Merge: Update by command from UI

Close #428

* commit '70e329956776cc381fdb28805375d5b2f0e22dbf':
  * openapi: update
  * client: add link to the update error
  * client: add update timeout
  * client: add error message if update failed
  + client: handle update
  * go linter
  * control: /version.json: use new JSON format
  + set config.runningAsService
  * app: --help: more pretty help info
  + app: add --check-config command-line argument
  * app: optimize config file reading
  + /control/update handler
  * control: don't use custom resolver for tests
  + doc: Update algorithm
  - control: fix race in /control/version.json handler
This commit is contained in:
Simon Zolin 2019-05-20 13:38:23 +03:00
commit aa2d942783
20 changed files with 814 additions and 139 deletions

View File

@ -9,6 +9,9 @@ Contents:
* "Check configuration" command
* Disable DNSStubListener
* "Apply configuration" command
* Updating
* Get version command
* Update command
* Enable DHCP server
* "Check DHCP" command
* "Enable DHCP" command
@ -187,6 +190,92 @@ On error, server responds with code 400 or 500. In this case UI should show err
ERROR MESSAGE
## Updating
Algorithm of an update by command:
* UI requests the latest version information from Server
* Server requests information from Internet; stores the data in cache for several hours; sends data to UI
* If UI sees that a new version is available, it shows notification message and "Update Now" button
* When user clicks on "Update Now" button, UI sends Update command to Server
* UI shows "Please wait, AGH is being updated..." message
* Server performs an update:
* Use working directory from `--work-dir` if necessary
* Download new package for the current OS and CPU
* Unpack the package to a temporary directory `update-vXXX`
* Copy the current configuration file to the directory we unpacked new AGH to
* Check configuration compatibility by executing `./AGH --check-config`. If this command fails, we won't be able to update.
* Create `backup-vXXX` directory and copy the current configuration file there
* Stop all tasks, including DNS server, DHCP server, HTTP server
* Move the current binary file to backup directory
* Note: if power fails here, AGH won't be able to start at system boot. Administrator has to fix it manually
* Move new binary file to the current directory
* If AGH is running as a service, use service control functionality to restart
* If AGH is not running as a service, use the current process arguments to start a new process
* Exit process
* UI resends Get Status command until Server responds to it with the new version. This means that Server is successfully restarted after update.
* UI reloads itself
### Get version command
On receiving this request server downloads version.json data from github and stores it in cache for several hours.
Example of version.json data:
{
"version": "v0.95-hotfix",
"announcement": "AdGuard Home v0.95-hotfix is now available!",
"announcement_url": "",
"download_windows_amd64": "",
"download_windows_386": "",
"download_darwin_amd64": "",
"download_linux_amd64": "",
"download_linux_386": "",
"download_linux_arm": "",
"download_linux_arm64": "",
"download_linux_mips": "",
"download_linux_mipsle": "",
"selfupdate_min_version": "v0.0"
}
Request:
GET /control/version.json
Response:
200 OK
{
"new_version": "v0.95",
"announcement": "AdGuard Home v0.95 is now available!",
"announcement_url": "http://...",
"can_autoupdate": true
}
If `can_autoupdate` is true, then the server can automatically upgrade to a new version.
### Update command
Perform an update procedure to the latest available version
Request:
POST /control/update
Response:
200 OK
Error response:
500
UI shows error message "Auto-update has failed"
## Enable DHCP server
Algorithm:

83
app.go
View File

