diff --git a/home/auth.go b/home/auth.go index 9afe2c87..3e0d4300 100644 --- a/home/auth.go +++ b/home/auth.go @@ -307,7 +307,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { return } - cookie := config.auth.httpCookie(req) + cookie := Context.auth.httpCookie(req) if len(cookie) == 0 { log.Info("Auth: invalid user name or password: name='%s'", req.Name) time.Sleep(1 * time.Second) @@ -328,7 +328,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) { cookie := r.Header.Get("Cookie") sess := parseCookie(cookie) - config.auth.RemoveSession(sess) + Context.auth.RemoveSession(sess) w.Header().Set("Location", "/login.html") @@ -365,10 +365,10 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re if r.URL.Path == "/login.html" { // redirect to dashboard if already authenticated - authRequired := config.auth != nil && config.auth.AuthRequired() + authRequired := Context.auth != nil && Context.auth.AuthRequired() cookie, err := r.Cookie(sessionCookieName) if authRequired && err == nil { - r := config.auth.CheckSession(cookie.Value) + r := Context.auth.CheckSession(cookie.Value) if r == 0 { w.Header().Set("Location", "/") w.WriteHeader(http.StatusFound) @@ -383,12 +383,12 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re strings.HasPrefix(r.URL.Path, "/__locales/") { // process as usual - } else if config.auth != nil && config.auth.AuthRequired() { + } else if Context.auth != nil && Context.auth.AuthRequired() { // redirect to login page if not authenticated ok := false cookie, err := r.Cookie(sessionCookieName) if err == nil { - r := config.auth.CheckSession(cookie.Value) + r := Context.auth.CheckSession(cookie.Value) if r == 0 { ok = true } else if r < 0 { @@ -398,7 +398,7 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re // there's no Cookie, check Basic authentication user, pass, ok2 := r.BasicAuth() if ok2 { - u := config.auth.UserFind(user, pass) + u := Context.auth.UserFind(user, pass) if len(u.Name) != 0 { ok = true } else { @@ -474,7 +474,7 @@ func (a *Auth) GetCurrentUser(r *http.Request) User { // there's no Cookie, check Basic authentication user, pass, ok := r.BasicAuth() if ok { - u := config.auth.UserFind(user, pass) + u := Context.auth.UserFind(user, pass) return u } return User{} diff --git a/home/auth_test.go b/home/auth_test.go index 19cd5001..38f826ec 100644 --- a/home/auth_test.go +++ b/home/auth_test.go @@ -100,7 +100,7 @@ func TestAuthHTTP(t *testing.T) { users := []User{ User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"}, } - config.auth = InitAuth(fn, users, 60) + Context.auth = InitAuth(fn, users, 60) handlerCalled := false handler := func(w http.ResponseWriter, r *http.Request) { @@ -129,7 +129,7 @@ func TestAuthHTTP(t *testing.T) { assert.True(t, handlerCalled) // perform login - cookie := config.auth.httpCookie(loginJSON{Name: "name", Password: "password"}) + cookie := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}) assert.True(t, cookie != "") // get / @@ -173,5 +173,5 @@ func TestAuthHTTP(t *testing.T) { assert.True(t, handlerCalled) r.Header.Del("Cookie") - config.auth.Close() + Context.auth.Close() } diff --git a/home/config.go b/home/config.go index 693c4a40..ab875fdc 100644 --- a/home/config.go +++ b/home/config.go @@ -44,19 +44,6 @@ type configuration struct { // It's reset after config is parsed fileData []byte - ourConfigFilename string // Config filename (can be overridden via the command line arguments) - ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else - firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html - pidFileName string // PID file name. Empty if no PID file was created. - // runningAsService flag is set to true when options are passed from the service runner - runningAsService bool - disableUpdate bool // If set, don't check for updates - appSignalChannel chan os.Signal - controlLock sync.Mutex - transport *http.Transport - client *http.Client - auth *Auth // HTTP authentication module - // cached version.json to avoid hammering github.io for each page reload versionCheckJSON []byte versionCheckLastTime time.Time @@ -152,9 +139,8 @@ type tlsConfig struct { // initialize to default values, will be changed later when reading config or parsing command line var config = configuration{ - ourConfigFilename: "AdGuardHome.yaml", - BindPort: 3000, - BindHost: "0.0.0.0", + BindPort: 3000, + BindHost: "0.0.0.0", DNS: dnsConfig{ BindHost: "0.0.0.0", Port: 53, @@ -185,14 +171,6 @@ var config = configuration{ // initConfig initializes default configuration for the current OS&ARCH func initConfig() { - config.transport = &http.Transport{ - DialContext: customDialContext, - } - config.client = &http.Client{ - Timeout: time.Minute * 5, - Transport: config.transport, - } - config.WebSessionTTLHours = 30 * 24 config.DNS.QueryLogEnabled = true @@ -209,24 +187,19 @@ func initConfig() { // getConfigFilename returns path to the current config file func (c *configuration) getConfigFilename() string { - configFile, err := filepath.EvalSymlinks(config.ourConfigFilename) + configFile, err := filepath.EvalSymlinks(Context.configFilename) if err != nil { if !os.IsNotExist(err) { log.Error("unexpected error while config file path evaluation: %s", err) } - configFile = config.ourConfigFilename + configFile = Context.configFilename } if !filepath.IsAbs(configFile) { - configFile = filepath.Join(config.ourWorkingDir, configFile) + configFile = filepath.Join(Context.workDir, configFile) } return configFile } -// getDataDir returns path to the directory where we store databases and filters -func (c *configuration) getDataDir() string { - return filepath.Join(c.ourWorkingDir, dataDir) -} - // getLogSettings reads logging settings from the config file. // we do it in a separate method in order to configure logger before the actual configuration is parsed and applied. func getLogSettings() logSettings { @@ -292,8 +265,8 @@ func (c *configuration) write() error { Context.clients.WriteDiskConfig(&config.Clients) - if config.auth != nil { - config.Users = config.auth.GetUsers() + if Context.auth != nil { + config.Users = Context.auth.GetUsers() } if Context.stats != nil { diff --git a/home/control.go b/home/control.go index 87247190..d852ef66 100644 --- a/home/control.go +++ b/home/control.go @@ -129,7 +129,7 @@ type profileJSON struct { func handleGetProfile(w http.ResponseWriter, r *http.Request) { pj := profileJSON{} - u := config.auth.GetCurrentUser(r) + u := Context.auth.GetCurrentUser(r) pj.Name = u.Name data, err := json.Marshal(pj) diff --git a/home/control_filtering.go b/home/control_filtering.go index 77c6cafa..8846b980 100644 --- a/home/control_filtering.go +++ b/home/control_filtering.go @@ -210,9 +210,9 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { } func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { - config.controlLock.Unlock() + Context.controlLock.Unlock() nUpdated, err := refreshFilters() - config.controlLock.Lock() + Context.controlLock.Lock() if err != nil { httpError(w, http.StatusInternalServerError, "%s", err) return diff --git a/home/control_install.go b/home/control_install.go index 196b792b..8a595bb8 100644 --- a/home/control_install.go +++ b/home/control_install.go @@ -302,7 +302,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { var curConfig configuration copyInstallSettings(&curConfig, &config) - config.firstRun = false + Context.firstRun = false config.BindHost = newSettings.Web.IP config.BindPort = newSettings.Web.Port config.DNS.BindHost = newSettings.DNS.IP @@ -317,7 +317,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { } } if err != nil || err2 != nil { - config.firstRun = true + Context.firstRun = true copyInstallSettings(&config, &curConfig) if err != nil { httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err) @@ -329,11 +329,11 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { u := User{} u.Name = newSettings.Username - config.auth.UserAdd(&u, newSettings.Password) + Context.auth.UserAdd(&u, newSettings.Password) err = config.write() if err != nil { - config.firstRun = true + Context.firstRun = true copyInstallSettings(&config, &curConfig) httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err) return diff --git a/home/control_update.go b/home/control_update.go index 7864cfbb..10e94e50 100644 --- a/home/control_update.go +++ b/home/control_update.go @@ -64,7 +64,7 @@ type getVersionJSONRequest struct { // Get the latest available version from the Internet func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { - if config.disableUpdate { + if Context.disableUpdate { return } @@ -77,10 +77,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { now := time.Now() if !req.RecheckNow { - config.controlLock.Lock() + Context.controlLock.Lock() cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0 data := config.versionCheckJSON - config.controlLock.Unlock() + Context.controlLock.Unlock() if cached { log.Tracef("Returning cached data") @@ -93,7 +93,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { var resp *http.Response for i := 0; i != 3; i++ { log.Tracef("Downloading data from %s", versionCheckURL) - resp, err = config.client.Get(versionCheckURL) + resp, err = Context.client.Get(versionCheckURL) if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") { // This case may happen while we're restarting DNS server // https://github.com/AdguardTeam/AdGuardHome/issues/934 @@ -116,10 +116,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { return } - config.controlLock.Lock() + Context.controlLock.Lock() config.versionCheckLastTime = now config.versionCheckJSON = body - config.controlLock.Unlock() + Context.controlLock.Unlock() w.Header().Set("Content-Type", "application/json") _, err = w.Write(getVersionResp(body)) @@ -158,7 +158,7 @@ type updateInfo struct { func getUpdateInfo(jsonData []byte) (*updateInfo, error) { var u updateInfo - workDir := config.ourWorkingDir + workDir := Context.workDir versionJSON := make(map[string]interface{}) err := json.Unmarshal(jsonData, &versionJSON) @@ -365,7 +365,7 @@ func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly, // Download package file and save it to disk func getPackageFile(u *updateInfo) error { - resp, err := config.client.Get(u.pkgURL) + resp, err := Context.client.Get(u.pkgURL) if err != nil { return fmt.Errorf("HTTP request failed: %s", err) } @@ -436,17 +436,17 @@ func doUpdate(u *updateInfo) error { } // ./README.md -> backup/README.md - err = copySupportingFiles(files, config.ourWorkingDir, u.backupDir, true, true) + err = copySupportingFiles(files, Context.workDir, u.backupDir, true, true) if err != nil { return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", - config.ourWorkingDir, u.backupDir, err) + Context.workDir, u.backupDir, err) } // update/[AdGuardHome/]README.md -> ./README.md - err = copySupportingFiles(files, u.updateDir, config.ourWorkingDir, false, true) + err = copySupportingFiles(files, u.updateDir, Context.workDir, false, true) if err != nil { return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", - u.updateDir, config.ourWorkingDir, err) + u.updateDir, Context.workDir, err) } log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName) @@ -478,8 +478,7 @@ func finishUpdate(u *updateInfo) { cleanupAlways() if runtime.GOOS == "windows" { - - if config.runningAsService { + if Context.runningAsService { // Note: // we can't restart the service via "kardianos/service" package - it kills the process first // we can't start a new instance - Windows doesn't allow it diff --git a/home/control_update_test.go b/home/control_update_test.go index c30a72e4..cfc13823 100644 --- a/home/control_update_test.go +++ b/home/control_update_test.go @@ -10,7 +10,7 @@ import ( func TestDoUpdate(t *testing.T) { config.DNS.Port = 0 - config.ourWorkingDir = "..." // set absolute path + Context.workDir = "..." // set absolute path newver := "v0.96" data := `{ @@ -35,15 +35,15 @@ func TestDoUpdate(t *testing.T) { u := updateInfo{ pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/" + newver + "/AdGuardHome_linux_amd64.tar.gz", - pkgName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz", + pkgName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz", newVer: newver, - updateDir: config.ourWorkingDir + "/agh-update-" + newver, - backupDir: config.ourWorkingDir + "/agh-backup", - configName: config.ourWorkingDir + "/AdGuardHome.yaml", - updateConfigName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome.yaml", - curBinName: config.ourWorkingDir + "/AdGuardHome", - bkpBinName: config.ourWorkingDir + "/agh-backup/AdGuardHome", - newBinName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome", + updateDir: Context.workDir + "/agh-update-" + newver, + backupDir: Context.workDir + "/agh-backup", + configName: Context.workDir + "/AdGuardHome.yaml", + updateConfigName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome.yaml", + curBinName: Context.workDir + "/AdGuardHome", + bkpBinName: Context.workDir + "/agh-backup/AdGuardHome", + newBinName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome", } if uu.pkgURL != u.pkgURL || diff --git a/home/dns.go b/home/dns.go index 4d5dceb3..167662b1 100644 --- a/home/dns.go +++ b/home/dns.go @@ -25,7 +25,7 @@ func onConfigModified() { // Please note that we must do it even if we don't start it // so that we had access to the query log and the stats func initDNSServer() error { - baseDir := config.getDataDir() + baseDir := Context.getDataDir() err := os.MkdirAll(baseDir, 0755) if err != nil { @@ -71,8 +71,8 @@ func initDNSServer() error { } sessFilename := filepath.Join(baseDir, "sessions.db") - config.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) - if config.auth == nil { + Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) + if Context.auth == nil { closeDNSServer() return fmt.Errorf("Couldn't initialize Auth module") } @@ -294,9 +294,9 @@ func closeDNSServer() { Context.queryLog = nil } - if config.auth != nil { - config.auth.Close() - config.auth = nil + if Context.auth != nil { + Context.auth.Close() + Context.auth = nil } log.Debug("Closed all DNS modules") diff --git a/home/filter.go b/home/filter.go index 6b0a16ef..9e0a5abf 100644 --- a/home/filter.go +++ b/home/filter.go @@ -424,7 +424,7 @@ func parseFilterContents(contents []byte) (int, string) { func (filter *filter) update() (bool, error) { log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL) - resp, err := config.client.Get(filter.URL) + resp, err := Context.client.Get(filter.URL) if resp != nil && resp.Body != nil { defer resp.Body.Close() } @@ -538,7 +538,7 @@ func (filter *filter) unload() { // Path to the filter contents func (filter *filter) Path() string { - return filepath.Join(config.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt") + return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt") } // LastTimeUpdated returns the time when the filter was last time updated diff --git a/home/filter_test.go b/home/filter_test.go index 63736c38..fe1826d2 100644 --- a/home/filter_test.go +++ b/home/filter_test.go @@ -10,7 +10,7 @@ import ( ) func TestFilters(t *testing.T) { - config.client = &http.Client{ + Context.client = &http.Client{ Timeout: time.Minute * 5, } diff --git a/home/helpers.go b/home/helpers.go index c5540f14..5a87aec7 100644 --- a/home/helpers.go +++ b/home/helpers.go @@ -32,8 +32,8 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun } if method == "POST" || method == "PUT" || method == "DELETE" { - config.controlLock.Lock() - defer config.controlLock.Unlock() + Context.controlLock.Lock() + defer Context.controlLock.Unlock() } handler(w, r) @@ -67,9 +67,9 @@ func ensureHandler(method string, handler func(http.ResponseWriter, *http.Reques // first run / install // ------------------- func detectFirstRun() bool { - configfile := config.ourConfigFilename + configfile := Context.configFilename if !filepath.IsAbs(configfile) { - configfile = filepath.Join(config.ourWorkingDir, config.ourConfigFilename) + configfile = filepath.Join(Context.workDir, Context.configFilename) } _, err := os.Stat(configfile) if !os.IsNotExist(err) { @@ -82,7 +82,7 @@ func detectFirstRun() bool { // preInstall lets the handler run only if firstRun is true, no redirects func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - if !config.firstRun { + if !Context.firstRun { // if it's not first run, don't let users access it (for example /install.html when configuration is done) http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return @@ -109,7 +109,7 @@ func preInstallHandler(handler http.Handler) http.Handler { // it also enforces HTTPS if it is enabled and configured func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - if config.firstRun && + if Context.firstRun && !strings.HasPrefix(r.URL.Path, "/install.") && r.URL.Path != "/favicon.png" { http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable diff --git a/home/home.go b/home/home.go index 5b34847b..b31f1044 100644 --- a/home/home.go +++ b/home/home.go @@ -49,6 +49,9 @@ const versionCheckPeriod = time.Hour * 8 // Global context type homeContext struct { + // Modules + // -- + clients clientsContainer // per-client-settings module stats stats.Stats // statistics module queryLog querylog.QueryLog // query log module @@ -57,8 +60,29 @@ type homeContext struct { whois *Whois // WHOIS module dnsFilter *dnsfilter.Dnsfilter // DNS filtering module dhcpServer *dhcpd.Server // DHCP module + auth *Auth // HTTP authentication module httpServer *http.Server // HTTP module httpsServer HTTPSServer // HTTPS module + + // Runtime properties + // -- + + configFilename string // Config filename (can be overridden via the command line arguments) + workDir string // Location of our directory, used to protect against CWD being somewhere else + firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html + pidFileName string // PID file name. Empty if no PID file was created. + disableUpdate bool // If set, don't check for updates + controlLock sync.Mutex + transport *http.Transport + client *http.Client + appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app + // runningAsService flag is set to true when options are passed from the service runner + runningAsService bool +} + +// getDataDir returns path to the directory where we store databases and filters +func (c *homeContext) getDataDir() string { + return filepath.Join(c.workDir, dataDir) } // Context - a global context object @@ -81,17 +105,38 @@ func Main(version string, channel string, armVer string) { return } + Context.appSignalChannel = make(chan os.Signal) + signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) + go func() { + <-Context.appSignalChannel + cleanup() + cleanupAlways() + os.Exit(0) + }() + // run the protection run(args) } // run initializes configuration and runs the AdGuard Home -// run is a blocking method and it won't exit until the service is stopped! +// run is a blocking method! // nolint func run(args options) { // config file path can be overridden by command-line arguments: if args.configFilename != "" { - config.ourConfigFilename = args.configFilename + Context.configFilename = args.configFilename + } else { + // Default config file name + Context.configFilename = "AdGuardHome.yaml" + } + + // Init some of the Context fields right away + Context.transport = &http.Transport{ + DialContext: customDialContext, + } + Context.client = &http.Client{ + Timeout: time.Minute * 5, + Transport: Context.transport, } // configure working dir and config path @@ -106,31 +151,22 @@ func run(args options) { msg = msg + " v" + ARMVersion } log.Printf(msg, versionString, updateChannel, runtime.GOOS, runtime.GOARCH, ARMVersion) - log.Debug("Current working directory is %s", config.ourWorkingDir) + log.Debug("Current working directory is %s", Context.workDir) if args.runningAsService { log.Info("AdGuard Home is running as a service") } - config.runningAsService = args.runningAsService - config.disableUpdate = args.disableUpdate + Context.runningAsService = args.runningAsService + Context.disableUpdate = args.disableUpdate - config.firstRun = detectFirstRun() - if config.firstRun { + Context.firstRun = detectFirstRun() + if Context.firstRun { requireAdminRights() } - config.appSignalChannel = make(chan os.Signal) - signal.Notify(config.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) - go func() { - <-config.appSignalChannel - cleanup() - cleanupAlways() - os.Exit(0) - }() - initConfig() initServices() - if !config.firstRun { + if !Context.firstRun { // Do the upgrade if necessary err := upgradeConfig() if err != nil { @@ -148,7 +184,7 @@ func run(args options) { } } - config.DHCP.WorkDir = config.ourWorkingDir + config.DHCP.WorkDir = Context.workDir config.DHCP.HTTPRegister = httpRegister config.DHCP.ConfigModified = onConfigModified Context.dhcpServer = dhcpd.Create(config.DHCP) @@ -168,7 +204,7 @@ func run(args options) { config.BindPort = args.bindPort } - if !config.firstRun { + if !Context.firstRun { // Save the updated config err := config.write() if err != nil { @@ -193,7 +229,7 @@ func run(args options) { } if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { - config.pidFileName = args.pidFile + Context.pidFileName = args.pidFile } // Initialize and run the admin Web interface @@ -204,7 +240,7 @@ func run(args options) { registerControlHandlers() // add handlers for /install paths, we only need them when we're not configured yet - if config.firstRun { + if Context.firstRun { log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ") http.Handle("/install.html", preInstallHandler(http.FileServer(box))) registerInstallHandlers() @@ -331,19 +367,19 @@ func writePIDFile(fn string) bool { return true } -// initWorkingDir initializes the ourWorkingDir +// initWorkingDir initializes the workDir // if no command-line arguments specified, we use the directory where our binary file is located func initWorkingDir(args options) { - exec, err := os.Executable() + execPath, err := os.Executable() if err != nil { panic(err) } if args.workDir != "" { // If there is a custom config file, use it's directory as our working dir - config.ourWorkingDir = args.workDir + Context.workDir = args.workDir } else { - config.ourWorkingDir = filepath.Dir(exec) + Context.workDir = filepath.Dir(execPath) } } @@ -381,7 +417,7 @@ func configureLogger(args options) { log.Fatalf("cannot initialize syslog: %s", err) } } else { - logFilePath := filepath.Join(config.ourWorkingDir, ls.LogFile) + logFilePath := filepath.Join(Context.workDir, ls.LogFile) if filepath.IsAbs(ls.LogFile) { logFilePath = ls.LogFile } @@ -420,8 +456,8 @@ func stopHTTPServer() { // This function is called before application exits func cleanupAlways() { - if len(config.pidFileName) != 0 { - os.Remove(config.pidFileName) + if len(Context.pidFileName) != 0 { + _ = os.Remove(Context.pidFileName) } log.Info("Stopped") } diff --git a/home/home_test.go b/home/home_test.go index 771c74d0..2c40e97f 100644 --- a/home/home_test.go +++ b/home/home_test.go @@ -120,16 +120,18 @@ func TestHome(t *testing.T) { args.workDir = dir go run(args) + for i := 0; i < 5; i++ { + // Waiting until the DNS server is up and running + if !isRunning() { + time.Sleep(1 * time.Second) + } + } + var err error var resp *http.Response h := http.Client{} - for i := 0; i != 5; i++ { - resp, err = h.Get("http://127.0.0.1:3000/") - if err == nil && resp.StatusCode != 404 { - break - } - time.Sleep(1 * time.Second) - } + + resp, err = h.Get("http://127.0.0.1:3000/") assert.Truef(t, err == nil, "%s", err) assert.Equal(t, 200, resp.StatusCode) @@ -140,7 +142,7 @@ func TestHome(t *testing.T) { // test DNS over UDP r := upstream.NewResolver("127.0.0.1:5354", 3*time.Second) addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com") - assert.Truef(t, err == nil, "%s", err) + assert.Nil(t, err) haveIP := len(addrs) != 0 assert.True(t, haveIP) @@ -159,7 +161,7 @@ func TestHome(t *testing.T) { assert.True(t, resp.StatusCode == http.StatusOK) response := dns.Msg{} err = response.Unpack(body) - assert.True(t, err == nil, "%s", err) + assert.Nil(t, err) addrs = nil proxyutil.AppendIPAddrs(&addrs, response.Answer) haveIP = len(addrs) != 0 diff --git a/home/service.go b/home/service.go index edca9244..1951ed28 100644 --- a/home/service.go +++ b/home/service.go @@ -34,10 +34,10 @@ func (p *program) Start(s service.Service) error { // Stop stops the program func (p *program) Stop(s service.Service) error { // Stop should not block. Return with a few seconds. - if config.appSignalChannel == nil { + if Context.appSignalChannel == nil { os.Exit(0) } - config.appSignalChannel <- syscall.SIGINT + Context.appSignalChannel <- syscall.SIGINT return nil } diff --git a/home/upgrade.go b/home/upgrade.go index 3a703ebc..7336f846 100644 --- a/home/upgrade.go +++ b/home/upgrade.go @@ -116,7 +116,7 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err func upgradeSchema0to1(diskConfig *map[string]interface{}) error { log.Printf("%s(): called", _Func()) - dnsFilterPath := filepath.Join(config.ourWorkingDir, "dnsfilter.txt") + dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt") if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) { log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath) err = os.Remove(dnsFilterPath) @@ -137,7 +137,7 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error { func upgradeSchema1to2(diskConfig *map[string]interface{}) error { log.Printf("%s(): called", _Func()) - coreFilePath := filepath.Join(config.ourWorkingDir, "Corefile") + coreFilePath := filepath.Join(Context.workDir, "Corefile") if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) { log.Printf("Deleting %s as we don't need it anymore", coreFilePath) err = os.Remove(coreFilePath)