diff --git a/internal/aghnet/etchostscontainer.go b/internal/aghnet/etchostscontainer.go index 3a1487be..9b6e07ef 100644 --- a/internal/aghnet/etchostscontainer.go +++ b/internal/aghnet/etchostscontainer.go @@ -2,7 +2,6 @@ package aghnet import ( "bufio" - "io" "net" "os" "path/filepath" @@ -231,14 +230,41 @@ func (ehc *EtcHostsContainer) updateTableRev(tableRev map[string][]string, newHo log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost) } -// Read IP-hostname pairs from file -// Multiple hostnames per line (per one IP) is supported. -func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[string][]string, fn string) { +// parseHostsLine parses hosts from the fields. +func parseHostsLine(fields []string) (hosts []string) { + for _, f := range fields { + hashIdx := strings.IndexByte(f, '#') + if hashIdx == 0 { + // The rest of the fields are a part of the comment. + // Skip immediately. + return + } else if hashIdx > 0 { + // Only a part of the field is a comment. + hosts = append(hosts, f[:hashIdx]) + + return hosts + } + + hosts = append(hosts, f) + } + + return hosts +} + +// load reads IP-hostname pairs from the hosts file. Multiple hostnames per +// line for one IP are supported. +func (ehc *EtcHostsContainer) load( + table map[string][]net.IP, + tableRev map[string][]string, + fn string, +) { f, err := os.Open(fn) if err != nil { log.Error("etchostscontainer: %s", err) + return } + defer func() { derr := f.Close() if derr != nil { @@ -246,25 +272,11 @@ func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[strin } }() - r := bufio.NewReader(f) log.Debug("etchostscontainer: loading hosts from file %s", fn) - for done := false; !done; { - var line string - line, err = r.ReadString('\n') - if err == io.EOF { - done = true - } else if err != nil { - log.Error("etchostscontainer: %s", err) - - return - } - - line = strings.TrimSpace(line) - if len(line) == 0 || line[0] == '#' { - continue - } - + s := bufio.NewScanner(f) + for s.Scan() { + line := strings.TrimSpace(s.Text()) fields := strings.Fields(line) if len(fields) < 2 { continue @@ -275,28 +287,17 @@ func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[strin continue } - for i := 1; i != len(fields); i++ { - host := fields[i] - if len(host) == 0 { - break - } - - sharp := strings.IndexByte(host, '#') - if sharp == 0 { - // Skip the comments. - break - } else if sharp > 0 { - host = host[:sharp] - } - + hosts := parseHostsLine(fields[1:]) + for _, host := range hosts { ehc.updateTable(table, host, ip) ehc.updateTableRev(tableRev, host, ip) - if sharp >= 0 { - // Skip the comments again. - break - } } } + + err = s.Err() + if err != nil { + log.Error("etchostscontainer: %s", err) + } } // onlyWrites is a filter for (*fsnotify.Watcher).Events. diff --git a/internal/aghnet/etshostscontainer_test.go b/internal/aghnet/etchostscontainer_test.go similarity index 90% rename from internal/aghnet/etshostscontainer_test.go rename to internal/aghnet/etchostscontainer_test.go index 85c61d4a..74e4f46f 100644 --- a/internal/aghnet/etshostscontainer_test.go +++ b/internal/aghnet/etchostscontainer_test.go @@ -23,10 +23,11 @@ func prepareTestFile(t *testing.T) (f *os.File) { dir := t.TempDir() f, err := os.CreateTemp(dir, "") - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, f) + t.Cleanup(func() { - assert.Nil(t, f.Close()) + assert.NoError(t, f.Close()) }) return f @@ -37,7 +38,7 @@ func assertWriting(t *testing.T, f *os.File, strs ...string) { for _, str := range strs { n, err := f.WriteString(str) - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, n, len(str)) } } @@ -77,16 +78,16 @@ func TestEtcHostsContainerResolution(t *testing.T) { t.Run("ptr", func(t *testing.T) { testCases := []struct { wantIP string - wantLen int wantHost string + wantLen int }{ - {wantIP: "127.0.0.1", wantLen: 2, wantHost: "host"}, - {wantIP: "::1", wantLen: 1, wantHost: "localhost"}, + {wantIP: "127.0.0.1", wantHost: "host", wantLen: 2}, + {wantIP: "::1", wantHost: "localhost", wantLen: 1}, } for _, tc := range testCases { a, err := dns.ReverseAddr(tc.wantIP) - require.Nil(t, err) + require.NoError(t, err) a = strings.TrimSuffix(a, ".") hosts := ehc.ProcessReverse(a, dns.TypePTR) @@ -114,7 +115,7 @@ func TestEtcHostsContainerFSNotify(t *testing.T) { t.Cleanup(ehc.Close) assertWriting(t, f, "127.0.0.2 newhost\n") - require.Nil(t, f.Sync()) + require.NoError(t, f.Sync()) // Wait until fsnotify has triggerred and processed the // file-modification event. diff --git a/internal/aghnet/net_linux.go b/internal/aghnet/net_linux.go index b945dea6..d0d01f7c 100644 --- a/internal/aghnet/net_linux.go +++ b/internal/aghnet/net_linux.go @@ -68,40 +68,41 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) { return false, ErrNoStaticIPInfo } +// findIfaceLine scans s until it finds the line that declares an interface with +// the given name. If findIfaceLine can't find the line, it returns false. +func findIfaceLine(s *bufio.Scanner, name string) (ok bool) { + for s.Scan() { + line := strings.TrimSpace(s.Text()) + fields := strings.Fields(line) + if len(fields) == 2 && fields[0] == "interface" && fields[1] == name { + return true + } + } + + return false +} + // dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to // have a static IP. func dhcpcdStaticConfig(r io.Reader, ifaceName string) (has bool, err error) { s := bufio.NewScanner(r) - var withinInterfaceCtx bool + ifaceFound := findIfaceLine(s, ifaceName) + if !ifaceFound { + return false, s.Err() + } for s.Scan() { line := strings.TrimSpace(s.Text()) - - if withinInterfaceCtx && len(line) == 0 { - // An empty line resets our state. - withinInterfaceCtx = false - } - - if len(line) == 0 || line[0] == '#' { - continue - } - fields := strings.Fields(line) - - if withinInterfaceCtx { - if len(fields) >= 2 && fields[0] == "static" && strings.HasPrefix(fields[1], "ip_address=") { - return true, nil - } - if len(fields) > 0 && fields[0] == "interface" { - // Another interface found. - withinInterfaceCtx = false - } - continue + if len(fields) >= 2 && + fields[0] == "static" && + strings.HasPrefix(fields[1], "ip_address=") { + return true, s.Err() } - if len(fields) == 2 && fields[0] == "interface" && fields[1] == ifaceName { - // The interface found. - withinInterfaceCtx = true + if len(fields) > 0 && fields[0] == "interface" { + // Another interface found. + break } } diff --git a/internal/aghnet/systemresolvers.go b/internal/aghnet/systemresolvers.go index 4133fc24..0543c615 100644 --- a/internal/aghnet/systemresolvers.go +++ b/internal/aghnet/systemresolvers.go @@ -3,7 +3,6 @@ package aghnet import ( "time" - "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" ) @@ -25,18 +24,6 @@ type SystemResolvers interface { refresh() (err error) } -const ( - // errBadAddrPassed is returned when dialFunc can't parse an IP address. - errBadAddrPassed errors.Error = "the passed string is not a valid IP address" - - // errFakeDial is an error which dialFunc is expected to return. - errFakeDial errors.Error = "this error signals the successful dialFunc work" - - // errUnexpectedHostFormat is returned by validateDialedHost when the host has - // more than one percent sign. - errUnexpectedHostFormat errors.Error = "unexpected host format" -) - // refreshWithTicker refreshes the cache of sr after each tick form tickCh. func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) { defer log.OnPanic("systemResolvers") diff --git a/internal/aghnet/systemresolvers_others.go b/internal/aghnet/systemresolvers_others.go index c0ba457a..dd44fa0b 100644 --- a/internal/aghnet/systemresolvers_others.go +++ b/internal/aghnet/systemresolvers_others.go @@ -32,6 +32,18 @@ type systemResolvers struct { addrsLock sync.RWMutex } +const ( + // errBadAddrPassed is returned when dialFunc can't parse an IP address. + errBadAddrPassed errors.Error = "the passed string is not a valid IP address" + + // errFakeDial is an error which dialFunc is expected to return. + errFakeDial errors.Error = "this error signals the successful dialFunc work" + + // errUnexpectedHostFormat is returned by validateDialedHost when the host has + // more than one percent sign. + errUnexpectedHostFormat errors.Error = "unexpected host format" +) + func (sr *systemResolvers) refresh() (err error) { defer func() { err = errors.Annotate(err, "systemResolvers: %w") }() diff --git a/internal/aghnet/systemresolvers_windows.go b/internal/aghnet/systemresolvers_windows.go index f3095860..7dfec410 100644 --- a/internal/aghnet/systemresolvers_windows.go +++ b/internal/aghnet/systemresolvers_windows.go @@ -45,6 +45,57 @@ func (sr *systemResolvers) Get() (rs []string) { return rs } +// writeExit writes "exit" to w and closes it. It is supposed to be run in +// a goroutine. +func writeExit(w io.WriteCloser) { + defer log.OnPanic("systemResolvers: writeExit") + + defer func() { + derr := w.Close() + if derr != nil { + log.Error("systemResolvers: writeExit: closing: %s", derr) + } + }() + + _, err := io.WriteString(w, "exit") + if err != nil { + log.Error("systemResolvers: writeExit: writing: %s", err) + } +} + +// scanAddrs scans the DNS addresses from nslookup's output. The expected +// output of nslookup looks like this: +// +// Default Server: 192-168-1-1.qualified.domain.ru +// Address: 192.168.1.1 +// +func scanAddrs(s *bufio.Scanner) (addrs []string) { + for s.Scan() { + line := strings.TrimSpace(s.Text()) + fields := strings.Fields(line) + if len(fields) != 2 || fields[0] != "Address:" { + continue + } + + // If the address contains port then it is separated with '#'. + ipPort := strings.Split(fields[1], "#") + if len(ipPort) == 0 { + continue + } + + addr := ipPort[0] + if net.ParseIP(addr) == nil { + log.Debug("systemResolvers: %q is not a valid ip", addr) + + continue + } + + addrs = append(addrs, addr) + } + + return addrs +} + // getAddrs gets local resolvers' addresses from OS in a special Windows way. // // TODO(e.burkov): This whole function needs more detailed research on getting @@ -71,73 +122,30 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) { return nil, fmt.Errorf("limiting stdout reader: %w", err) } - go func() { - defer log.OnPanic("systemResolvers") - - defer func() { - derr := stdin.Close() - if derr != nil { - log.Error("systemResolvers: closing stdin pipe: %s", derr) - } - }() - - _, werr := io.WriteString(stdin, "exit") - if werr != nil { - log.Error("systemResolvers: writing to command pipe: %s", werr) - } - }() + go writeExit(stdin) err = cmd.Start() if err != nil { return nil, fmt.Errorf("start command executing: %w", err) } - // The output of nslookup looks like this: - // - // Default Server: 192-168-1-1.qualified.domain.ru - // Address: 192.168.1.1 - - var possibleIPs []string s := bufio.NewScanner(stdoutLimited) - for s.Scan() { - line := s.Text() - if len(line) == 0 { - continue - } - - fields := strings.Fields(line) - if len(fields) != 2 || fields[0] != "Address:" { - continue - } - - // If the address contains port then it is separated with '#'. - ipStrs := strings.Split(fields[1], "#") - if len(ipStrs) == 0 { - continue - } - - possibleIPs = append(possibleIPs, ipStrs[0]) - } + addrs = scanAddrs(s) err = cmd.Wait() if err != nil { return nil, fmt.Errorf("executing the command: %w", err) } + err = s.Err() + if err != nil { + return nil, fmt.Errorf("scanning output: %w", err) + } + // Don't close StdoutPipe since Wait do it for us in ¿most? cases. // // See go doc os/exec.Cmd.StdoutPipe. - for _, addr := range possibleIPs { - if net.ParseIP(addr) == nil { - log.Debug("systemResolvers: %q is not a valid ip", addr) - - continue - } - - addrs = append(addrs, addr) - } - return addrs, nil } diff --git a/internal/aghtest/upstream.go b/internal/aghtest/upstream.go index c69982ed..cffdca3a 100644 --- a/internal/aghtest/upstream.go +++ b/internal/aghtest/upstream.go @@ -13,8 +13,6 @@ import ( // TestUpstream is a mock of real upstream. type TestUpstream struct { - // Addr is the address for Address method. - Addr string // CName is a map of hostname to canonical name. CName map[string]string // IPv4 is a map of hostname to IPv4. @@ -23,9 +21,13 @@ type TestUpstream struct { IPv6 map[string][]net.IP // Reverse is a map of address to domain name. Reverse map[string][]string + // Addr is the address for Address method. + Addr string } // Exchange implements upstream.Upstream interface for *TestUpstream. +// +// TODO(a.garipov): Split further into handlers. func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { resp = &dns.Msg{} resp.SetReply(m) @@ -33,70 +35,69 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { if len(m.Question) == 0 { return nil, fmt.Errorf("question should not be empty") } + name := m.Question[0].Name if cname, ok := u.CName[name]; ok { - resp.Answer = append(resp.Answer, &dns.CNAME{ + ans := &dns.CNAME{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeCNAME, }, Target: cname, - }) + } + + resp.Answer = append(resp.Answer, ans) } - var hasRec bool - var rrType uint16 + rrType := m.Question[0].Qtype + hdr := dns.RR_Header{ + Name: name, + Rrtype: rrType, + } + + var names []string var ips []net.IP switch m.Question[0].Qtype { case dns.TypeA: - rrType = dns.TypeA - if ipv4addr, ok := u.IPv4[name]; ok { - hasRec = true - ips = ipv4addr - } + ips = u.IPv4[name] case dns.TypeAAAA: - rrType = dns.TypeAAAA - if ipv6addr, ok := u.IPv6[name]; ok { - hasRec = true - ips = ipv6addr - } + ips = u.IPv6[name] case dns.TypePTR: - names, ok := u.Reverse[name] - if !ok { - break - } - - for _, n := range names { - resp.Answer = append(resp.Answer, &dns.PTR{ - Hdr: dns.RR_Header{ - Name: n, - Rrtype: rrType, - }, - Ptr: n, - }) - } + names = u.Reverse[name] } for _, ip := range ips { - resp.Answer = append(resp.Answer, &dns.A{ - Hdr: dns.RR_Header{ - Name: name, - Rrtype: rrType, - }, - A: ip, - }) + var ans dns.RR + if rrType == dns.TypeA { + ans = &dns.A{ + Hdr: hdr, + A: ip, + } + + resp.Answer = append(resp.Answer, ans) + + continue + } + + ans = &dns.AAAA{ + Hdr: hdr, + AAAA: ip, + } + + resp.Answer = append(resp.Answer, ans) + } + + for _, n := range names { + ans := &dns.PTR{ + Hdr: hdr, + Ptr: n, + } + + resp.Answer = append(resp.Answer, ans) } if len(resp.Answer) == 0 { - if hasRec { - // Set no error RCode if there are some records for - // given Qname but we didn't apply them. - resp.SetRcode(m, dns.RcodeSuccess) - - return resp, nil - } - // Set NXDomain RCode otherwise. resp.SetRcode(m, dns.RcodeNameError) } @@ -111,10 +112,13 @@ func (u *TestUpstream) Address() string { // TestBlockUpstream implements upstream.Upstream interface for replacing real // upstream in tests. type TestBlockUpstream struct { - Hostname string - Block bool - requestsCount int - lock sync.RWMutex + Hostname string + + // lock protects reqNum. + lock sync.RWMutex + reqNum int + + Block bool } // Exchange returns a message unique for TestBlockUpstream's Hostname-Block @@ -122,7 +126,7 @@ type TestBlockUpstream struct { func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) { u.lock.Lock() defer u.lock.Unlock() - u.requestsCount++ + u.reqNum++ hash := sha256.Sum256([]byte(u.Hostname)) hashToReturn := hex.EncodeToString(hash[:]) @@ -156,7 +160,7 @@ func (u *TestBlockUpstream) RequestsCount() int { u.lock.Lock() defer u.lock.Unlock() - return u.requestsCount + return u.reqNum } // TestErrUpstream implements upstream.Upstream interface for replacing real diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index 565844b5..edf54f51 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -326,7 +326,7 @@ func TestServer_ProcessRestrictLocal(t *testing.T) { require.Len(t, pctx.Res.Answer, tc.wantLen) if tc.wantLen > 0 { - assert.Equal(t, tc.want, pctx.Res.Answer[0].Header().Name) + assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr) } }) } @@ -368,7 +368,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { require.Equal(t, resultCodeSuccess, rc) require.NotEmpty(t, proxyCtx.Res.Answer) - assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].Header().Name) + assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr) }) t.Run("disabled", func(t *testing.T) { diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 35069656..da9497aa 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -284,7 +284,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) { StartTime: time.Now(), } - var resolver *proxy.Proxy = s.internalProxy + resolver := s.internalProxy if s.subnetDetector.IsLocallyServedNetwork(ip) { if !s.conf.UsePrivateRDNS { return "", nil diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 40de04bc..8bc66031 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -175,8 +175,15 @@ golint --set_exit_status ./... "$GO" vet ./... -# Here and below, don't use quotes to get word splitting. -gocyclo --over 17 $go_files +# Apply more lax standards to the code we haven't properly refactored yet. +gocyclo --over 17 ./internal/dhcpd/ ./internal/dnsforward/\ + ./internal/filtering/ ./internal/home/ ./internal/querylog/\ + ./internal/stats/ ./internal/updater/ + +# Apply stricter standards to new or vetted code +gocyclo --over 10 ./internal/aghio/ ./internal/aghnet/ ./internal/aghos/\ + ./internal/aghstrings/ ./internal/aghtest/ ./internal/tools/\ + ./internal/version/ ./main.go gosec --quiet $go_files