@ -2,6 +2,7 @@ package main
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
@ -30,6 +31,7 @@ var httpsServer struct {
server *http.Server
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
sync.Mutex // protects config.TLS
shutdown bool // if TRUE, don't restart the server
}
var pidFileName string // PID file name. Empty if no PID file was created.
@ -76,6 +78,7 @@ func run(args options) {
if args.runningAsService {
log.Info("AdGuard Home is running as a service")
}
config.runningAsService = args.runningAsService
config.firstRun = detectFirstRun()
if config.firstRun {
@ -91,16 +94,22 @@ func run(args options) {
os.Exit(0)
}()
// Do the upgrade if necessary
err := upgradeConfig()
if err != nil {
log.Fatal(err)
}
if !config.firstRun {
// Do the upgrade if necessary
err := upgradeConfig()
if err != nil {
log.Fatal(err)
}
// parse from config file
err = parseConfig()
if err != nil {
log.Fatal(err)
err = parseConfig()
if err != nil {
os.Exit(1)
}
if args.checkConfig {
log.Info("Configuration file is OK")
os.Exit(0)
}
}
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
@ -118,10 +127,12 @@ func run(args options) {
loadFilters()
// Save the updated config
err = config.write()
if err != nil {
log.Fatal(err)
if !config.firstRun {
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
}
// Init the DNS server instance before registering HTTP handlers
@ -129,7 +140,7 @@ func run(args options) {
initDNSServer(dnsBaseDir)
if !config.firstRun {
err = startDNSServer()
err := startDNSServer()
if err != nil {
log.Fatal(err)
}
@ -171,7 +182,7 @@ func run(args options) {
go httpServerLoop()
// this loop is used as an ability to change listening host and/or port
for {
for !httpsServer.shutdown {
printHTTPAddresses("http")
// we need to have new instance, because after Shutdown() the Server is not usable
@ -186,10 +197,13 @@ func run(args options) {
}
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
}
// wait indefinitely for other go-routines to complete their job
select {}
}
func httpServerLoop() {
for {
for !httpsServer.shutdown {
httpsServer.cond.L.Lock()
// this mechanism doesn't let us through until all conditions are met
for config.TLS.Enabled == false ||
@ -367,6 +381,15 @@ func cleanup() {
}
}
// Stop HTTP server, possibly waiting for all active connections to be closed
func stopHTTPServer() {
httpsServer.shutdown = true
if httpsServer.server != nil {
httpsServer.server.Shutdown(context.TODO())
}
httpServer.Shutdown(context.TODO())
}
// This function is called before application exits
func cleanupAlways() {
if len(pidFileName) != 0 {
@ -384,6 +407,7 @@ type options struct {
bindPort int // port to serve HTTP pages on
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
pidFile string // File name to save PID to
checkConfig bool // Check configuration and exit
// service control action (see service.ControlAction array + "status" command)
serviceControlAction string
@ -404,25 +428,26 @@ func loadOptions() options {
callbackWithValue func(value string)
callbackNoValue func()
}{
{"config", "c", "path to the config file", func(value string) { o.configFilename = value }, nil},
{"work-dir", "w", "path to the working directory", func(value string) { o.workDir = value }, nil},
{"host", "h", "host address to bind HTTP server on", func(value string) { o.bindHost = value }, nil},
{"port", "p", "port to serve HTTP pages on", func(value string) {
{"config", "c", "Path to the config file", func(value string) { o.configFilename = value }, nil},
{"work-dir", "w", "Path to the working directory", func(value string) { o.workDir = value }, nil},
{"host", "h", "Host address to bind HTTP server on", func(value string) { o.bindHost = value }, nil},
{"port", "p", "Port to serve HTTP pages on", func(value string) {
v, err := strconv.Atoi(value)
if err != nil {
panic("Got port that is not a number")
}
o.bindPort = v
}, nil},
{"service", "s", "service control action: status, install, uninstall, start, stop, restart", func(value string) {
{"service", "s", "Service control action: status, install, uninstall, start, stop, restart", func(value string) {
o.serviceControlAction = value
}, nil},
{"logfile", "l", "path to the log file. If empty, writes to stdout, if 'syslog' -- system log", func(value string) {
{"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) {
o.logFile = value
}, nil},
{"pidfile", "", "File name to save PID to", func(value string) { o.pidFile = value }, nil},
{"verbose", "v", "enable verbose output", nil, func() { o.verbose = true }},
{"help", "", "print this help", nil, func() {
{"pidfile", "", "Path to a file where PID is stored", func(value string) { o.pidFile = value }, nil},
{"check-config", "", "Check configuration and exit", nil, func() { o.checkConfig = true }},
{"verbose", "v", "Enable verbose output", nil, func() { o.verbose = true }},
{"help", "", "Print this help", nil, func() {
printHelp()
os.Exit(64)
}},
@ -432,10 +457,14 @@ func loadOptions() options {
fmt.Printf("%s [options]\n\n", os.Args[0])
fmt.Printf("Options:\n")
for _, opt := range opts {
val := ""
if opt.callbackWithValue != nil {
val = " VALUE"
}
if opt.shortName != "" {
fmt.Printf(" -%s, %-30s %s\n", opt.shortName, "--"+opt.longName, opt.description)
fmt.Printf(" -%s, %-30s %s\n", opt.shortName, "--"+opt.longName+val, opt.description)
} else {
fmt.Printf(" %-34s %s\n", "--"+opt.longName, opt.description)
fmt.Printf(" %-34s %s\n", "--"+opt.longName+val, opt.description)
}
}
}

View File

@ -260,5 +260,8 @@
"dns_addresses": "DNS addresses",
"down": "Down",
"fix": "Fix",
"dns_providers": "Here is a <0>list of known DNS providers</0> to choose from."
"dns_providers": "Here is a <0>list of known DNS providers</0> to choose from.",
"update_now": "Update now",
"update_failed": "Auto-update failed. Please <a href='https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#update'>follow the steps<\/a> to update manually.",
"processing_update": "Please wait, AdGuard Home is being updated"
}

View File

@ -2,15 +2,17 @@ import { createAction } from 'redux-actions';
import round from 'lodash/round';
import { t } from 'i18next';
import { showLoading, hideLoading } from 'react-redux-loading-bar';
import axios from 'axios';
import { normalizeHistory, normalizeFilteringStatus, normalizeLogs, normalizeTextarea } from '../helpers/helpers';
import { SETTINGS_NAMES } from '../helpers/constants';
import { SETTINGS_NAMES, CHECK_TIMEOUT } from '../helpers/constants';
import Api from '../api/Api';
const apiClient = new Api();
export const addErrorToast = createAction('ADD_ERROR_TOAST');
export const addSuccessToast = createAction('ADD_SUCCESS_TOAST');
export const addNoticeToast = createAction('ADD_NOTICE_TOAST');
export const removeToast = createAction('REMOVE_TOAST');
export const toggleSettingStatus = createAction('SETTING_STATUS_TOGGLE');
@ -154,6 +156,56 @@ export const getVersion = () => async (dispatch) => {
}
};
export const getUpdateRequest = createAction('GET_UPDATE_REQUEST');
export const getUpdateFailure = createAction('GET_UPDATE_FAILURE');
export const getUpdateSuccess = createAction('GET_UPDATE_SUCCESS');
export const getUpdate = () => async (dispatch) => {
dispatch(getUpdateRequest());
try {
await apiClient.getUpdate();
const checkUpdate = async (attempts) => {
let count = attempts || 1;
let timeout;
if (count > 60) {
dispatch(addNoticeToast({ error: 'update_failed' }));
dispatch(getUpdateFailure());
return false;
}
const rmTimeout = t => t && clearTimeout(t);
const setRecursiveTimeout = (time, ...args) => setTimeout(
checkUpdate,
time,
...args,
);
axios.get('control/status')
.then((response) => {
rmTimeout(timeout);
if (response) {
dispatch(getUpdateSuccess());
window.location.reload(true);
}
timeout = setRecursiveTimeout(CHECK_TIMEOUT, count += 1);
})
.catch(() => {
rmTimeout(timeout);
timeout = setRecursiveTimeout(CHECK_TIMEOUT, count += 1);
});
return false;
};
checkUpdate();
} catch (error) {
dispatch(addNoticeToast({ error: 'update_failed' }));
dispatch(getUpdateFailure());
}
};
export const getClientsRequest = createAction('GET_CLIENTS_REQUEST');
export const getClientsFailure = createAction('GET_CLIENTS_FAILURE');
export const getClientsSuccess = createAction('GET_CLIENTS_SUCCESS');

View File

@ -40,6 +40,8 @@ export default class Api {
GLOBAL_ENABLE_PROTECTION = { path: 'enable_protection', method: 'POST' };
GLOBAL_DISABLE_PROTECTION = { path: 'disable_protection', method: 'POST' };
GLOBAL_CLIENTS = { path: 'clients', method: 'GET' }
GLOBAL_CLIENTS = { path: 'clients', method: 'GET' };
GLOBAL_UPDATE = { path: 'update', method: 'POST' };
restartGlobalFiltering() {
const { path, method } = this.GLOBAL_RESTART;
@ -145,6 +147,11 @@ export default class Api {
return this.makeRequest(path, method);
}
getUpdate() {
const { path, method } = this.GLOBAL_UPDATE;
return this.makeRequest(path, method);
}
// Filtering
FILTERING_STATUS = { path: 'filtering/status', method: 'GET' };
FILTERING_ENABLE = { path: 'filtering/enable', method: 'POST' };

View File

@ -19,6 +19,7 @@ import Toasts from '../Toasts';
import Footer from '../ui/Footer';
import Status from '../ui/Status';
import UpdateTopline from '../ui/UpdateTopline';
import UpdateOverlay from '../ui/UpdateOverlay';
import EncryptionTopline from '../ui/EncryptionTopline';
import i18n from '../../i18n';
@ -37,6 +38,10 @@ class App extends Component {
this.props.enableDns();
};
handleUpdate = () => {
this.props.getUpdate();
}
setLanguage = () => {
const { processing, language } = this.props.dashboard;
@ -62,10 +67,16 @@ class App extends Component {
<HashRouter hashType='noslash'>
<Fragment>
{updateAvailable &&
<UpdateTopline
url={dashboard.announcementUrl}
version={dashboard.version}
/>
<Fragment>
<UpdateTopline
url={dashboard.announcementUrl}
version={dashboard.newVersion}
canAutoUpdate={dashboard.canAutoUpdate}
getUpdate={this.handleUpdate}
processingUpdate={dashboard.processingUpdate}
/>
<UpdateOverlay processingUpdate={dashboard.processingUpdate} />
</Fragment>
}
{!encryption.processing &&
<EncryptionTopline notAfter={encryption.not_after} />
@ -100,6 +111,7 @@ class App extends Component {
App.propTypes = {
getDnsStatus: PropTypes.func,
getUpdate: PropTypes.func,
enableDns: PropTypes.func,
dashboard: PropTypes.object,
isCoreRunning: PropTypes.bool,

View File

@ -32,6 +32,12 @@
overflow: hidden;
}
.toast__content a {
font-weight: 600;
color: #fff;
text-decoration: underline;
}
.toast__dismiss {
display: block;
flex: 0 0 auto;

View File

@ -4,7 +4,7 @@ import { Trans, withNamespaces } from 'react-i18next';
class Toast extends Component {
componentDidMount() {
const timeout = this.props.type === 'error' ? 30000 : 5000;
const timeout = this.props.type === 'success' ? 5000 : 30000;
setTimeout(() => {
this.props.removeToast(this.props.id);
@ -15,13 +15,25 @@ class Toast extends Component {
return false;
}
showMessage(t, type, message) {
if (type === 'notice') {
return <span dangerouslySetInnerHTML={{ __html: t(message) }} />;
}
return <Trans>{message}</Trans>;
}
render() {
const {
type, id, t, message,
} = this.props;
return (
<div className={`toast toast--${this.props.type}`}>
<div className={`toast toast--${type}`}>
<p className="toast__content">
<Trans>{this.props.message}</Trans>
{this.showMessage(t, type, message)}
</p>
<button className="toast__dismiss" onClick={() => this.props.removeToast(this.props.id)}>
<button className="toast__dismiss" onClick={() => this.props.removeToast(id)}>
<svg stroke="#fff" fill="none" width="20" height="20" strokeWidth="2" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="m18 6-12 12"/><path d="m6 6 12 12"/></svg>
</button>
</div>
@ -30,6 +42,7 @@ class Toast extends Component {
}
Toast.propTypes = {
t: PropTypes.func.isRequired,
id: PropTypes.string.isRequired,
message: PropTypes.string.isRequired,
type: PropTypes.string.isRequired,

View File

@ -0,0 +1,40 @@
.overlay {
display: none;
position: fixed;
top: 0;
left: 0;
z-index: 110;
width: 100%;
height: 100%;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 20px;
font-size: 28px;
font-weight: 600;
text-align: center;
background-color: rgba(255, 255, 255, 0.8);
}
.overlay--visible {
display: flex;
}
.overlay__loading {
width: 40px;
height: 40px;
margin-bottom: 20px;
background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg%20xmlns%3D%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%20viewBox%3D%220%200%2047.6%2047.6%22%20height%3D%22100%25%22%20width%3D%22100%25%22%3E%3Cpath%20opacity%3D%22.235%22%20fill%3D%22%23979797%22%20d%3D%22M44.4%2011.9l-5.2%203c1.5%202.6%202.4%205.6%202.4%208.9%200%209.8-8%2017.8-17.8%2017.8-6.6%200-12.3-3.6-15.4-8.9l-5.2%203C7.3%2042.8%2015%2047.6%2023.8%2047.6c13.1%200%2023.8-10.7%2023.8-23.8%200-4.3-1.2-8.4-3.2-11.9z%22%2F%3E%3Cpath%20fill%3D%22%2366b574%22%20d%3D%22M3.2%2035.7C0%2030.2-.8%2023.8.8%2017.6%202.5%2011.5%206.4%206.4%2011.9%203.2%2017.4%200%2023.8-.8%2030%20.8c6.1%201.6%2011.3%205.6%2014.4%2011.1l-5.2%203c-2.4-4.1-6.2-7.1-10.8-8.3C23.8%205.4%2019%206%2014.9%208.4s-7.1%206.2-8.3%2010.8c-1.2%204.6-.6%209.4%201.8%2013.5l-5.2%203z%22%2F%3E%3C%2Fsvg%3E");
will-change: transform;
animation: clockwise 2s linear infinite;
}
@keyframes clockwise {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}

View File

@ -0,0 +1,26 @@
import React from 'react';
import PropTypes from 'prop-types';
import { Trans, withNamespaces } from 'react-i18next';
import classnames from 'classnames';
import './Overlay.css';
const UpdateOverlay = (props) => {
const overlayClass = classnames({
overlay: true,
'overlay--visible': props.processingUpdate,
});
return (
<div className={overlayClass}>
<div className="overlay__loading"></div>
<Trans>processing_update</Trans>
</div>
);
};
UpdateOverlay.propTypes = {
processingUpdate: PropTypes.bool,
};
export default withNamespaces()(UpdateOverlay);

View File

@ -1,4 +1,4 @@
import React from 'react';
import React, { Fragment } from 'react';
import PropTypes from 'prop-types';
import { Trans, withNamespaces } from 'react-i18next';
@ -6,22 +6,37 @@ import Topline from './Topline';
const UpdateTopline = props => (
<Topline type="info">
<Trans
values={{ version: props.version }}
components={[
<a href={props.url} target="_blank" rel="noopener noreferrer" key="0">
Click here
</a>,
]}
>
update_announcement
</Trans>
<Fragment>
<Trans
values={{ version: props.version }}
components={[
<a href={props.url} target="_blank" rel="noopener noreferrer" key="0">
Click here
</a>,
]}
>
update_announcement
</Trans>
{props.canAutoUpdate &&
<button
type="button"
className="btn btn-sm btn-primary ml-3"
onClick={props.getUpdate}
disabled={props.processingUpdate}
>
<Trans>update_now</Trans>
</button>
}
</Fragment>
</Topline>
);
UpdateTopline.propTypes = {
version: PropTypes.string.isRequired,
version: PropTypes.string,
url: PropTypes.string.isRequired,
canAutoUpdate: PropTypes.bool,
getUpdate: PropTypes.func,
processingUpdate: PropTypes.bool,
};
export default withNamespaces()(UpdateTopline);

View File

@ -126,12 +126,16 @@ const dashboard = handleActions({
const {
version,
announcement_url: announcementUrl,
new_version: newVersion,
can_autoupdate: canAutoUpdate,
} = payload;
const newState = {
...state,
version,
announcementUrl,
newVersion,
canAutoUpdate,
isUpdateAvailable: true,
};
return newState;
@ -140,6 +144,13 @@ const dashboard = handleActions({
return state;
},
[actions.getUpdateRequest]: state => ({ ...state, processingUpdate: true }),
[actions.getUpdateFailure]: state => ({ ...state, processingUpdate: false }),
[actions.getUpdateSuccess]: (state) => {
const newState = { ...state, processingUpdate: false };
return newState;
},
[actions.getFilteringRequest]: state => ({ ...state, processingFiltering: true }),
[actions.getFilteringFailure]: state => ({ ...state, processingFiltering: false }),
[actions.getFilteringSuccess]: (state, { payload }) => {
@ -187,6 +198,7 @@ const dashboard = handleActions({
processingVersion: true,
processingFiltering: true,
processingClients: true,
processingUpdate: false,
upstreamDns: '',
bootstrapDns: '',
allServers: false,

View File

@ -1,7 +1,7 @@
import { handleActions } from 'redux-actions';
import nanoid from 'nanoid';
import { addErrorToast, addSuccessToast, removeToast } from '../actions';
import { addErrorToast, addSuccessToast, addNoticeToast, removeToast } from '../actions';
const toasts = handleActions({
[addErrorToast]: (state, { payload }) => {
@ -24,6 +24,16 @@ const toasts = handleActions({
const newState = { ...state, notices: [...state.notices, successToast] };
return newState;
},
[addNoticeToast]: (state, { payload }) => {
const noticeToast = {
id: nanoid(),
message: payload.error.toString(),
type: 'notice',
};
const newState = { ...state, notices: [...state.notices, noticeToast] };
return newState;
},
[removeToast]: (state, { payload }) => {
const filtered = state.notices.filter(notice => notice.id !== payload);
const newState = { ...state, notices: filtered };

View File

@ -30,9 +30,15 @@ type logSettings struct {
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
type configuration struct {
// Raw file data to avoid re-reading of configuration file
// It's reset after config is parsed
fileData []byte
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
@ -113,10 +119,10 @@ var config = configuration{
BindHost: "0.0.0.0",
Port: 53,
FilteringConfig: dnsforward.FilteringConfig{
ProtectionEnabled: true, // whether or not use any of dnsfilter features
FilteringEnabled: true, // whether or not use filter lists
ProtectionEnabled: true, // whether or not use any of dnsfilter features
FilteringEnabled: true, // whether or not use filter lists
BlockingMode: "nxdomain", // mode how to answer filtered requests
BlockedResponseTTL: 10, // in seconds
BlockedResponseTTL: 10, // in seconds
QueryLogEnabled: true,
Ratelimit: 20,
RefuseAny: true,
@ -174,7 +180,7 @@ func (c *configuration) getConfigFilename() string {
func getLogSettings() logSettings {
l := logSettings{}
yamlFile, err := readConfigFile()
if err != nil || yamlFile == nil {
if err != nil {
return l
}
err = yaml.Unmarshal(yamlFile, &l)
@ -190,13 +196,9 @@ func parseConfig() error {
log.Debug("Reading config file: %s", configFile)
yamlFile, err := readConfigFile()
if err != nil {
log.Error("Couldn't read config file: %s", err)
return err
}
if yamlFile == nil {
log.Error("YAML file doesn't exist, skipping it")
return nil
}
config.fileData = nil
err = yaml.Unmarshal(yamlFile, &config)
if err != nil {
log.Error("Couldn't parse config file: %s", err)
@ -213,22 +215,23 @@ func parseConfig() error {
// readConfigFile reads config file contents if it exists
func readConfigFile() ([]byte, error) {
configFile := config.getConfigFilename()
if _, err := os.Stat(configFile); os.IsNotExist(err) {
// do nothing, file doesn't exist
return nil, nil
if len(config.fileData) != 0 {
return config.fileData, nil
}
return ioutil.ReadFile(configFile)
configFile := config.getConfigFilename()
d, err := ioutil.ReadFile(configFile)
if err != nil {
log.Error("Couldn't read config file %s: %s", configFile, err)
return nil, err
}
return d, nil
}
// Saves configuration to the YAML file and also saves the user filter contents to a file
func (c *configuration) write() error {
c.Lock()
defer c.Unlock()
if config.firstRun {
log.Debug("Silently refusing to write config because first run and not configured yet")
return nil
}
configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config)

View File

@ -557,42 +557,6 @@ func checkDNS(input string, bootstrap []string) error {
return nil
}
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
now := time.Now()
if now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0 {
// return cached copy
w.Header().Set("Content-Type", "application/json")
w.Write(versionCheckJSON)
return
}
resp, err := client.Get(versionCheckURL)
if err != nil {
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
return
}
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
// read the body entirely
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(body)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
versionCheckLastTime = now
versionCheckJSON = body
}
// ---------
// filtering
// ---------
@ -1006,6 +970,7 @@ func registerControlHandlers() {
http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(handleStatsHistory))))
http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(handleStatsReset))))
http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
http.HandleFunc("/control/update", postInstall(optionalAuth(ensurePOST(handleUpdate))))
http.HandleFunc("/control/filtering/enable", postInstall(optionalAuth(ensurePOST(handleFilteringEnable))))
http.HandleFunc("/control/filtering/disable", postInstall(optionalAuth(ensurePOST(handleFilteringDisable))))
http.HandleFunc("/control/filtering/add_url", postInstall(optionalAuth(ensurePOST(handleFilteringAddURL))))

371
control_update.go Normal file
View File

@ -0,0 +1,371 @@
package main
import (
"archive/zip"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
"github.com/AdguardTeam/golibs/log"
)
// Convert version.json data to our JSON response
func getVersionResp(data []byte) []byte {
versionJSON := make(map[string]interface{})
err := json.Unmarshal(data, &versionJSON)
if err != nil {
log.Error("version.json: %s", err)
return []byte{}
}
ret := make(map[string]interface{})
ret["can_autoupdate"] = false
var ok1, ok2, ok3 bool
ret["new_version"], ok1 = versionJSON["version"].(string)
ret["announcement"], ok2 = versionJSON["announcement"].(string)
ret["announcement_url"], ok3 = versionJSON["announcement_url"].(string)
if !ok1 || !ok2 || !ok3 {
log.Error("version.json: invalid data")
return []byte{}
}
_, ok := versionJSON[fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH)]
if ok && ret["new_version"] != VersionString {
ret["can_autoupdate"] = true
}
d, _ := json.Marshal(ret)
return d
}
// Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
now := time.Now()
controlLock.Lock()
cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0
data := versionCheckJSON
controlLock.Unlock()
if cached {
// return cached copy
w.Header().Set("Content-Type", "application/json")
w.Write(getVersionResp(data))
return
}
resp, err := client.Get(versionCheckURL)
if err != nil {
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
return
}
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
// read the body entirely
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err)
return
}
controlLock.Lock()
versionCheckLastTime = now
versionCheckJSON = body
controlLock.Unlock()
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(getVersionResp(body))
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
// Copy file on disk
func copyFile(src, dst string) error {
d, e := ioutil.ReadFile(src)
if e != nil {
return e
}
e = ioutil.WriteFile(dst, d, 0644)
if e != nil {
return e
}
return nil
}
type updateInfo struct {
pkgURL string // URL for the new package
pkgName string // Full path to package file
newVer string // New version string
updateDir string // Full path to the directory containing unpacked files from the new package
backupDir string // Full path to backup directory
configName string // Full path to the current configuration file
updateConfigName string // Full path to the configuration file to check by the new binary
curBinName string // Full path to the current executable file
bkpBinName string // Full path to the current executable file in backup directory
newBinName string // Full path to the new executable file
}
// Fill in updateInfo object
func getUpdateInfo(jsonData []byte) (*updateInfo, error) {
var u updateInfo
workDir := config.ourWorkingDir
versionJSON := make(map[string]interface{})
err := json.Unmarshal(jsonData, &versionJSON)
if err != nil {
return nil, fmt.Errorf("JSON parse: %s", err)
}
u.pkgURL = versionJSON[fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH)].(string)
u.newVer = versionJSON["version"].(string)
if len(u.pkgURL) == 0 || len(u.newVer) == 0 {
return nil, fmt.Errorf("Invalid JSON")
}
if u.newVer == VersionString {
return nil, fmt.Errorf("No need to update")
}
_, pkgFileName := filepath.Split(u.pkgURL)
if len(pkgFileName) == 0 {
return nil, fmt.Errorf("Invalid JSON")
}
u.pkgName = filepath.Join(workDir, pkgFileName)
u.updateDir = filepath.Join(workDir, fmt.Sprintf("update-%s", u.newVer))
u.backupDir = filepath.Join(workDir, fmt.Sprintf("backup-%s", VersionString))
u.configName = config.getConfigFilename()
u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome", "AdGuardHome.yaml")
if strings.HasSuffix(pkgFileName, ".zip") {
u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome.yaml")
}
binName := "AdGuardHome"
if runtime.GOOS == "windows" {
binName = "AdGuardHome.exe"
}
u.curBinName = filepath.Join(workDir, binName)
u.bkpBinName = filepath.Join(u.backupDir, binName)
u.newBinName = filepath.Join(u.updateDir, "AdGuardHome", binName)
if strings.HasSuffix(pkgFileName, ".zip") {
u.newBinName = filepath.Join(u.updateDir, binName)
}
return &u, nil
}
// Unpack all files from .zip file to the specified directory
func zipFileUnpack(zipfile, outdir string) error {
r, err := zip.OpenReader(zipfile)
if err != nil {
return fmt.Errorf("zip.OpenReader(): %s", err)
}
defer r.Close()
for _, zf := range r.File {
zr, err := zf.Open()
if err != nil {
return fmt.Errorf("zip file Open(): %s", err)
}
fi := zf.FileInfo()
fn := filepath.Join(outdir, fi.Name())
if fi.IsDir() {
err = os.Mkdir(fn, fi.Mode())
if err != nil {
return fmt.Errorf("zip file Read(): %s", err)
}
continue
}
f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
if err != nil {
zr.Close()
return fmt.Errorf("os.OpenFile(): %s", err)
}
_, err = io.Copy(f, zr)
if err != nil {
zr.Close()
return fmt.Errorf("io.Copy(): %s", err)
}
zr.Close()
}
return nil
}
// Unpack all files from .tar.gz file to the specified directory
func targzFileUnpack(tarfile, outdir string) error {
cmd := exec.Command("tar", "zxf", tarfile, "-C", outdir)
log.Tracef("Unpacking: %v", cmd.Args)
_, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
return fmt.Errorf("exec.Command() failed: %s", err)
}
return nil
}
// Perform an update procedure
func doUpdate(u *updateInfo) error {
log.Info("Updating from %s to %s. URL:%s Package:%s",
VersionString, u.newVer, u.pkgURL, u.pkgName)
resp, err := client.Get(u.pkgURL)
if err != nil {
return fmt.Errorf("HTTP request failed: %s", err)
}
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
log.Tracef("Reading HTTP body")
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("ioutil.ReadAll() failed: %s", err)
}
log.Tracef("Saving package to file")
err = ioutil.WriteFile(u.pkgName, body, 0644)
if err != nil {
return fmt.Errorf("ioutil.WriteFile() failed: %s", err)
}
log.Tracef("Unpacking the package")
_ = os.Mkdir(u.updateDir, 0755)
_, file := filepath.Split(u.pkgName)
if strings.HasSuffix(file, ".zip") {
err = zipFileUnpack(u.pkgName, u.updateDir)
if err != nil {
return fmt.Errorf("zipFileUnpack() failed: %s", err)
}
} else if strings.HasSuffix(file, ".tar.gz") {
err = targzFileUnpack(u.pkgName, u.updateDir)
if err != nil {
return fmt.Errorf("zipFileUnpack() failed: %s", err)
}
} else {
return fmt.Errorf("Unknown package extension")
}
log.Tracef("Checking configuration")
err = copyFile(u.configName, u.updateConfigName)
if err != nil {
return fmt.Errorf("copyFile() failed: %s", err)
}
cmd := exec.Command(u.newBinName, "--check-config")
err = cmd.Run()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode())
}
log.Tracef("Backing up the current configuration")
_ = os.Mkdir(u.backupDir, 0755)
err = copyFile(u.configName, filepath.Join(u.backupDir, "AdGuardHome.yaml"))
if err != nil {
return fmt.Errorf("copyFile() failed: %s", err)
}
log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName)
err = os.Rename(u.curBinName, u.bkpBinName)
if err != nil {
return err
}
if runtime.GOOS == "windows" {
// rename fails with "File in use" error
err = copyFile(u.newBinName, u.curBinName)
} else {
err = os.Rename(u.newBinName, u.curBinName)
}
if err != nil {
return err
}
log.Tracef("Renamed: %s -> %s", u.newBinName, u.curBinName)
_ = os.Remove(u.pkgName)
// _ = os.RemoveAll(u.updateDir)
return nil
}
// Complete an update procedure
func finishUpdate(u *updateInfo) {
log.Info("Stopping all tasks")
cleanup()
stopHTTPServer()
cleanupAlways()
if runtime.GOOS == "windows" {
if config.runningAsService {
// Note:
// we can't restart the service via "kardianos/service" package - it kills the process first
// we can't start a new instance - Windows doesn't allow it
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
err := cmd.Start()
if err != nil {
log.Fatalf("exec.Command() failed: %s", err)
}
os.Exit(0)
}
cmd := exec.Command(u.curBinName, os.Args[1:]...)
log.Info("Restarting: %v", cmd.Args)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if err != nil {
log.Fatalf("exec.Command() failed: %s", err)
}
os.Exit(0)
} else {
log.Info("Restarting: %v", os.Args)
err := syscall.Exec(u.curBinName, os.Args, os.Environ())
if err != nil {
log.Fatalf("syscall.Exec() failed: %s", err)
}
// Unreachable code
}
}
// Perform an update procedure to the latest available version
func handleUpdate(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
if len(versionCheckJSON) == 0 {
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
return
}
u, err := getUpdateInfo(versionCheckJSON)
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
err = doUpdate(u)
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
returnOK(w)
time.Sleep(time.Second) // wait (hopefully) until response is sent (not sure whether it's really necessary)
go finishUpdate(u)
}

39
control_update_test.go Normal file
View File

@ -0,0 +1,39 @@
package main
import (
"os"
"testing"
)
func testDoUpdate(t *testing.T) {
config.DNS.Port = 0
u := updateInfo{
pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.95/AdGuardHome_v0.95_linux_amd64.tar.gz",
pkgName: "./AdGuardHome_v0.95_linux_amd64.tar.gz",
newVer: "v0.95",
updateDir: "./update-v0.95",
backupDir: "./backup-v0.94",
configName: "./AdGuardHome.yaml",
updateConfigName: "./update-v0.95/AdGuardHome/AdGuardHome.yaml",
curBinName: "./AdGuardHome",
bkpBinName: "./backup-v0.94/AdGuardHome",
newBinName: "./update-v0.95/AdGuardHome/AdGuardHome",
}
e := doUpdate(&u)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
os.RemoveAll(u.backupDir)
os.RemoveAll(u.updateDir)
}
func testZipFileUnpack(t *testing.T) {
fn := "./dist/AdGuardHome_v0.95_Windows_amd64.zip"
outdir := "./test-unpack"
_ = os.Mkdir(outdir, 0755)
e := zipFileUnpack(fn, outdir)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
os.RemoveAll(outdir)
}

View File

@ -318,7 +318,7 @@ func customDialContext(ctx context.Context, network, addr string) (net.Conn, err
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil {
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
con, err := dialer.DialContext(ctx, network, addr)
return con, err
}

View File

@ -151,6 +151,17 @@ paths:
description: 'Cannot write answer'
502:
description: 'Cannot retrieve the version.json file contents'
/update:
post:
tags:
- global
operationId: beginUpdate
summary: 'Begin auto-upgrade procedure'
responses:
200:
description: OK
500:
description: Failed
# --------------------------------------------------
# Query log methods
@ -906,17 +917,8 @@ definitions:
VersionInfo:
type: "object"
description: "Information about the latest available version of AdGuard Home"
required:
- "version"
- "announcement"
- "announcement_url"
- "download_darwin_amd64"
- "download_linux_amd64"
- "download_linux_386"
- "download_linux_arm"
- "selfupdate_min_version"
properties:
version:
new_version:
type: "string"
example: "v0.9"
announcement:
@ -925,21 +927,8 @@ definitions:
announcement_url:
type: "string"
example: "https://github.com/AdguardTeam/AdGuardHome/releases/tag/v0.9"
download_darwin_amd64:
type: "string"
example: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.9/AdGuardHome_v0.9_MacOS.zip"
download_linux_amd64:
type: "string"
example: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.9/AdGuardHome_v0.9_linux_amd64.tar.gz"
download_linux_386:
type: "string"
example: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.9/AdGuardHome_v0.9_linux_386.tar.gz"
download_linux_arm:
type: "string"
example: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.9/AdGuardHome_v0.9_linux_arm.tar.gz"
selfupdate_min_version:
type: "string"
example: "v0.0"
can_autoupdate:
type: "boolean"
Stats:
type: "object"
description: "General server stats for the last 24 hours"

View File

@ -2,7 +2,6 @@ package main
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
@ -16,21 +15,15 @@ const currentSchemaVersion = 3 // used for upgrading from old configs to new con
// Performs necessary upgrade operations if needed
func upgradeConfig() error {
// read a config file into an interface map, so we can manipulate values without losing any
configFile := config.getConfigFilename()
if _, err := os.Stat(configFile); os.IsNotExist(err) {
log.Printf("config file %s does not exist, nothing to upgrade", configFile)
return nil
}
diskConfig := map[string]interface{}{}
body, err := ioutil.ReadFile(configFile)
body, err := readConfigFile()
if err != nil {
log.Printf("Couldn't read config file '%s': %s", configFile, err)
return err
}
err = yaml.Unmarshal(body, &diskConfig)
if err != nil {
log.Printf("Couldn't parse config file '%s': %s", configFile, err)
log.Printf("Couldn't parse config file: %s", err)
return err
}
@ -87,6 +80,7 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err
return err
}
config.fileData = body
err = file.SafeWrite(configFile, body)
if err != nil {
log.Printf("Couldn't save YAML config: %s", err)