Pull request: all: imp code, err handling

Closes #2571.

Squashed commit of the following:

commit a5b50ee011a995f4ab3d93314acd6f0ca82d99cf
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Mar 15 14:05:25 2021 +0300

    all: imp code

commit bc610f8f438549e8c6b04c8a213b5422dda2aff5
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Mar 12 20:00:14 2021 +0300

    all: imp code, err handling
This commit is contained in:
Ainar Garipov 2021-03-15 14:19:04 +03:00
parent d970b79f2b
commit 313fd7107f
15 changed files with 121 additions and 99 deletions

View File

@ -4,6 +4,7 @@ package dhcpd
import ( import (
"encoding/json" "encoding/json"
"errors"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
@ -39,9 +40,10 @@ func (s *Server) dbLoad() {
data, err := ioutil.ReadFile(s.conf.DBFilePath) data, err := ioutil.ReadFile(s.conf.DBFilePath)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !errors.Is(err, os.ErrNotExist) {
log.Error("DHCP: can't read file %s: %v", s.conf.DBFilePath, err) log.Error("dhcp: can't read file %q: %v", s.conf.DBFilePath, err)
} }
return return
} }

View File

@ -508,8 +508,8 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
s.Stop() s.Stop()
err := os.Remove(s.conf.DBFilePath) err := os.Remove(s.conf.DBFilePath)
if err != nil && !os.IsNotExist(err) { if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("DHCP: os.Remove: %s: %s", s.conf.DBFilePath, err) log.Error("dhcp: removing %q: %s", s.conf.DBFilePath, err)
} }
oldconf := s.conf oldconf := s.conf

View File

@ -239,7 +239,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
} }
httpsURL := &url.URL{ httpsURL := &url.URL{
Scheme: "https", Scheme: schemeHTTPS,
Host: hostPort, Host: hostPort,
Path: r.URL.Path, Path: r.URL.Path,
RawQuery: r.URL.RawQuery, RawQuery: r.URL.RawQuery,
@ -255,7 +255,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
// //
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin. // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
originURL := &url.URL{ originURL := &url.URL{
Scheme: "http", Scheme: schemeHTTP,
Host: r.Host, Host: r.Host,
} }
w.Header().Set("Access-Control-Allow-Origin", originURL.String()) w.Header().Set("Access-Control-Allow-Origin", originURL.String())

View File

@ -12,26 +12,31 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// isValidURL - return TRUE if URL or file path is valid // validateFilterURL validates the filter list URL or file name.
func isValidURL(rawurl string) bool { func validateFilterURL(urlStr string) (err error) {
if filepath.IsAbs(rawurl) { if filepath.IsAbs(urlStr) {
// this is a file path _, err = os.Stat(urlStr)
return util.FileExists(rawurl) if err != nil {
return fmt.Errorf("checking filter file: %w", err)
}
return nil
} }
url, err := url.ParseRequestURI(rawurl) url, err := url.ParseRequestURI(urlStr)
if err != nil { if err != nil {
return false // Couldn't even parse the rawurl return fmt.Errorf("checking filter url: %w", err)
} }
if len(url.Scheme) == 0 {
return false // No Scheme found if s := url.Scheme; s != schemeHTTP && s != schemeHTTPS {
return fmt.Errorf("checking filter url: invalid scheme %q", s)
} }
return true
return nil
} }
type filterAddJSON struct { type filterAddJSON struct {
@ -48,8 +53,11 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
return return
} }
if !isValidURL(fj.URL) { err = validateFilterURL(fj.URL)
http.Error(w, "Invalid URL or file path", http.StatusBadRequest) if err != nil {
msg := fmt.Sprintf("invalid url: %s", err)
http.Error(w, msg, http.StatusBadRequest)
return return
} }
@ -168,8 +176,11 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
return return
} }
if !isValidURL(fj.Data.URL) { err = validateFilterURL(fj.Data.URL)
http.Error(w, "invalid URL or file path", http.StatusBadRequest) if err != nil {
msg := fmt.Sprintf("invalid url: %s", err)
http.Error(w, msg, http.StatusBadRequest)
return return
} }

