*(global): refactoring - moved runtime properties to Context

This commit is contained in:
Andrey Meshkov 2020-02-12 15:53:36 +03:00
parent d8d48c5386
commit ae2990582d
16 changed files with 141 additions and 131 deletions

View File

@ -307,7 +307,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
return return
} }
cookie := config.auth.httpCookie(req) cookie := Context.auth.httpCookie(req)
if len(cookie) == 0 { if len(cookie) == 0 {
log.Info("Auth: invalid user name or password: name='%s'", req.Name) log.Info("Auth: invalid user name or password: name='%s'", req.Name)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
@ -328,7 +328,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
cookie := r.Header.Get("Cookie") cookie := r.Header.Get("Cookie")
sess := parseCookie(cookie) sess := parseCookie(cookie)
config.auth.RemoveSession(sess) Context.auth.RemoveSession(sess)
w.Header().Set("Location", "/login.html") w.Header().Set("Location", "/login.html")
@ -365,10 +365,10 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
if r.URL.Path == "/login.html" { if r.URL.Path == "/login.html" {
// redirect to dashboard if already authenticated // redirect to dashboard if already authenticated
authRequired := config.auth != nil && config.auth.AuthRequired() authRequired := Context.auth != nil && Context.auth.AuthRequired()
cookie, err := r.Cookie(sessionCookieName) cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil { if authRequired && err == nil {
r := config.auth.CheckSession(cookie.Value) r := Context.auth.CheckSession(cookie.Value)
if r == 0 { if r == 0 {
w.Header().Set("Location", "/") w.Header().Set("Location", "/")
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
@ -383,12 +383,12 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
strings.HasPrefix(r.URL.Path, "/__locales/") { strings.HasPrefix(r.URL.Path, "/__locales/") {
// process as usual // process as usual
} else if config.auth != nil && config.auth.AuthRequired() { } else if Context.auth != nil && Context.auth.AuthRequired() {
// redirect to login page if not authenticated // redirect to login page if not authenticated
ok := false ok := false
cookie, err := r.Cookie(sessionCookieName) cookie, err := r.Cookie(sessionCookieName)
if err == nil { if err == nil {
r := config.auth.CheckSession(cookie.Value) r := Context.auth.CheckSession(cookie.Value)
if r == 0 { if r == 0 {
ok = true ok = true
} else if r < 0 { } else if r < 0 {
@ -398,7 +398,7 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
// there's no Cookie, check Basic authentication // there's no Cookie, check Basic authentication
user, pass, ok2 := r.BasicAuth() user, pass, ok2 := r.BasicAuth()
if ok2 { if ok2 {
u := config.auth.UserFind(user, pass) u := Context.auth.UserFind(user, pass)
if len(u.Name) != 0 { if len(u.Name) != 0 {
ok = true ok = true
} else { } else {
@ -474,7 +474,7 @@ func (a *Auth) GetCurrentUser(r *http.Request) User {
// there's no Cookie, check Basic authentication // there's no Cookie, check Basic authentication
user, pass, ok := r.BasicAuth() user, pass, ok := r.BasicAuth()
if ok { if ok {
u := config.auth.UserFind(user, pass) u := Context.auth.UserFind(user, pass)
return u return u
} }
return User{} return User{}

View File

@ -100,7 +100,7 @@ func TestAuthHTTP(t *testing.T) {
users := []User{ users := []User{
User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"}, User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
} }
config.auth = InitAuth(fn, users, 60) Context.auth = InitAuth(fn, users, 60)
handlerCalled := false handlerCalled := false
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(w http.ResponseWriter, r *http.Request) {
@ -129,7 +129,7 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled) assert.True(t, handlerCalled)
// perform login // perform login
cookie := config.auth.httpCookie(loginJSON{Name: "name", Password: "password"}) cookie := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"})
assert.True(t, cookie != "") assert.True(t, cookie != "")
// get / // get /
@ -173,5 +173,5 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled) assert.True(t, handlerCalled)
r.Header.Del("Cookie") r.Header.Del("Cookie")
config.auth.Close() Context.auth.Close()
} }

View File

@ -44,19 +44,6 @@ type configuration struct {
// It's reset after config is parsed // It's reset after config is parsed
fileData []byte fileData []byte
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
pidFileName string // PID file name. Empty if no PID file was created.
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
disableUpdate bool // If set, don't check for updates
appSignalChannel chan os.Signal
controlLock sync.Mutex
transport *http.Transport
client *http.Client
auth *Auth // HTTP authentication module
// cached version.json to avoid hammering github.io for each page reload // cached version.json to avoid hammering github.io for each page reload
versionCheckJSON []byte versionCheckJSON []byte
versionCheckLastTime time.Time versionCheckLastTime time.Time
@ -152,7 +139,6 @@ type tlsConfig struct {
// initialize to default values, will be changed later when reading config or parsing command line // initialize to default values, will be changed later when reading config or parsing command line
var config = configuration{ var config = configuration{
ourConfigFilename: "AdGuardHome.yaml",
BindPort: 3000, BindPort: 3000,
BindHost: "0.0.0.0", BindHost: "0.0.0.0",
DNS: dnsConfig{ DNS: dnsConfig{
@ -185,14 +171,6 @@ var config = configuration{
// initConfig initializes default configuration for the current OS&ARCH // initConfig initializes default configuration for the current OS&ARCH
func initConfig() { func initConfig() {
config.transport = &http.Transport{
DialContext: customDialContext,
}
config.client = &http.Client{
Timeout: time.Minute * 5,
Transport: config.transport,
}
config.WebSessionTTLHours = 30 * 24 config.WebSessionTTLHours = 30 * 24
config.DNS.QueryLogEnabled = true config.DNS.QueryLogEnabled = true
@ -209,24 +187,19 @@ func initConfig() {
// getConfigFilename returns path to the current config file // getConfigFilename returns path to the current config file
func (c *configuration) getConfigFilename() string { func (c *configuration) getConfigFilename() string {
configFile, err := filepath.EvalSymlinks(config.ourConfigFilename) configFile, err := filepath.EvalSymlinks(Context.configFilename)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
log.Error("unexpected error while config file path evaluation: %s", err) log.Error("unexpected error while config file path evaluation: %s", err)
} }
configFile = config.ourConfigFilename configFile = Context.configFilename
} }
if !filepath.IsAbs(configFile) { if !filepath.IsAbs(configFile) {
configFile = filepath.Join(config.ourWorkingDir, configFile) configFile = filepath.Join(Context.workDir, configFile)
} }
return configFile return configFile
} }
// getDataDir returns path to the directory where we store databases and filters
func (c *configuration) getDataDir() string {
return filepath.Join(c.ourWorkingDir, dataDir)
}
// getLogSettings reads logging settings from the config file. // getLogSettings reads logging settings from the config file.
// we do it in a separate method in order to configure logger before the actual configuration is parsed and applied. // we do it in a separate method in order to configure logger before the actual configuration is parsed and applied.
func getLogSettings() logSettings { func getLogSettings() logSettings {
@ -292,8 +265,8 @@ func (c *configuration) write() error {
Context.clients.WriteDiskConfig(&config.Clients) Context.clients.WriteDiskConfig(&config.Clients)
if config.auth != nil { if Context.auth != nil {
config.Users = config.auth.GetUsers() config.Users = Context.auth.GetUsers()
} }
if Context.stats != nil { if Context.stats != nil {

View File

@ -129,7 +129,7 @@ type profileJSON struct {
func handleGetProfile(w http.ResponseWriter, r *http.Request) { func handleGetProfile(w http.ResponseWriter, r *http.Request) {
pj := profileJSON{} pj := profileJSON{}
u := config.auth.GetCurrentUser(r) u := Context.auth.GetCurrentUser(r)
pj.Name = u.Name pj.Name = u.Name
data, err := json.Marshal(pj) data, err := json.Marshal(pj)

View File

@ -210,9 +210,9 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
} }
func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
config.controlLock.Unlock() Context.controlLock.Unlock()
nUpdated, err := refreshFilters() nUpdated, err := refreshFilters()
config.controlLock.Lock() Context.controlLock.Lock()
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err) httpError(w, http.StatusInternalServerError, "%s", err)
return return

View File

@ -302,7 +302,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
var curConfig configuration var curConfig configuration
copyInstallSettings(&curConfig, &config) copyInstallSettings(&curConfig, &config)
config.firstRun = false Context.firstRun = false
config.BindHost = newSettings.Web.IP config.BindHost = newSettings.Web.IP
config.BindPort = newSettings.Web.Port config.BindPort = newSettings.Web.Port
config.DNS.BindHost = newSettings.DNS.IP config.DNS.BindHost = newSettings.DNS.IP
@ -317,7 +317,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
} }
} }
if err != nil || err2 != nil { if err != nil || err2 != nil {
config.firstRun = true Context.firstRun = true
copyInstallSettings(&config, &curConfig) copyInstallSettings(&config, &curConfig)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err) httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err)
@ -329,11 +329,11 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
u := User{} u := User{}
u.Name = newSettings.Username u.Name = newSettings.Username
config.auth.UserAdd(&u, newSettings.Password) Context.auth.UserAdd(&u, newSettings.Password)
err = config.write() err = config.write()
if err != nil { if err != nil {
config.firstRun = true Context.firstRun = true
copyInstallSettings(&config, &curConfig) copyInstallSettings(&config, &curConfig)
httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err) httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err)
return return

View File

@ -64,7 +64,7 @@ type getVersionJSONRequest struct {
// Get the latest available version from the Internet // Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
if config.disableUpdate { if Context.disableUpdate {
return return
} }
@ -77,10 +77,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
now := time.Now() now := time.Now()
if !req.RecheckNow { if !req.RecheckNow {
config.controlLock.Lock() Context.controlLock.Lock()
cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0 cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0
data := config.versionCheckJSON data := config.versionCheckJSON
config.controlLock.Unlock() Context.controlLock.Unlock()
if cached { if cached {
log.Tracef("Returning cached data") log.Tracef("Returning cached data")
@ -93,7 +93,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
var resp *http.Response var resp *http.Response
for i := 0; i != 3; i++ { for i := 0; i != 3; i++ {
log.Tracef("Downloading data from %s", versionCheckURL) log.Tracef("Downloading data from %s", versionCheckURL)
resp, err = config.client.Get(versionCheckURL) resp, err = Context.client.Get(versionCheckURL)
if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") { if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") {
// This case may happen while we're restarting DNS server // This case may happen while we're restarting DNS server
// https://github.com/AdguardTeam/AdGuardHome/issues/934 // https://github.com/AdguardTeam/AdGuardHome/issues/934
@ -116,10 +116,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
return return
} }
config.controlLock.Lock() Context.controlLock.Lock()
config.versionCheckLastTime = now config.versionCheckLastTime = now
config.versionCheckJSON = body config.versionCheckJSON = body
config.controlLock.Unlock() Context.controlLock.Unlock()
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(getVersionResp(body)) _, err = w.Write(getVersionResp(body))
@ -158,7 +158,7 @@ type updateInfo struct {
func getUpdateInfo(jsonData []byte) (*updateInfo, error) { func getUpdateInfo(jsonData []byte) (*updateInfo, error) {
var u updateInfo var u updateInfo
workDir := config.ourWorkingDir workDir := Context.workDir
versionJSON := make(map[string]interface{}) versionJSON := make(map[string]interface{})
err := json.Unmarshal(jsonData, &versionJSON) err := json.Unmarshal(jsonData, &versionJSON)
@ -365,7 +365,7 @@ func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly,
// Download package file and save it to disk // Download package file and save it to disk
func getPackageFile(u *updateInfo) error { func getPackageFile(u *updateInfo) error {
resp, err := config.client.Get(u.pkgURL) resp, err := Context.client.Get(u.pkgURL)
if err != nil { if err != nil {
return fmt.Errorf("HTTP request failed: %s", err) return fmt.Errorf("HTTP request failed: %s", err)
} }
@ -436,17 +436,17 @@ func doUpdate(u *updateInfo) error {
} }
// ./README.md -> backup/README.md // ./README.md -> backup/README.md
err = copySupportingFiles(files, config.ourWorkingDir, u.backupDir, true, true) err = copySupportingFiles(files, Context.workDir, u.backupDir, true, true)
if err != nil { if err != nil {
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s",
config.ourWorkingDir, u.backupDir, err) Context.workDir, u.backupDir, err)
} }
// update/[AdGuardHome/]README.md -> ./README.md // update/[AdGuardHome/]README.md -> ./README.md
err = copySupportingFiles(files, u.updateDir, config.ourWorkingDir, false, true) err = copySupportingFiles(files, u.updateDir, Context.workDir, false, true)
if err != nil { if err != nil {
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s",
u.updateDir, config.ourWorkingDir, err) u.updateDir, Context.workDir, err)
} }
log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName) log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName)
@ -478,8 +478,7 @@ func finishUpdate(u *updateInfo) {
cleanupAlways() cleanupAlways()
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
if Context.runningAsService {
if config.runningAsService {
// Note: // Note:
// we can't restart the service via "kardianos/service" package - it kills the process first // we can't restart the service via "kardianos/service" package - it kills the process first
// we can't start a new instance - Windows doesn't allow it // we can't start a new instance - Windows doesn't allow it

View File

@ -10,7 +10,7 @@ import (
func TestDoUpdate(t *testing.T) { func TestDoUpdate(t *testing.T) {
config.DNS.Port = 0 config.DNS.Port = 0
config.ourWorkingDir = "..." // set absolute path Context.workDir = "..." // set absolute path
newver := "v0.96" newver := "v0.96"
data := `{ data := `{
@ -35,15 +35,15 @@ func TestDoUpdate(t *testing.T) {
u := updateInfo{ u := updateInfo{
pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/" + newver + "/AdGuardHome_linux_amd64.tar.gz", pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/" + newver + "/AdGuardHome_linux_amd64.tar.gz",
pkgName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz", pkgName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz",
newVer: newver, newVer: newver,
updateDir: config.ourWorkingDir + "/agh-update-" + newver, updateDir: Context.workDir + "/agh-update-" + newver,
backupDir: config.ourWorkingDir + "/agh-backup", backupDir: Context.workDir + "/agh-backup",
configName: config.ourWorkingDir + "/AdGuardHome.yaml", configName: Context.workDir + "/AdGuardHome.yaml",
updateConfigName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome.yaml", updateConfigName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome.yaml",
curBinName: config.ourWorkingDir + "/AdGuardHome", curBinName: Context.workDir + "/AdGuardHome",
bkpBinName: config.ourWorkingDir + "/agh-backup/AdGuardHome", bkpBinName: Context.workDir + "/agh-backup/AdGuardHome",
newBinName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome", newBinName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome",
} }
if uu.pkgURL != u.pkgURL || if uu.pkgURL != u.pkgURL ||

View File

@ -25,7 +25,7 @@ func onConfigModified() {
// Please note that we must do it even if we don't start it // Please note that we must do it even if we don't start it
// so that we had access to the query log and the stats // so that we had access to the query log and the stats
func initDNSServer() error { func initDNSServer() error {
baseDir := config.getDataDir() baseDir := Context.getDataDir()
err := os.MkdirAll(baseDir, 0755) err := os.MkdirAll(baseDir, 0755)
if err != nil { if err != nil {
@ -71,8 +71,8 @@ func initDNSServer() error {
} }
sessFilename := filepath.Join(baseDir, "sessions.db") sessFilename := filepath.Join(baseDir, "sessions.db")
config.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
if config.auth == nil { if Context.auth == nil {
closeDNSServer() closeDNSServer()
return fmt.Errorf("Couldn't initialize Auth module") return fmt.Errorf("Couldn't initialize Auth module")
} }
@ -294,9 +294,9 @@ func closeDNSServer() {
Context.queryLog = nil Context.queryLog = nil
} }
if config.auth != nil { if Context.auth != nil {
config.auth.Close() Context.auth.Close()
config.auth = nil Context.auth = nil
} }
log.Debug("Closed all DNS modules") log.Debug("Closed all DNS modules")

View File

@ -424,7 +424,7 @@ func parseFilterContents(contents []byte) (int, string) {
func (filter *filter) update() (bool, error) { func (filter *filter) update() (bool, error) {
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL) log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
resp, err := config.client.Get(filter.URL) resp, err := Context.client.Get(filter.URL)
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
} }
@ -538,7 +538,7 @@ func (filter *filter) unload() {
// Path to the filter contents // Path to the filter contents
func (filter *filter) Path() string { func (filter *filter) Path() string {
return filepath.Join(config.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt") return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
} }
// LastTimeUpdated returns the time when the filter was last time updated // LastTimeUpdated returns the time when the filter was last time updated

View File

@ -10,7 +10,7 @@ import (
) )
func TestFilters(t *testing.T) { func TestFilters(t *testing.T) {
config.client = &http.Client{ Context.client = &http.Client{
Timeout: time.Minute * 5, Timeout: time.Minute * 5,
} }

View File

@ -32,8 +32,8 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
} }
if method == "POST" || method == "PUT" || method == "DELETE" { if method == "POST" || method == "PUT" || method == "DELETE" {
config.controlLock.Lock() Context.controlLock.Lock()
defer config.controlLock.Unlock() defer Context.controlLock.Unlock()
} }
handler(w, r) handler(w, r)
@ -67,9 +67,9 @@ func ensureHandler(method string, handler func(http.ResponseWriter, *http.Reques
// first run / install // first run / install
// ------------------- // -------------------
func detectFirstRun() bool { func detectFirstRun() bool {
configfile := config.ourConfigFilename configfile := Context.configFilename
if !filepath.IsAbs(configfile) { if !filepath.IsAbs(configfile) {
configfile = filepath.Join(config.ourWorkingDir, config.ourConfigFilename) configfile = filepath.Join(Context.workDir, Context.configFilename)
} }
_, err := os.Stat(configfile) _, err := os.Stat(configfile)
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
@ -82,7 +82,7 @@ func detectFirstRun() bool {
// preInstall lets the handler run only if firstRun is true, no redirects // preInstall lets the handler run only if firstRun is true, no redirects
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !config.firstRun { if !Context.firstRun {
// if it's not first run, don't let users access it (for example /install.html when configuration is done) // if it's not first run, don't let users access it (for example /install.html when configuration is done)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return return
@ -109,7 +109,7 @@ func preInstallHandler(handler http.Handler) http.Handler {
// it also enforces HTTPS if it is enabled and configured // it also enforces HTTPS if it is enabled and configured
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if config.firstRun && if Context.firstRun &&
!strings.HasPrefix(r.URL.Path, "/install.") && !strings.HasPrefix(r.URL.Path, "/install.") &&
r.URL.Path != "/favicon.png" { r.URL.Path != "/favicon.png" {
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable

View File

@ -49,6 +49,9 @@ const versionCheckPeriod = time.Hour * 8
// Global context // Global context
type homeContext struct { type homeContext struct {
// Modules
// --
clients clientsContainer // per-client-settings module clients clientsContainer // per-client-settings module
stats stats.Stats // statistics module stats stats.Stats // statistics module
queryLog querylog.QueryLog // query log module queryLog querylog.QueryLog // query log module
@ -57,8 +60,29 @@ type homeContext struct {
whois *Whois // WHOIS module whois *Whois // WHOIS module
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
dhcpServer *dhcpd.Server // DHCP module dhcpServer *dhcpd.Server // DHCP module
auth *Auth // HTTP authentication module
httpServer *http.Server // HTTP module httpServer *http.Server // HTTP module
httpsServer HTTPSServer // HTTPS module httpsServer HTTPSServer // HTTPS module
// Runtime properties
// --
configFilename string // Config filename (can be overridden via the command line arguments)
workDir string // Location of our directory, used to protect against CWD being somewhere else
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
pidFileName string // PID file name. Empty if no PID file was created.
disableUpdate bool // If set, don't check for updates
controlLock sync.Mutex
transport *http.Transport
client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
}
// getDataDir returns path to the directory where we store databases and filters
func (c *homeContext) getDataDir() string {
return filepath.Join(c.workDir, dataDir)
} }
// Context - a global context object // Context - a global context object
@ -81,17 +105,38 @@ func Main(version string, channel string, armVer string) {
return return
} }
Context.appSignalChannel = make(chan os.Signal)
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-Context.appSignalChannel
cleanup()
cleanupAlways()
os.Exit(0)
}()
// run the protection // run the protection
run(args) run(args)
} }
// run initializes configuration and runs the AdGuard Home // run initializes configuration and runs the AdGuard Home
// run is a blocking method and it won't exit until the service is stopped! // run is a blocking method!
// nolint // nolint
func run(args options) { func run(args options) {
// config file path can be overridden by command-line arguments: // config file path can be overridden by command-line arguments:
if args.configFilename != "" { if args.configFilename != "" {
config.ourConfigFilename = args.configFilename Context.configFilename = args.configFilename
} else {
// Default config file name
Context.configFilename = "AdGuardHome.yaml"
}
// Init some of the Context fields right away
Context.transport = &http.Transport{
DialContext: customDialContext,
}
Context.client = &http.Client{
Timeout: time.Minute * 5,
Transport: Context.transport,
} }
// configure working dir and config path // configure working dir and config path
@ -106,31 +151,22 @@ func run(args options) {
msg = msg + " v" + ARMVersion msg = msg + " v" + ARMVersion
} }
log.Printf(msg, versionString, updateChannel, runtime.GOOS, runtime.GOARCH, ARMVersion) log.Printf(msg, versionString, updateChannel, runtime.GOOS, runtime.GOARCH, ARMVersion)
log.Debug("Current working directory is %s", config.ourWorkingDir) log.Debug("Current working directory is %s", Context.workDir)
if args.runningAsService { if args.runningAsService {
log.Info("AdGuard Home is running as a service") log.Info("AdGuard Home is running as a service")
} }
config.runningAsService = args.runningAsService Context.runningAsService = args.runningAsService
config.disableUpdate = args.disableUpdate Context.disableUpdate = args.disableUpdate
config.firstRun = detectFirstRun() Context.firstRun = detectFirstRun()
if config.firstRun { if Context.firstRun {
requireAdminRights() requireAdminRights()
} }
config.appSignalChannel = make(chan os.Signal)
signal.Notify(config.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-config.appSignalChannel
cleanup()
cleanupAlways()
os.Exit(0)
}()
initConfig() initConfig()
initServices() initServices()
if !config.firstRun { if !Context.firstRun {
// Do the upgrade if necessary // Do the upgrade if necessary
err := upgradeConfig() err := upgradeConfig()
if err != nil { if err != nil {
@ -148,7 +184,7 @@ func run(args options) {
} }
} }
config.DHCP.WorkDir = config.ourWorkingDir config.DHCP.WorkDir = Context.workDir
config.DHCP.HTTPRegister = httpRegister config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified config.DHCP.ConfigModified = onConfigModified
Context.dhcpServer = dhcpd.Create(config.DHCP) Context.dhcpServer = dhcpd.Create(config.DHCP)
@ -168,7 +204,7 @@ func run(args options) {
config.BindPort = args.bindPort config.BindPort = args.bindPort
} }
if !config.firstRun { if !Context.firstRun {
// Save the updated config // Save the updated config
err := config.write() err := config.write()
if err != nil { if err != nil {
@ -193,7 +229,7 @@ func run(args options) {
} }
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
config.pidFileName = args.pidFile Context.pidFileName = args.pidFile
} }
// Initialize and run the admin Web interface // Initialize and run the admin Web interface
@ -204,7 +240,7 @@ func run(args options) {
registerControlHandlers() registerControlHandlers()
// add handlers for /install paths, we only need them when we're not configured yet // add handlers for /install paths, we only need them when we're not configured yet
if config.firstRun { if Context.firstRun {
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ") log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
http.Handle("/install.html", preInstallHandler(http.FileServer(box))) http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
registerInstallHandlers() registerInstallHandlers()
@ -331,19 +367,19 @@ func writePIDFile(fn string) bool {
return true return true
} }
// initWorkingDir initializes the ourWorkingDir // initWorkingDir initializes the workDir
// if no command-line arguments specified, we use the directory where our binary file is located // if no command-line arguments specified, we use the directory where our binary file is located
func initWorkingDir(args options) { func initWorkingDir(args options) {
exec, err := os.Executable() execPath, err := os.Executable()
if err != nil { if err != nil {
panic(err) panic(err)
} }
if args.workDir != "" { if args.workDir != "" {
// If there is a custom config file, use it's directory as our working dir // If there is a custom config file, use it's directory as our working dir
config.ourWorkingDir = args.workDir Context.workDir = args.workDir
} else { } else {
config.ourWorkingDir = filepath.Dir(exec) Context.workDir = filepath.Dir(execPath)
} }
} }
@ -381,7 +417,7 @@ func configureLogger(args options) {
log.Fatalf("cannot initialize syslog: %s", err) log.Fatalf("cannot initialize syslog: %s", err)
} }
} else { } else {
logFilePath := filepath.Join(config.ourWorkingDir, ls.LogFile) logFilePath := filepath.Join(Context.workDir, ls.LogFile)
if filepath.IsAbs(ls.LogFile) { if filepath.IsAbs(ls.LogFile) {
logFilePath = ls.LogFile logFilePath = ls.LogFile
} }
@ -420,8 +456,8 @@ func stopHTTPServer() {
// This function is called before application exits // This function is called before application exits
func cleanupAlways() { func cleanupAlways() {
if len(config.pidFileName) != 0 { if len(Context.pidFileName) != 0 {
os.Remove(config.pidFileName) _ = os.Remove(Context.pidFileName)
} }
log.Info("Stopped") log.Info("Stopped")
} }

View File

@ -120,16 +120,18 @@ func TestHome(t *testing.T) {
args.workDir = dir args.workDir = dir
go run(args) go run(args)
for i := 0; i < 5; i++ {
// Waiting until the DNS server is up and running
if !isRunning() {
time.Sleep(1 * time.Second)
}
}
var err error var err error
var resp *http.Response var resp *http.Response
h := http.Client{} h := http.Client{}
for i := 0; i != 5; i++ {
resp, err = h.Get("http://127.0.0.1:3000/") resp, err = h.Get("http://127.0.0.1:3000/")
if err == nil && resp.StatusCode != 404 {
break
}
time.Sleep(1 * time.Second)
}
assert.Truef(t, err == nil, "%s", err) assert.Truef(t, err == nil, "%s", err)
assert.Equal(t, 200, resp.StatusCode) assert.Equal(t, 200, resp.StatusCode)
@ -140,7 +142,7 @@ func TestHome(t *testing.T) {
// test DNS over UDP // test DNS over UDP
r := upstream.NewResolver("127.0.0.1:5354", 3*time.Second) r := upstream.NewResolver("127.0.0.1:5354", 3*time.Second)
addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com") addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com")
assert.Truef(t, err == nil, "%s", err) assert.Nil(t, err)
haveIP := len(addrs) != 0 haveIP := len(addrs) != 0
assert.True(t, haveIP) assert.True(t, haveIP)
@ -159,7 +161,7 @@ func TestHome(t *testing.T) {
assert.True(t, resp.StatusCode == http.StatusOK) assert.True(t, resp.StatusCode == http.StatusOK)
response := dns.Msg{} response := dns.Msg{}
err = response.Unpack(body) err = response.Unpack(body)
assert.True(t, err == nil, "%s", err) assert.Nil(t, err)
addrs = nil addrs = nil
proxyutil.AppendIPAddrs(&addrs, response.Answer) proxyutil.AppendIPAddrs(&addrs, response.Answer)
haveIP = len(addrs) != 0 haveIP = len(addrs) != 0

View File

@ -34,10 +34,10 @@ func (p *program) Start(s service.Service) error {
// Stop stops the program // Stop stops the program
func (p *program) Stop(s service.Service) error { func (p *program) Stop(s service.Service) error {
// Stop should not block. Return with a few seconds. // Stop should not block. Return with a few seconds.
if config.appSignalChannel == nil { if Context.appSignalChannel == nil {
os.Exit(0) os.Exit(0)
} }
config.appSignalChannel <- syscall.SIGINT Context.appSignalChannel <- syscall.SIGINT
return nil return nil
} }

View File

@ -116,7 +116,7 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err
func upgradeSchema0to1(diskConfig *map[string]interface{}) error { func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func()) log.Printf("%s(): called", _Func())
dnsFilterPath := filepath.Join(config.ourWorkingDir, "dnsfilter.txt") dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt")
if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) { if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath) log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath)
err = os.Remove(dnsFilterPath) err = os.Remove(dnsFilterPath)
@ -137,7 +137,7 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
func upgradeSchema1to2(diskConfig *map[string]interface{}) error { func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func()) log.Printf("%s(): called", _Func())
coreFilePath := filepath.Join(config.ourWorkingDir, "Corefile") coreFilePath := filepath.Join(Context.workDir, "Corefile")
if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) { if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", coreFilePath) log.Printf("Deleting %s as we don't need it anymore", coreFilePath)
err = os.Remove(coreFilePath) err = os.Remove(coreFilePath)