Pull request: all: imp code, err handling

Closes #2571.

Squashed commit of the following:

commit a5b50ee011a995f4ab3d93314acd6f0ca82d99cf
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Mar 15 14:05:25 2021 +0300

    all: imp code

commit bc610f8f438549e8c6b04c8a213b5422dda2aff5
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Mar 12 20:00:14 2021 +0300

    all: imp code, err handling
This commit is contained in:
Ainar Garipov 2021-03-15 14:19:04 +03:00
parent d970b79f2b
commit 313fd7107f
15 changed files with 121 additions and 99 deletions

View File

@ -4,6 +4,7 @@ package dhcpd
import (
"encoding/json"
"errors"
"io/ioutil"
"net"
"os"
@ -39,9 +40,10 @@ func (s *Server) dbLoad() {
data, err := ioutil.ReadFile(s.conf.DBFilePath)
if err != nil {
if !os.IsNotExist(err) {
log.Error("DHCP: can't read file %s: %v", s.conf.DBFilePath, err)
if !errors.Is(err, os.ErrNotExist) {
log.Error("dhcp: can't read file %q: %v", s.conf.DBFilePath, err)
}
return
}

View File

@ -508,8 +508,8 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
s.Stop()
err := os.Remove(s.conf.DBFilePath)
if err != nil && !os.IsNotExist(err) {
log.Error("DHCP: os.Remove: %s: %s", s.conf.DBFilePath, err)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("dhcp: removing %q: %s", s.conf.DBFilePath, err)
}
oldconf := s.conf

View File

@ -239,7 +239,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
}
httpsURL := &url.URL{
Scheme: "https",
Scheme: schemeHTTPS,
Host: hostPort,
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
@ -255,7 +255,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
//
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
originURL := &url.URL{
Scheme: "http",
Scheme: schemeHTTP,
Host: r.Host,
}
w.Header().Set("Access-Control-Allow-Origin", originURL.String())

View File

@ -12,26 +12,31 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// isValidURL - return TRUE if URL or file path is valid
func isValidURL(rawurl string) bool {
if filepath.IsAbs(rawurl) {
// this is a file path
return util.FileExists(rawurl)
// validateFilterURL validates the filter list URL or file name.
func validateFilterURL(urlStr string) (err error) {
if filepath.IsAbs(urlStr) {
_, err = os.Stat(urlStr)
if err != nil {
return fmt.Errorf("checking filter file: %w", err)
}
return nil
}
url, err := url.ParseRequestURI(rawurl)
url, err := url.ParseRequestURI(urlStr)
if err != nil {
return false // Couldn't even parse the rawurl
return fmt.Errorf("checking filter url: %w", err)
}
if len(url.Scheme) == 0 {
return false // No Scheme found
if s := url.Scheme; s != schemeHTTP && s != schemeHTTPS {
return fmt.Errorf("checking filter url: invalid scheme %q", s)
}
return true
return nil
}
type filterAddJSON struct {
@ -48,8 +53,11 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
return
}
if !isValidURL(fj.URL) {
http.Error(w, "Invalid URL or file path", http.StatusBadRequest)
err = validateFilterURL(fj.URL)
if err != nil {
msg := fmt.Sprintf("invalid url: %s", err)
http.Error(w, msg, http.StatusBadRequest)
return
}
@ -168,8 +176,11 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
return
}
if !isValidURL(fj.Data.URL) {
http.Error(w, "invalid URL or file path", http.StatusBadRequest)
err = validateFilterURL(fj.Data.URL)
if err != nil {
msg := fmt.Sprintf("invalid url: %s", err)
http.Error(w, msg, http.StatusBadRequest)
return
}

View File

@ -570,8 +570,8 @@ func loadOptions() options {
return o
}
// prints IP addresses which user can use to open the admin interface
// proto is either "http" or "https"
// printHTTPAddresses prints the IP addresses which user can use to open the
// admin interface. proto is either schemeHTTP or schemeHTTPS.
func printHTTPAddresses(proto string) {
tlsConf := tlsConfigSettings{}
if Context.tls != nil {
@ -579,12 +579,12 @@ func printHTTPAddresses(proto string) {
}
port := strconv.Itoa(config.BindPort)
if proto == "https" {
if proto == schemeHTTPS {
port = strconv.Itoa(tlsConf.PortHTTPS)
}
var hostStr string
if proto == "https" && tlsConf.ServerName != "" {
if proto == schemeHTTPS && tlsConf.ServerName != "" {
if tlsConf.PortHTTPS == 443 {
log.Printf("Go to https://%s", tlsConf.ServerName)
} else {

View File

@ -58,7 +58,7 @@ func getMobileConfig(d dnsSettings) ([]byte, error) {
dspName = fmt.Sprintf("%s DoH", d.ServerName)
u := &url.URL{
Scheme: "https",
Scheme: schemeHTTPS,
Host: d.ServerName,
Path: "/dns-query",
}

View File

@ -236,7 +236,7 @@ func handleServiceInstallCommand(s service.Service) {
AdGuard Home is successfully installed and will automatically start on boot.
There are a few more things that must be configured before you can use it.
Click on the link below and follow the Installation Wizard steps to finish setup.`)
printHTTPAddresses("http")
printHTTPAddresses(schemeHTTP)
}
}

View File

@ -15,6 +15,12 @@ import (
"github.com/gobuffalo/packr"
)
// HTTP scheme constants.
const (
schemeHTTP = "http"
schemeHTTPS = "https"
)
const (
// readTimeout is the maximum duration for reading the entire request,
// including the body.
@ -159,7 +165,7 @@ func (web *Web) Start() {
// this loop is used as an ability to change listening host and/or port
for !web.httpsServer.shutdown {
printHTTPAddresses("http")
printHTTPAddresses(schemeHTTP)
errs := make(chan error, 2)
hostStr := web.conf.BindHost.String()
@ -256,7 +262,7 @@ func (web *Web) tlsServerLoop() {
WriteTimeout: web.conf.WriteTimeout,
}
printHTTPAddresses("https")
printHTTPAddresses(schemeHTTPS)
err := web.httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed {
cleanupAlways()

View File

@ -8,8 +8,6 @@ import (
"strconv"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log"
)
@ -127,6 +125,17 @@ func getDoubleQuotesEnclosedValue(s *string) bool {
return false
}
// inStr checks if string is in the slice of strings.
func inStr(strs []string, str string) (ok bool) {
for _, s := range strs {
if s == str {
return true
}
}
return false
}
// parseSearchCriteria - parses "searchCriteria" from the specified query parameter
func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaType) (bool, searchCriteria, error) {
val := q.Get(name)
@ -142,7 +151,7 @@ func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaTyp
c.strict = true
}
if ct == ctFilteringStatus && !util.ContainsString(filteringStatusValues, c.value) {
if ct == ctFilteringStatus && !inStr(filteringStatusValues, c.value) {
return false, c, fmt.Errorf("invalid value %s", c.value)
}

View File

@ -4,8 +4,10 @@ import (
"errors"
"fmt"
"io"
"os"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/log"
)
// QLogReader allows reading from multiple query log files in the reverse order.
@ -30,8 +32,16 @@ func NewQLogReader(files []string) (*QLogReader, error) {
for _, f := range files {
q, err := NewQLogFile(f)
if err != nil {
// Close what we've already opened
_ = closeQFiles(qFiles)
if errors.Is(err, os.ErrNotExist) {
continue
}
// Close what we've already opened.
cerr := closeQFiles(qFiles)
if cerr != nil {
log.Debug("querylog: closing files: %s", cerr)
}
return nil, err
}

View File

@ -5,7 +5,6 @@ import (
"sort"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
)
@ -87,9 +86,15 @@ func (l *queryLog) searchFiles(params *searchParams) ([]*logEntry, time.Time, in
entries := make([]*logEntry, 0)
oldest := time.Time{}
r, err := l.openReader()
files := []string{
l.logFile + ".1",
l.logFile,
}
r, err := NewQLogReader(files)
if err != nil {
log.Error("Failed to open qlog reader: %v", err)
log.Error("querylog: failed to open qlog reader: %s", err)
return entries, oldest, 0
}
defer r.Close()
@ -174,17 +179,3 @@ func (l *queryLog) readNextEntry(r *QLogReader, params *searchParams) (*logEntry
return &entry, timestamp, nil
}
// openReader - opens QLogReader instance
func (l *queryLog) openReader() (*QLogReader, error) {
files := make([]string, 0)
if util.FileExists(l.logFile + ".1") {
files = append(files, l.logFile+".1")
}
if util.FileExists(l.logFile) {
files = append(files, l.logFile)
}
return NewQLogReader(files)
}

View File

@ -5,6 +5,7 @@ import (
"archive/tar"
"archive/zip"
"compress/gzip"
"errors"
"fmt"
"io"
"io/ioutil"
@ -19,7 +20,6 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log"
)
@ -160,7 +160,7 @@ func (u *Updater) VersionCheckURL() (vcu string) {
return u.versionCheckURL
}
func (u *Updater) prepare() error {
func (u *Updater) prepare() (err error) {
u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion))
_, pkgNameOnly := filepath.Split(u.packageURL)
@ -181,12 +181,11 @@ func (u *Updater) prepare() error {
log.Info("Updating from %s to %s. URL:%s", version.Version(), u.newVersion, u.packageURL)
// If the binary file isn't found in working directory, we won't be able
// to auto-update. Getting the full path to the current binary file on
// Unix and checking write permissions is more difficult.
// TODO(a.garipov): Use os.Args[0] instead?
u.currentExeName = filepath.Join(u.workDir, exeName)
if !util.FileExists(u.currentExeName) {
return fmt.Errorf("executable file %s doesn't exist", u.currentExeName)
_, err = os.Stat(u.currentExeName)
if err != nil {
return fmt.Errorf("checking %q: %w", u.currentExeName, err)
}
return nil
@ -367,11 +366,14 @@ func tarGzFileUnpack(tarfile, outdir string) ([]string, error) {
}
err = os.Mkdir(outputName, os.FileMode(header.Mode&0o777))
if err != nil && !os.IsExist(err) {
err2 = fmt.Errorf("os.Mkdir(%s): %w", outputName, err)
if err != nil && !errors.Is(err, os.ErrExist) {
err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
break
}
log.Debug("updater: created directory %s", outputName)
log.Debug("updater: created directory %q", outputName)
continue
} else if header.Typeflag != tar.TypeReg {
log.Debug("updater: %s: unknown file type %d, skipping", inputNameOnly, header.Typeflag)
@ -443,11 +445,14 @@ func zipFileUnpack(zipfile, outdir string) ([]string, error) {
}
err = os.Mkdir(outputName, fi.Mode())
if err != nil && !os.IsExist(err) {
err2 = fmt.Errorf("os.Mkdir(): %w", err)
if err != nil && !errors.Is(err, os.ErrExist) {
err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
break
}
log.Tracef("created directory %s", outputName)
log.Tracef("created directory %q", outputName)
continue
}
@ -501,11 +506,12 @@ func copySupportingFiles(files []string, srcdir, dstdir string) error {
dst := filepath.Join(dstdir, name)
err := copyFile(src, dst)
if err != nil && !os.IsNotExist(err) {
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
log.Debug("updater: copied: %s -> %s", src, dst)
log.Debug("updater: copied: %q -> %q", src, dst)
}
return nil
}

View File

@ -2,10 +2,12 @@ package util
import (
"bufio"
"errors"
"io"
"io/ioutil"
"net"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
@ -76,7 +78,7 @@ func (a *AutoHosts) Init(hostsFn string) {
var err error
a.watcher, err = fsnotify.NewWatcher()
if err != nil {
log.Error("AutoHosts: %s", err)
log.Error("autohosts: %s", err)
}
}
@ -127,7 +129,7 @@ func (a *AutoHosts) Process(host string, qtype uint16) []net.IP {
copy(ipsCopy, ips)
}
log.Debug("AutoHosts: answer: %s -> %v", host, ipsCopy)
log.Debug("autohosts: answer: %s -> %v", host, ipsCopy)
return ipsCopy
}
@ -153,7 +155,7 @@ func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) (hosts []string) {
return nil // not found
}
log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, hosts)
log.Debug("autohosts: reverse-lookup: %s -> %s", addr, hosts)
return hosts
}
@ -191,7 +193,7 @@ func (a *AutoHosts) updateTable(table map[string][]net.IP, host string, ipAddr n
ok = true
}
if ok {
log.Debug("AutoHosts: added %s -> %s", ipAddr, host)
log.Debug("autohosts: added %s -> %s", ipAddr, host)
}
}
@ -201,7 +203,7 @@ func (a *AutoHosts) updateTableRev(tableRev map[string][]string, newHost string,
hosts, ok := tableRev[ipStr]
if !ok {
tableRev[ipStr] = []string{newHost}
log.Debug("AutoHosts: added reverse-address %s -> %s", ipStr, newHost)
log.Debug("autohosts: added reverse-address %s -> %s", ipStr, newHost)
return
}
@ -213,7 +215,7 @@ func (a *AutoHosts) updateTableRev(tableRev map[string][]string, newHost string,
}
tableRev[ipStr] = append(tableRev[ipStr], newHost)
log.Debug("AutoHosts: added reverse-address %s -> %s", ipStr, newHost)
log.Debug("autohosts: added reverse-address %s -> %s", ipStr, newHost)
}
// Read IP-hostname pairs from file
@ -221,12 +223,12 @@ func (a *AutoHosts) updateTableRev(tableRev map[string][]string, newHost string,
func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string][]string, fn string) {
f, err := os.Open(fn)
if err != nil {
log.Error("AutoHosts: %s", err)
log.Error("autohosts: %s", err)
return
}
defer f.Close()
r := bufio.NewReader(f)
log.Debug("AutoHosts: loading hosts from file %s", fn)
log.Debug("autohosts: loading hosts from file %s", fn)
for done := false; !done; {
var line string
@ -234,7 +236,7 @@ func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string][]string
if err == io.EOF {
done = true
} else if err != nil {
log.Error("AutoHosts: %s", err)
log.Error("autohosts: %s", err)
return
}
@ -310,7 +312,7 @@ func (a *AutoHosts) watcherLoop() {
}
if event.Op&fsnotify.Write == fsnotify.Write {
log.Debug("AutoHosts: modified: %s", event.Name)
log.Debug("autohosts: modified: %s", event.Name)
a.updateHosts()
}
@ -318,7 +320,7 @@ func (a *AutoHosts) watcherLoop() {
if !ok {
return
}
log.Error("AutoHosts: %s", err)
log.Error("autohosts: %s", err)
}
}
}
@ -333,14 +335,15 @@ func (a *AutoHosts) updateHosts() {
for _, dir := range a.hostsDirs {
fis, err := ioutil.ReadDir(dir)
if err != nil {
if !os.IsNotExist(err) {
log.Error("AutoHosts: Opening directory: %s: %s", dir, err)
if !errors.Is(err, os.ErrNotExist) {
log.Error("autohosts: Opening directory: %q: %s", dir, err)
}
continue
}
for _, fi := range fis {
a.load(table, tableRev, dir+"/"+fi.Name())
a.load(table, tableRev, filepath.Join(dir, fi.Name()))
}
}

View File

@ -8,29 +8,12 @@ import (
"bytes"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
)
// ContainsString checks if string is in the slice of strings.
func ContainsString(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
}
return false
}
// FileExists returns true if file exists.
func FileExists(fn string) bool {
_, err := os.Stat(fn)
return err == nil || !os.IsNotExist(err)
}
// RunCommand runs shell command.
func RunCommand(command string, arguments ...string) (int, string, error) {
cmd := exec.Command(command, arguments...)

View File

@ -3,6 +3,7 @@ package util
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"os"
"runtime"
@ -29,8 +30,8 @@ func LoadSystemRootCAs() *x509.CertPool {
for _, dir := range dirs {
fis, err := ioutil.ReadDir(dir)
if err != nil {
if !os.IsNotExist(err) {
log.Error("Opening directory: %s: %s", dir, err)
if !errors.Is(err, os.ErrNotExist) {
log.Error("opening directory: %q: %s", dir, err)
}
continue