View File

@ -570,8 +570,8 @@ func loadOptions() options {
return o return o
} }
// prints IP addresses which user can use to open the admin interface // printHTTPAddresses prints the IP addresses which user can use to open the
// proto is either "http" or "https" // admin interface. proto is either schemeHTTP or schemeHTTPS.
func printHTTPAddresses(proto string) { func printHTTPAddresses(proto string) {
tlsConf := tlsConfigSettings{} tlsConf := tlsConfigSettings{}
if Context.tls != nil { if Context.tls != nil {
@ -579,12 +579,12 @@ func printHTTPAddresses(proto string) {
} }
port := strconv.Itoa(config.BindPort) port := strconv.Itoa(config.BindPort)
if proto == "https" { if proto == schemeHTTPS {
port = strconv.Itoa(tlsConf.PortHTTPS) port = strconv.Itoa(tlsConf.PortHTTPS)
} }
var hostStr string var hostStr string
if proto == "https" && tlsConf.ServerName != "" { if proto == schemeHTTPS && tlsConf.ServerName != "" {
if tlsConf.PortHTTPS == 443 { if tlsConf.PortHTTPS == 443 {
log.Printf("Go to https://%s", tlsConf.ServerName) log.Printf("Go to https://%s", tlsConf.ServerName)
} else { } else {

View File

@ -58,7 +58,7 @@ func getMobileConfig(d dnsSettings) ([]byte, error) {
dspName = fmt.Sprintf("%s DoH", d.ServerName) dspName = fmt.Sprintf("%s DoH", d.ServerName)
u := &url.URL{ u := &url.URL{
Scheme: "https", Scheme: schemeHTTPS,
Host: d.ServerName, Host: d.ServerName,
Path: "/dns-query", Path: "/dns-query",
} }

View File

@ -236,7 +236,7 @@ func handleServiceInstallCommand(s service.Service) {
AdGuard Home is successfully installed and will automatically start on boot. AdGuard Home is successfully installed and will automatically start on boot.
There are a few more things that must be configured before you can use it. There are a few more things that must be configured before you can use it.
Click on the link below and follow the Installation Wizard steps to finish setup.`) Click on the link below and follow the Installation Wizard steps to finish setup.`)
printHTTPAddresses("http") printHTTPAddresses(schemeHTTP)
} }
} }

View File

@ -15,6 +15,12 @@ import (
"github.com/gobuffalo/packr" "github.com/gobuffalo/packr"
) )
// HTTP scheme constants.
const (
schemeHTTP = "http"
schemeHTTPS = "https"
)
const ( const (
// readTimeout is the maximum duration for reading the entire request, // readTimeout is the maximum duration for reading the entire request,
// including the body. // including the body.
@ -159,7 +165,7 @@ func (web *Web) Start() {
// this loop is used as an ability to change listening host and/or port // this loop is used as an ability to change listening host and/or port
for !web.httpsServer.shutdown { for !web.httpsServer.shutdown {
printHTTPAddresses("http") printHTTPAddresses(schemeHTTP)
errs := make(chan error, 2) errs := make(chan error, 2)
hostStr := web.conf.BindHost.String() hostStr := web.conf.BindHost.String()
@ -256,7 +262,7 @@ func (web *Web) tlsServerLoop() {
WriteTimeout: web.conf.WriteTimeout, WriteTimeout: web.conf.WriteTimeout,
} }
printHTTPAddresses("https") printHTTPAddresses(schemeHTTPS)
err := web.httpsServer.server.ListenAndServeTLS("", "") err := web.httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
cleanupAlways() cleanupAlways()

View File

@ -8,8 +8,6 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/jsonutil" "github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -127,6 +125,17 @@ func getDoubleQuotesEnclosedValue(s *string) bool {
return false return false
} }
// inStr checks if string is in the slice of strings.
func inStr(strs []string, str string) (ok bool) {
for _, s := range strs {
if s == str {
return true
}
}
return false
}
// parseSearchCriteria - parses "searchCriteria" from the specified query parameter // parseSearchCriteria - parses "searchCriteria" from the specified query parameter
func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaType) (bool, searchCriteria, error) { func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaType) (bool, searchCriteria, error) {
val := q.Get(name) val := q.Get(name)
@ -142,7 +151,7 @@ func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaTyp
c.strict = true c.strict = true
} }
if ct == ctFilteringStatus && !util.ContainsString(filteringStatusValues, c.value) { if ct == ctFilteringStatus && !inStr(filteringStatusValues, c.value) {
return false, c, fmt.Errorf("invalid value %s", c.value) return false, c, fmt.Errorf("invalid value %s", c.value)
} }

View File

@ -4,8 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os"
"github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/log"
) )
// QLogReader allows reading from multiple query log files in the reverse order. // QLogReader allows reading from multiple query log files in the reverse order.
@ -30,8 +32,16 @@ func NewQLogReader(files []string) (*QLogReader, error) {
for _, f := range files { for _, f := range files {
q, err := NewQLogFile(f) q, err := NewQLogFile(f)
if err != nil { if err != nil {
// Close what we've already opened if errors.Is(err, os.ErrNotExist) {
_ = closeQFiles(qFiles) continue
}
// Close what we've already opened.
cerr := closeQFiles(qFiles)
if cerr != nil {
log.Debug("querylog: closing files: %s", cerr)
}
return nil, err return nil, err
} }

View File

@ -5,7 +5,6 @@ import (
"sort" "sort"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -87,9 +86,15 @@ func (l *queryLog) searchFiles(params *searchParams) ([]*logEntry, time.Time, in
entries := make([]*logEntry, 0) entries := make([]*logEntry, 0)
oldest := time.Time{} oldest := time.Time{}
r, err := l.openReader() files := []string{
l.logFile + ".1",
l.logFile,
}
r, err := NewQLogReader(files)
if err != nil { if err != nil {
log.Error("Failed to open qlog reader: %v", err) log.Error("querylog: failed to open qlog reader: %s", err)
return entries, oldest, 0 return entries, oldest, 0
} }
defer r.Close() defer r.Close()
@ -174,17 +179,3 @@ func (l *queryLog) readNextEntry(r *QLogReader, params *searchParams) (*logEntry
return &entry, timestamp, nil return &entry, timestamp, nil
} }
// openReader - opens QLogReader instance
func (l *queryLog) openReader() (*QLogReader, error) {
files := make([]string, 0)
if util.FileExists(l.logFile + ".1") {
files = append(files, l.logFile+".1")
}
if util.FileExists(l.logFile) {
files = append(files, l.logFile)
}
return NewQLogReader(files)
}

View File

@ -5,6 +5,7 @@ import (
"archive/tar" "archive/tar"
"archive/zip" "archive/zip"
"compress/gzip" "compress/gzip"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -19,7 +20,6 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -160,7 +160,7 @@ func (u *Updater) VersionCheckURL() (vcu string) {
return u.versionCheckURL return u.versionCheckURL
} }
func (u *Updater) prepare() error { func (u *Updater) prepare() (err error) {
u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion)) u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion))
_, pkgNameOnly := filepath.Split(u.packageURL) _, pkgNameOnly := filepath.Split(u.packageURL)
@ -181,12 +181,11 @@ func (u *Updater) prepare() error {
log.Info("Updating from %s to %s. URL:%s", version.Version(), u.newVersion, u.packageURL) log.Info("Updating from %s to %s. URL:%s", version.Version(), u.newVersion, u.packageURL)
// If the binary file isn't found in working directory, we won't be able // TODO(a.garipov): Use os.Args[0] instead?
// to auto-update. Getting the full path to the current binary file on
// Unix and checking write permissions is more difficult.
u.currentExeName = filepath.Join(u.workDir, exeName) u.currentExeName = filepath.Join(u.workDir, exeName)
if !util.FileExists(u.currentExeName) { _, err = os.Stat(u.currentExeName)
return fmt.Errorf("executable file %s doesn't exist", u.currentExeName) if err != nil {
return fmt.Errorf("checking %q: %w", u.currentExeName, err)
} }
return nil return nil
@ -367,11 +366,14 @@ func tarGzFileUnpack(tarfile, outdir string) ([]string, error) {
} }
err = os.Mkdir(outputName, os.FileMode(header.Mode&0o777)) err = os.Mkdir(outputName, os.FileMode(header.Mode&0o777))
if err != nil && !os.IsExist(err) { if err != nil && !errors.Is(err, os.ErrExist) {
err2 = fmt.Errorf("os.Mkdir(%s): %w", outputName, err) err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
break break
} }
log.Debug("updater: created directory %s", outputName)
log.Debug("updater: created directory %q", outputName)
continue continue
} else if header.Typeflag != tar.TypeReg { } else if header.Typeflag != tar.TypeReg {
log.Debug("updater: %s: unknown file type %d, skipping", inputNameOnly, header.Typeflag) log.Debug("updater: %s: unknown file type %d, skipping", inputNameOnly, header.Typeflag)
@ -443,11 +445,14 @@ func zipFileUnpack(zipfile, outdir string) ([]string, error) {
} }
err = os.Mkdir(outputName, fi.Mode()) err = os.Mkdir(outputName, fi.Mode())
if err != nil && !os.IsExist(err) { if err != nil && !errors.Is(err, os.ErrExist) {
err2 = fmt.Errorf("os.Mkdir(): %w", err) err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
break break
} }
log.Tracef("created directory %s", outputName)
log.Tracef("created directory %q", outputName)
continue continue
} }
@ -501,11 +506,12 @@ func copySupportingFiles(files []string, srcdir, dstdir string) error {
dst := filepath.Join(dstdir, name) dst := filepath.Join(dstdir, name)
err := copyFile(src, dst) err := copyFile(src, dst)
if err != nil && !os.IsNotExist(err) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err
} }
log.Debug("updater: copied: %s -> %s", src, dst) log.Debug("updater: copied: %q -> %q", src, dst)
} }
return nil return nil
} }

