From 313fd7107f3b305977bb132e14577eaff65d9777 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Mon, 15 Mar 2021 14:19:04 +0300 Subject: [PATCH] Pull request: all: imp code, err handling Closes #2571. Squashed commit of the following: commit a5b50ee011a995f4ab3d93314acd6f0ca82d99cf Author: Ainar Garipov Date: Mon Mar 15 14:05:25 2021 +0300 all: imp code commit bc610f8f438549e8c6b04c8a213b5422dda2aff5 Author: Ainar Garipov Date: Fri Mar 12 20:00:14 2021 +0300 all: imp code, err handling --- internal/dhcpd/db.go | 6 +++-- internal/dhcpd/http.go | 4 +-- internal/home/control.go | 4 +-- internal/home/controlfiltering.go | 41 ++++++++++++++++++++----------- internal/home/home.go | 8 +++--- internal/home/mobileconfig.go | 2 +- internal/home/service.go | 2 +- internal/home/web.go | 10 ++++++-- internal/querylog/http.go | 15 ++++++++--- internal/querylog/qlogreader.go | 14 +++++++++-- internal/querylog/search.go | 25 ++++++------------- internal/updater/updater.go | 36 ++++++++++++++++----------- internal/util/autohosts.go | 31 ++++++++++++----------- internal/util/helpers.go | 17 ------------- internal/util/tls.go | 5 ++-- 15 files changed, 121 insertions(+), 99 deletions(-) diff --git a/internal/dhcpd/db.go b/internal/dhcpd/db.go index 2618f12e..7408c9e0 100644 --- a/internal/dhcpd/db.go +++ b/internal/dhcpd/db.go @@ -4,6 +4,7 @@ package dhcpd import ( "encoding/json" + "errors" "io/ioutil" "net" "os" @@ -39,9 +40,10 @@ func (s *Server) dbLoad() { data, err := ioutil.ReadFile(s.conf.DBFilePath) if err != nil { - if !os.IsNotExist(err) { - log.Error("DHCP: can't read file %s: %v", s.conf.DBFilePath, err) + if !errors.Is(err, os.ErrNotExist) { + log.Error("dhcp: can't read file %q: %v", s.conf.DBFilePath, err) } + return } diff --git a/internal/dhcpd/http.go b/internal/dhcpd/http.go index b0817ab6..9fea7bf6 100644 --- a/internal/dhcpd/http.go +++ b/internal/dhcpd/http.go @@ -508,8 +508,8 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) { s.Stop() err := os.Remove(s.conf.DBFilePath) - if err != nil && !os.IsNotExist(err) { - log.Error("DHCP: os.Remove: %s: %s", s.conf.DBFilePath, err) + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Error("dhcp: removing %q: %s", s.conf.DBFilePath, err) } oldconf := s.conf diff --git a/internal/home/control.go b/internal/home/control.go index 19876b12..e7724794 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -239,7 +239,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) { } httpsURL := &url.URL{ - Scheme: "https", + Scheme: schemeHTTPS, Host: hostPort, Path: r.URL.Path, RawQuery: r.URL.RawQuery, @@ -255,7 +255,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) { // // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin. originURL := &url.URL{ - Scheme: "http", + Scheme: schemeHTTP, Host: r.Host, } w.Header().Set("Access-Control-Allow-Origin", originURL.String()) diff --git a/internal/home/controlfiltering.go b/internal/home/controlfiltering.go index b64423d1..2de312e4 100644 --- a/internal/home/controlfiltering.go +++ b/internal/home/controlfiltering.go @@ -12,26 +12,31 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) -// isValidURL - return TRUE if URL or file path is valid -func isValidURL(rawurl string) bool { - if filepath.IsAbs(rawurl) { - // this is a file path - return util.FileExists(rawurl) +// validateFilterURL validates the filter list URL or file name. +func validateFilterURL(urlStr string) (err error) { + if filepath.IsAbs(urlStr) { + _, err = os.Stat(urlStr) + if err != nil { + return fmt.Errorf("checking filter file: %w", err) + } + + return nil } - url, err := url.ParseRequestURI(rawurl) + url, err := url.ParseRequestURI(urlStr) if err != nil { - return false // Couldn't even parse the rawurl + return fmt.Errorf("checking filter url: %w", err) } - if len(url.Scheme) == 0 { - return false // No Scheme found + + if s := url.Scheme; s != schemeHTTP && s != schemeHTTPS { + return fmt.Errorf("checking filter url: invalid scheme %q", s) } - return true + + return nil } type filterAddJSON struct { @@ -48,8 +53,11 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request return } - if !isValidURL(fj.URL) { - http.Error(w, "Invalid URL or file path", http.StatusBadRequest) + err = validateFilterURL(fj.URL) + if err != nil { + msg := fmt.Sprintf("invalid url: %s", err) + http.Error(w, msg, http.StatusBadRequest) + return } @@ -168,8 +176,11 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request return } - if !isValidURL(fj.Data.URL) { - http.Error(w, "invalid URL or file path", http.StatusBadRequest) + err = validateFilterURL(fj.Data.URL) + if err != nil { + msg := fmt.Sprintf("invalid url: %s", err) + http.Error(w, msg, http.StatusBadRequest) + return } diff --git a/internal/home/home.go b/internal/home/home.go index 54bae585..21307c05 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -570,8 +570,8 @@ func loadOptions() options { return o } -// prints IP addresses which user can use to open the admin interface -// proto is either "http" or "https" +// printHTTPAddresses prints the IP addresses which user can use to open the +// admin interface. proto is either schemeHTTP or schemeHTTPS. func printHTTPAddresses(proto string) { tlsConf := tlsConfigSettings{} if Context.tls != nil { @@ -579,12 +579,12 @@ func printHTTPAddresses(proto string) { } port := strconv.Itoa(config.BindPort) - if proto == "https" { + if proto == schemeHTTPS { port = strconv.Itoa(tlsConf.PortHTTPS) } var hostStr string - if proto == "https" && tlsConf.ServerName != "" { + if proto == schemeHTTPS && tlsConf.ServerName != "" { if tlsConf.PortHTTPS == 443 { log.Printf("Go to https://%s", tlsConf.ServerName) } else { diff --git a/internal/home/mobileconfig.go b/internal/home/mobileconfig.go index 3953e2e6..9ea3e7f5 100644 --- a/internal/home/mobileconfig.go +++ b/internal/home/mobileconfig.go @@ -58,7 +58,7 @@ func getMobileConfig(d dnsSettings) ([]byte, error) { dspName = fmt.Sprintf("%s DoH", d.ServerName) u := &url.URL{ - Scheme: "https", + Scheme: schemeHTTPS, Host: d.ServerName, Path: "/dns-query", } diff --git a/internal/home/service.go b/internal/home/service.go index 94364cbb..caebb582 100644 --- a/internal/home/service.go +++ b/internal/home/service.go @@ -236,7 +236,7 @@ func handleServiceInstallCommand(s service.Service) { AdGuard Home is successfully installed and will automatically start on boot. There are a few more things that must be configured before you can use it. Click on the link below and follow the Installation Wizard steps to finish setup.`) - printHTTPAddresses("http") + printHTTPAddresses(schemeHTTP) } } diff --git a/internal/home/web.go b/internal/home/web.go index 38ba6dcb..6b650576 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -15,6 +15,12 @@ import ( "github.com/gobuffalo/packr" ) +// HTTP scheme constants. +const ( + schemeHTTP = "http" + schemeHTTPS = "https" +) + const ( // readTimeout is the maximum duration for reading the entire request, // including the body. @@ -159,7 +165,7 @@ func (web *Web) Start() { // this loop is used as an ability to change listening host and/or port for !web.httpsServer.shutdown { - printHTTPAddresses("http") + printHTTPAddresses(schemeHTTP) errs := make(chan error, 2) hostStr := web.conf.BindHost.String() @@ -256,7 +262,7 @@ func (web *Web) tlsServerLoop() { WriteTimeout: web.conf.WriteTimeout, } - printHTTPAddresses("https") + printHTTPAddresses(schemeHTTPS) err := web.httpsServer.server.ListenAndServeTLS("", "") if err != http.ErrServerClosed { cleanupAlways() diff --git a/internal/querylog/http.go b/internal/querylog/http.go index b235c2fe..4f71c21d 100644 --- a/internal/querylog/http.go +++ b/internal/querylog/http.go @@ -8,8 +8,6 @@ import ( "strconv" "time" - "github.com/AdguardTeam/AdGuardHome/internal/util" - "github.com/AdguardTeam/golibs/jsonutil" "github.com/AdguardTeam/golibs/log" ) @@ -127,6 +125,17 @@ func getDoubleQuotesEnclosedValue(s *string) bool { return false } +// inStr checks if string is in the slice of strings. +func inStr(strs []string, str string) (ok bool) { + for _, s := range strs { + if s == str { + return true + } + } + + return false +} + // parseSearchCriteria - parses "searchCriteria" from the specified query parameter func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaType) (bool, searchCriteria, error) { val := q.Get(name) @@ -142,7 +151,7 @@ func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaTyp c.strict = true } - if ct == ctFilteringStatus && !util.ContainsString(filteringStatusValues, c.value) { + if ct == ctFilteringStatus && !inStr(filteringStatusValues, c.value) { return false, c, fmt.Errorf("invalid value %s", c.value) } diff --git a/internal/querylog/qlogreader.go b/internal/querylog/qlogreader.go index 19909110..03a93a0f 100644 --- a/internal/querylog/qlogreader.go +++ b/internal/querylog/qlogreader.go @@ -4,8 +4,10 @@ import ( "errors" "fmt" "io" + "os" "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/golibs/log" ) // QLogReader allows reading from multiple query log files in the reverse order. @@ -30,8 +32,16 @@ func NewQLogReader(files []string) (*QLogReader, error) { for _, f := range files { q, err := NewQLogFile(f) if err != nil { - // Close what we've already opened - _ = closeQFiles(qFiles) + if errors.Is(err, os.ErrNotExist) { + continue + } + + // Close what we've already opened. + cerr := closeQFiles(qFiles) + if cerr != nil { + log.Debug("querylog: closing files: %s", cerr) + } + return nil, err } diff --git a/internal/querylog/search.go b/internal/querylog/search.go index b114c4e9..216f9167 100644 --- a/internal/querylog/search.go +++ b/internal/querylog/search.go @@ -5,7 +5,6 @@ import ( "sort" "time" - "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/golibs/log" ) @@ -87,9 +86,15 @@ func (l *queryLog) searchFiles(params *searchParams) ([]*logEntry, time.Time, in entries := make([]*logEntry, 0) oldest := time.Time{} - r, err := l.openReader() + files := []string{ + l.logFile + ".1", + l.logFile, + } + + r, err := NewQLogReader(files) if err != nil { - log.Error("Failed to open qlog reader: %v", err) + log.Error("querylog: failed to open qlog reader: %s", err) + return entries, oldest, 0 } defer r.Close() @@ -174,17 +179,3 @@ func (l *queryLog) readNextEntry(r *QLogReader, params *searchParams) (*logEntry return &entry, timestamp, nil } - -// openReader - opens QLogReader instance -func (l *queryLog) openReader() (*QLogReader, error) { - files := make([]string, 0) - - if util.FileExists(l.logFile + ".1") { - files = append(files, l.logFile+".1") - } - if util.FileExists(l.logFile) { - files = append(files, l.logFile) - } - - return NewQLogReader(files) -} diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 4bc96fc3..d6a09a04 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -5,6 +5,7 @@ import ( "archive/tar" "archive/zip" "compress/gzip" + "errors" "fmt" "io" "io/ioutil" @@ -19,7 +20,6 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghio" - "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/log" ) @@ -160,7 +160,7 @@ func (u *Updater) VersionCheckURL() (vcu string) { return u.versionCheckURL } -func (u *Updater) prepare() error { +func (u *Updater) prepare() (err error) { u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion)) _, pkgNameOnly := filepath.Split(u.packageURL) @@ -181,12 +181,11 @@ func (u *Updater) prepare() error { log.Info("Updating from %s to %s. URL:%s", version.Version(), u.newVersion, u.packageURL) - // If the binary file isn't found in working directory, we won't be able - // to auto-update. Getting the full path to the current binary file on - // Unix and checking write permissions is more difficult. + // TODO(a.garipov): Use os.Args[0] instead? u.currentExeName = filepath.Join(u.workDir, exeName) - if !util.FileExists(u.currentExeName) { - return fmt.Errorf("executable file %s doesn't exist", u.currentExeName) + _, err = os.Stat(u.currentExeName) + if err != nil { + return fmt.Errorf("checking %q: %w", u.currentExeName, err) } return nil @@ -367,11 +366,14 @@ func tarGzFileUnpack(tarfile, outdir string) ([]string, error) { } err = os.Mkdir(outputName, os.FileMode(header.Mode&0o777)) - if err != nil && !os.IsExist(err) { - err2 = fmt.Errorf("os.Mkdir(%s): %w", outputName, err) + if err != nil && !errors.Is(err, os.ErrExist) { + err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err) + break } - log.Debug("updater: created directory %s", outputName) + + log.Debug("updater: created directory %q", outputName) + continue } else if header.Typeflag != tar.TypeReg { log.Debug("updater: %s: unknown file type %d, skipping", inputNameOnly, header.Typeflag) @@ -443,11 +445,14 @@ func zipFileUnpack(zipfile, outdir string) ([]string, error) { } err = os.Mkdir(outputName, fi.Mode()) - if err != nil && !os.IsExist(err) { - err2 = fmt.Errorf("os.Mkdir(): %w", err) + if err != nil && !errors.Is(err, os.ErrExist) { + err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err) + break } - log.Tracef("created directory %s", outputName) + + log.Tracef("created directory %q", outputName) + continue } @@ -501,11 +506,12 @@ func copySupportingFiles(files []string, srcdir, dstdir string) error { dst := filepath.Join(dstdir, name) err := copyFile(src, dst) - if err != nil && !os.IsNotExist(err) { + if err != nil && !errors.Is(err, os.ErrNotExist) { return err } - log.Debug("updater: copied: %s -> %s", src, dst) + log.Debug("updater: copied: %q -> %q", src, dst) } + return nil } diff --git a/internal/util/autohosts.go b/internal/util/autohosts.go index ad157171..c3156920 100644 --- a/internal/util/autohosts.go +++ b/internal/util/autohosts.go @@ -2,10 +2,12 @@ package util import ( "bufio" + "errors" "io" "io/ioutil" "net" "os" + "path/filepath" "runtime" "strings" "sync" @@ -76,7 +78,7 @@ func (a *AutoHosts) Init(hostsFn string) { var err error a.watcher, err = fsnotify.NewWatcher() if err != nil { - log.Error("AutoHosts: %s", err) + log.Error("autohosts: %s", err) } } @@ -127,7 +129,7 @@ func (a *AutoHosts) Process(host string, qtype uint16) []net.IP { copy(ipsCopy, ips) } - log.Debug("AutoHosts: answer: %s -> %v", host, ipsCopy) + log.Debug("autohosts: answer: %s -> %v", host, ipsCopy) return ipsCopy } @@ -153,7 +155,7 @@ func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) (hosts []string) { return nil // not found } - log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, hosts) + log.Debug("autohosts: reverse-lookup: %s -> %s", addr, hosts) return hosts } @@ -191,7 +193,7 @@ func (a *AutoHosts) updateTable(table map[string][]net.IP, host string, ipAddr n ok = true } if ok { - log.Debug("AutoHosts: added %s -> %s", ipAddr, host) + log.Debug("autohosts: added %s -> %s", ipAddr, host) } } @@ -201,7 +203,7 @@ func (a *AutoHosts) updateTableRev(tableRev map[string][]string, newHost string, hosts, ok := tableRev[ipStr] if !ok { tableRev[ipStr] = []string{newHost} - log.Debug("AutoHosts: added reverse-address %s -> %s", ipStr, newHost) + log.Debug("autohosts: added reverse-address %s -> %s", ipStr, newHost) return } @@ -213,7 +215,7 @@ func (a *AutoHosts) updateTableRev(tableRev map[string][]string, newHost string, } tableRev[ipStr] = append(tableRev[ipStr], newHost) - log.Debug("AutoHosts: added reverse-address %s -> %s", ipStr, newHost) + log.Debug("autohosts: added reverse-address %s -> %s", ipStr, newHost) } // Read IP-hostname pairs from file @@ -221,12 +223,12 @@ func (a *AutoHosts) updateTableRev(tableRev map[string][]string, newHost string, func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string][]string, fn string) { f, err := os.Open(fn) if err != nil { - log.Error("AutoHosts: %s", err) + log.Error("autohosts: %s", err) return } defer f.Close() r := bufio.NewReader(f) - log.Debug("AutoHosts: loading hosts from file %s", fn) + log.Debug("autohosts: loading hosts from file %s", fn) for done := false; !done; { var line string @@ -234,7 +236,7 @@ func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string][]string if err == io.EOF { done = true } else if err != nil { - log.Error("AutoHosts: %s", err) + log.Error("autohosts: %s", err) return } @@ -310,7 +312,7 @@ func (a *AutoHosts) watcherLoop() { } if event.Op&fsnotify.Write == fsnotify.Write { - log.Debug("AutoHosts: modified: %s", event.Name) + log.Debug("autohosts: modified: %s", event.Name) a.updateHosts() } @@ -318,7 +320,7 @@ func (a *AutoHosts) watcherLoop() { if !ok { return } - log.Error("AutoHosts: %s", err) + log.Error("autohosts: %s", err) } } } @@ -333,14 +335,15 @@ func (a *AutoHosts) updateHosts() { for _, dir := range a.hostsDirs { fis, err := ioutil.ReadDir(dir) if err != nil { - if !os.IsNotExist(err) { - log.Error("AutoHosts: Opening directory: %s: %s", dir, err) + if !errors.Is(err, os.ErrNotExist) { + log.Error("autohosts: Opening directory: %q: %s", dir, err) } + continue } for _, fi := range fis { - a.load(table, tableRev, dir+"/"+fi.Name()) + a.load(table, tableRev, filepath.Join(dir, fi.Name())) } } diff --git a/internal/util/helpers.go b/internal/util/helpers.go index 0dad08b9..2ccc9bbe 100644 --- a/internal/util/helpers.go +++ b/internal/util/helpers.go @@ -8,29 +8,12 @@ import ( "bytes" "fmt" "io/ioutil" - "os" "os/exec" "path/filepath" "runtime" "strings" ) -// ContainsString checks if string is in the slice of strings. -func ContainsString(strs []string, str string) bool { - for _, s := range strs { - if s == str { - return true - } - } - return false -} - -// FileExists returns true if file exists. -func FileExists(fn string) bool { - _, err := os.Stat(fn) - return err == nil || !os.IsNotExist(err) -} - // RunCommand runs shell command. func RunCommand(command string, arguments ...string) (int, string, error) { cmd := exec.Command(command, arguments...) diff --git a/internal/util/tls.go b/internal/util/tls.go index f1aa4976..ed70d24f 100644 --- a/internal/util/tls.go +++ b/internal/util/tls.go @@ -3,6 +3,7 @@ package util import ( "crypto/tls" "crypto/x509" + "errors" "io/ioutil" "os" "runtime" @@ -29,8 +30,8 @@ func LoadSystemRootCAs() *x509.CertPool { for _, dir := range dirs { fis, err := ioutil.ReadDir(dir) if err != nil { - if !os.IsNotExist(err) { - log.Error("Opening directory: %s: %s", dir, err) + if !errors.Is(err, os.ErrNotExist) { + log.Error("opening directory: %q: %s", dir, err) } continue