View File

@ -2,10 +2,12 @@ package util
import ( import (
"bufio" "bufio"
"errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
"path/filepath"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@ -76,7 +78,7 @@ func (a *AutoHosts) Init(hostsFn string) {
var err error var err error
a.watcher, err = fsnotify.NewWatcher() a.watcher, err = fsnotify.NewWatcher()
if err != nil { if err != nil {
log.Error("AutoHosts: %s", err) log.Error("autohosts: %s", err)
} }
} }
@ -127,7 +129,7 @@ func (a *AutoHosts) Process(host string, qtype uint16) []net.IP {
copy(ipsCopy, ips) copy(ipsCopy, ips)
} }
log.Debug("AutoHosts: answer: %s -> %v", host, ipsCopy) log.Debug("autohosts: answer: %s -> %v", host, ipsCopy)
return ipsCopy return ipsCopy
} }
@ -153,7 +155,7 @@ func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) (hosts []string) {
return nil // not found return nil // not found
} }
log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, hosts) log.Debug("autohosts: reverse-lookup: %s -> %s", addr, hosts)
return hosts return hosts
} }
@ -191,7 +193,7 @@ func (a *AutoHosts) updateTable(table map[string][]net.IP, host string, ipAddr n
ok = true ok = true
} }
if ok { if ok {
log.Debug("AutoHosts: added %s -> %s", ipAddr, host) log.Debug("autohosts: added %s -> %s", ipAddr, host)
} }
} }
@ -201,7 +203,7 @@ func (a *AutoHosts) updateTableRev(tableRev map[string][]string, newHost string,
hosts, ok := tableRev[ipStr] hosts, ok := tableRev[ipStr]
if !ok { if !ok {
tableRev[ipStr] = []string{newHost} tableRev[ipStr] = []string{newHost}
log.Debug("AutoHosts: added reverse-address %s -> %s", ipStr, newHost) log.Debug("autohosts: added reverse-address %s -> %s", ipStr, newHost)
return return
} }
@ -213,7 +215,7 @@ func (a *AutoHosts) updateTableRev(tableRev map[string][]string, newHost string,
} }
tableRev[ipStr] = append(tableRev[ipStr], newHost) tableRev[ipStr] = append(tableRev[ipStr], newHost)
log.Debug("AutoHosts: added reverse-address %s -> %s", ipStr, newHost) log.Debug("autohosts: added reverse-address %s -> %s", ipStr, newHost)
} }
// Read IP-hostname pairs from file // Read IP-hostname pairs from file
@ -221,12 +223,12 @@ func (a *AutoHosts) updateTableRev(tableRev map[string][]string, newHost string,
func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string][]string, fn string) { func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string][]string, fn string) {
f, err := os.Open(fn) f, err := os.Open(fn)
if err != nil { if err != nil {
log.Error("AutoHosts: %s", err) log.Error("autohosts: %s", err)
return return
} }
defer f.Close() defer f.Close()
r := bufio.NewReader(f) r := bufio.NewReader(f)
log.Debug("AutoHosts: loading hosts from file %s", fn) log.Debug("autohosts: loading hosts from file %s", fn)
for done := false; !done; { for done := false; !done; {
var line string var line string
@ -234,7 +236,7 @@ func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string][]string
if err == io.EOF { if err == io.EOF {
done = true done = true
} else if err != nil { } else if err != nil {
log.Error("AutoHosts: %s", err) log.Error("autohosts: %s", err)
return return
} }
@ -310,7 +312,7 @@ func (a *AutoHosts) watcherLoop() {
} }
if event.Op&fsnotify.Write == fsnotify.Write { if event.Op&fsnotify.Write == fsnotify.Write {
log.Debug("AutoHosts: modified: %s", event.Name) log.Debug("autohosts: modified: %s", event.Name)
a.updateHosts() a.updateHosts()
} }
@ -318,7 +320,7 @@ func (a *AutoHosts) watcherLoop() {
if !ok { if !ok {
return return
} }
log.Error("AutoHosts: %s", err) log.Error("autohosts: %s", err)
} }
} }
} }
@ -333,14 +335,15 @@ func (a *AutoHosts) updateHosts() {
for _, dir := range a.hostsDirs { for _, dir := range a.hostsDirs {
fis, err := ioutil.ReadDir(dir) fis, err := ioutil.ReadDir(dir)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !errors.Is(err, os.ErrNotExist) {
log.Error("AutoHosts: Opening directory: %s: %s", dir, err) log.Error("autohosts: Opening directory: %q: %s", dir, err)
} }
continue continue
} }
for _, fi := range fis { for _, fi := range fis {
a.load(table, tableRev, dir+"/"+fi.Name()) a.load(table, tableRev, filepath.Join(dir, fi.Name()))
} }
} }

View File

@ -8,29 +8,12 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
) )
// ContainsString checks if string is in the slice of strings.
func ContainsString(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
}
return false
}
// FileExists returns true if file exists.
func FileExists(fn string) bool {
_, err := os.Stat(fn)
return err == nil || !os.IsNotExist(err)
}
// RunCommand runs shell command. // RunCommand runs shell command.
func RunCommand(command string, arguments ...string) (int, string, error) { func RunCommand(command string, arguments ...string) (int, string, error) {
cmd := exec.Command(command, arguments...) cmd := exec.Command(command, arguments...)

View File

@ -3,6 +3,7 @@ package util
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"io/ioutil" "io/ioutil"
"os" "os"
"runtime" "runtime"
@ -29,8 +30,8 @@ func LoadSystemRootCAs() *x509.CertPool {
for _, dir := range dirs { for _, dir := range dirs {
fis, err := ioutil.ReadDir(dir) fis, err := ioutil.ReadDir(dir)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !errors.Is(err, os.ErrNotExist) {
log.Error("Opening directory: %s: %s", dir, err) log.Error("opening directory: %q: %s", dir, err)
} }
continue continue