diff --git a/CHANGELOG.md b/CHANGELOG.md index e2fa248b..3a7951ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,7 @@ and this project adheres to ### Fixed +- Unnecessary conversions from `string` to `net.IP`, and vice versa ([#2508]). - Inability to set DNS cache TTL limits ([#2459]). - Possible freezes on slower machines ([#2225]). - A mitigation against records being shown in the wrong order on the query log @@ -79,9 +80,13 @@ and this project adheres to [#2345]: https://github.com/AdguardTeam/AdGuardHome/issues/2345 [#2355]: https://github.com/AdguardTeam/AdGuardHome/issues/2355 [#2459]: https://github.com/AdguardTeam/AdGuardHome/issues/2459 +[#2508]: https://github.com/AdguardTeam/AdGuardHome/issues/2508 ### Removed +- The undocumented ability to use hostnames as any of `bind_host` values in + configuration. Documentation requires them to be valid IP addresses, and now + the implementation makes sure that that is the case ([#2508]). - `Dockerfile` ([#2276]). Replaced with the script `scripts/make/build-docker.sh` which uses `scripts/make/Dockerfile`. - Support for pre-v0.99.3 format of query logs ([#2102]). diff --git a/internal/dhcpd/dhcpd.go b/internal/dhcpd/dhcpd.go index 17d817cf..18509991 100644 --- a/internal/dhcpd/dhcpd.go +++ b/internal/dhcpd/dhcpd.go @@ -297,9 +297,6 @@ func parseOptionString(s string) (uint8, []byte) { return 0, nil } val = ip - if ip.To4() != nil { - val = ip.To4() - } default: return 0, nil diff --git a/internal/dhcpd/dhcpd_test.go b/internal/dhcpd/dhcpd_test.go index 1d3aab21..f65e6823 100644 --- a/internal/dhcpd/dhcpd_test.go +++ b/internal/dhcpd/dhcpd_test.go @@ -61,11 +61,11 @@ func TestDB(t *testing.T) { ll := s.srv4.GetLeases(LeasesAll) assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String()) - assert.Equal(t, "192.168.10.101", ll[0].IP.String()) + assert.True(t, net.IP{192, 168, 10, 101}.Equal(ll[0].IP)) assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String()) - assert.Equal(t, "192.168.10.100", ll[1].IP.String()) + assert.True(t, net.IP{192, 168, 10, 100}.Equal(ll[1].IP)) assert.Equal(t, exp1.Unix(), ll[1].Expiry.Unix()) _ = os.Remove("leases.db") @@ -117,7 +117,7 @@ func TestOptions(t *testing.T) { code, val = parseOptionString("123 ip 1.2.3.4") assert.EqualValues(t, 123, code) - assert.Equal(t, "1.2.3.4", net.IP(string(val)).String()) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(val))) code, _ = parseOptionString("256 ip 1.1.1.1") assert.EqualValues(t, 0, code) diff --git a/internal/dhcpd/dhcphttp.go b/internal/dhcpd/dhcphttp.go index 36bc64ef..e38afbca 100644 --- a/internal/dhcpd/dhcphttp.go +++ b/internal/dhcpd/dhcphttp.go @@ -40,7 +40,7 @@ func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf { } type v6ServerConfJSON struct { - RangeStart string `json:"range_start"` + RangeStart net.IP `json:"range_start"` LeaseDuration uint32 `json:"lease_duration"` } @@ -331,7 +331,7 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque result.V4.StaticIP.Error = err.Error() } else if !isStaticIP { result.V4.StaticIP.Static = "no" - result.V4.StaticIP.IP = util.GetSubnet(interfaceName) + result.V4.StaticIP.IP = util.GetSubnet(interfaceName).String() } if found4 { diff --git a/internal/dhcpd/server.go b/internal/dhcpd/server.go index 261ad4db..4adbca5a 100644 --- a/internal/dhcpd/server.go +++ b/internal/dhcpd/server.go @@ -79,7 +79,7 @@ type V6ServerConf struct { // The first IP address for dynamic leases // The last allowed IP address ends with 0xff byte - RangeStart string `yaml:"range_start" json:"range_start"` + RangeStart net.IP `yaml:"range_start"` LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds diff --git a/internal/dhcpd/v4_test.go b/internal/dhcpd/v4_test.go index e3086026..74eab484 100644 --- a/internal/dhcpd/v4_test.go +++ b/internal/dhcpd/v4_test.go @@ -40,7 +40,7 @@ func TestV4StaticLeaseAddRemove(t *testing.T) { // check ls = s.GetLeases(LeasesStatic) assert.Len(t, ls, 1) - assert.Equal(t, "192.168.10.150", ls[0].IP.String()) + assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP)) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) @@ -102,11 +102,11 @@ func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) { ls := s.GetLeases(LeasesStatic) assert.Len(t, ls, 2) - assert.Equal(t, "192.168.10.150", ls[0].IP.String()) + assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP)) assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) - assert.Equal(t, "192.168.10.152", ls[1].IP.String()) + assert.True(t, net.IP{192, 168, 10, 152}.Equal(ls[1].IP)) assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String()) assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix()) } @@ -139,10 +139,10 @@ func TestV4StaticLeaseGet(t *testing.T) { // check "Offer" assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) - assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String()) - assert.Equal(t, "192.168.10.1", resp.Router()[0].String()) - assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String()) - assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String()) + assert.True(t, net.IP{192, 168, 10, 150}.Equal(resp.YourIPAddr)) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0])) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier())) + assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask()))) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) // "Request" @@ -153,20 +153,20 @@ func TestV4StaticLeaseGet(t *testing.T) { // check "Ack" assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) - assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String()) - assert.Equal(t, "192.168.10.1", resp.Router()[0].String()) - assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String()) - assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String()) + assert.True(t, net.IP{192, 168, 10, 150}.Equal(resp.YourIPAddr)) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0])) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier())) + assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask()))) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) dnsAddrs := resp.DNS() assert.Len(t, dnsAddrs, 1) - assert.Equal(t, "192.168.10.1", dnsAddrs[0].String()) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0])) // check lease ls := s.GetLeases(LeasesStatic) assert.Len(t, ls, 1) - assert.Equal(t, "192.168.10.150", ls[0].IP.String()) + assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP)) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) } @@ -197,13 +197,13 @@ func TestV4DynamicLeaseGet(t *testing.T) { // check "Offer" assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) - assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String()) - assert.Equal(t, "192.168.10.1", resp.Router()[0].String()) - assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String()) - assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String()) + assert.True(t, net.IP{192, 168, 10, 100}.Equal(resp.YourIPAddr)) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0])) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier())) + assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask()))) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)]) - assert.Equal(t, "1.2.3.4", net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]).String()) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]))) // "Request" req, _ = dhcpv4.NewRequestFromOffer(resp) @@ -213,20 +213,20 @@ func TestV4DynamicLeaseGet(t *testing.T) { // check "Ack" assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) - assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String()) - assert.Equal(t, "192.168.10.1", resp.Router()[0].String()) - assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String()) - assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String()) + assert.True(t, net.IP{192, 168, 10, 100}.Equal(resp.YourIPAddr)) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0])) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier())) + assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask()))) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) dnsAddrs := resp.DNS() assert.Len(t, dnsAddrs, 1) - assert.Equal(t, "192.168.10.1", dnsAddrs[0].String()) + assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0])) // check lease ls := s.GetLeases(LeasesDynamic) assert.Len(t, ls, 1) - assert.Equal(t, "192.168.10.100", ls[0].IP.String()) + assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP)) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) start := net.IP{192, 168, 10, 100} diff --git a/internal/dhcpd/v6.go b/internal/dhcpd/v6.go index 0f8cddfd..f9dab113 100644 --- a/internal/dhcpd/v6.go +++ b/internal/dhcpd/v6.go @@ -660,7 +660,7 @@ func v6Create(conf V6ServerConf) (DHCPServer, error) { return s, nil } - s.conf.ipStart = net.ParseIP(conf.RangeStart) + s.conf.ipStart = conf.RangeStart if s.conf.ipStart == nil || s.conf.ipStart.To16() == nil { return s, fmt.Errorf("dhcpv6: invalid range-start IP: %s", conf.RangeStart) } diff --git a/internal/dhcpd/v6_test.go b/internal/dhcpd/v6_test.go index fd3dd89b..1c54163a 100644 --- a/internal/dhcpd/v6_test.go +++ b/internal/dhcpd/v6_test.go @@ -17,7 +17,7 @@ func notify6(flags uint32) { func TestV6StaticLeaseAddRemove(t *testing.T) { conf := V6ServerConf{ Enabled: true, - RangeStart: "2001::1", + RangeStart: net.ParseIP("2001::1"), notify: notify6, } s, err := v6Create(conf) @@ -60,7 +60,7 @@ func TestV6StaticLeaseAddRemove(t *testing.T) { func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) { conf := V6ServerConf{ Enabled: true, - RangeStart: "2001::1", + RangeStart: net.ParseIP("2001::1"), notify: notify6, } sIface, err := v6Create(conf) @@ -109,7 +109,7 @@ func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) { func TestV6GetLease(t *testing.T) { conf := V6ServerConf{ Enabled: true, - RangeStart: "2001::1", + RangeStart: net.ParseIP("2001::1"), notify: notify6, } sIface, err := v6Create(conf) @@ -169,7 +169,7 @@ func TestV6GetLease(t *testing.T) { func TestV6GetDynamicLease(t *testing.T) { conf := V6ServerConf{ Enabled: true, - RangeStart: "2001::2", + RangeStart: net.ParseIP("2001::2"), notify: notify6, } sIface, err := v6Create(conf) diff --git a/internal/dnsfilter/dnsfilter.go b/internal/dnsfilter/dnsfilter.go index c5c28aff..45ce5b8a 100644 --- a/internal/dnsfilter/dnsfilter.go +++ b/internal/dnsfilter/dnsfilter.go @@ -36,7 +36,7 @@ type RequestFilteringSettings struct { ParentalEnabled bool ClientName string - ClientIP string + ClientIP net.IP ClientTags []string ServicesRules []ServiceEntry @@ -676,9 +676,10 @@ func (d *DNSFilter) matchHost(host string, qtype uint16, setts RequestFilteringS ureq := urlfilter.DNSRequest{ Hostname: host, SortedClientTags: setts.ClientTags, - ClientIP: setts.ClientIP, - ClientName: setts.ClientName, - DNSType: qtype, + // TODO(e.burkov): Wait for urlfilter update to pass net.IP. + ClientIP: setts.ClientIP.String(), + ClientName: setts.ClientName, + DNSType: qtype, } if d.filteringEngineAllow != nil { diff --git a/internal/dnsfilter/rewrites_test.go b/internal/dnsfilter/rewrites_test.go index a38f3f9d..12a59409 100644 --- a/internal/dnsfilter/rewrites_test.go +++ b/internal/dnsfilter/rewrites_test.go @@ -117,19 +117,19 @@ func TestRewritesLevels(t *testing.T) { r := d.processRewrites("host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) assert.Len(t, r.IPList, 1) - assert.Equal(t, "1.1.1.1", r.IPList[0].String()) + assert.True(t, net.IP{1, 1, 1, 1}.Equal(r.IPList[0])) // match L2 r = d.processRewrites("sub.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) assert.Len(t, r.IPList, 1) - assert.Equal(t, "2.2.2.2", r.IPList[0].String()) + assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0])) // match L3 r = d.processRewrites("my.sub.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) assert.Len(t, r.IPList, 1) - assert.Equal(t, "3.3.3.3", r.IPList[0].String()) + assert.True(t, net.IP{3, 3, 3, 3}.Equal(r.IPList[0])) } func TestRewritesExceptionCNAME(t *testing.T) { @@ -145,7 +145,7 @@ func TestRewritesExceptionCNAME(t *testing.T) { r := d.processRewrites("my.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) assert.Len(t, r.IPList, 1) - assert.Equal(t, "2.2.2.2", r.IPList[0].String()) + assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0])) // match sub-domain, but handle exception r = d.processRewrites("sub.host.com", dns.TypeA) @@ -165,7 +165,7 @@ func TestRewritesExceptionWC(t *testing.T) { r := d.processRewrites("my.host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) assert.Len(t, r.IPList, 1) - assert.Equal(t, "2.2.2.2", r.IPList[0].String()) + assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0])) // match sub-domain, but handle exception r = d.processRewrites("my.sub.host.com", dns.TypeA) @@ -188,7 +188,7 @@ func TestRewritesExceptionIP(t *testing.T) { r := d.processRewrites("host.com", dns.TypeA) assert.Equal(t, Rewritten, r.Reason) assert.Len(t, r.IPList, 1) - assert.Equal(t, "1.2.3.4", r.IPList[0].String()) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(r.IPList[0])) // match exception r = d.processRewrites("host.com", dns.TypeAAAA) diff --git a/internal/dnsforward/access.go b/internal/dnsforward/access.go index 5038a89a..8afae955 100644 --- a/internal/dnsforward/access.go +++ b/internal/dnsforward/access.go @@ -83,20 +83,21 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin // Returns the item from the "disallowedClients" list that lead to blocking IP. // If it returns TRUE and an empty string, it means that the "allowedClients" is not empty, // but the ip does not belong to it. -func (a *accessCtx) IsBlockedIP(ip string) (bool, string) { +func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) { + ipStr := ip.String() + a.lock.Lock() defer a.lock.Unlock() if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 { - _, ok := a.allowedClients[ip] + _, ok := a.allowedClients[ipStr] if ok { return false, "" } if len(a.allowedClientsIPNet) != 0 { - ipAddr := net.ParseIP(ip) for _, ipnet := range a.allowedClientsIPNet { - if ipnet.Contains(ipAddr) { + if ipnet.Contains(ip) { return false, "" } } @@ -105,15 +106,14 @@ func (a *accessCtx) IsBlockedIP(ip string) (bool, string) { return true, "" } - _, ok := a.disallowedClients[ip] + _, ok := a.disallowedClients[ipStr] if ok { - return true, ip + return true, ipStr } if len(a.disallowedClientsIPNet) != 0 { - ipAddr := net.ParseIP(ip) for _, ipnet := range a.disallowedClientsIPNet { - if ipnet.Contains(ipAddr) { + if ipnet.Contains(ip) { return true, ipnet.String() } } diff --git a/internal/dnsforward/access_test.go b/internal/dnsforward/access_test.go index 5c225b21..af13b02e 100644 --- a/internal/dnsforward/access_test.go +++ b/internal/dnsforward/access_test.go @@ -1,6 +1,7 @@ package dnsforward import ( + "net" "testing" "github.com/stretchr/testify/assert" @@ -10,19 +11,19 @@ func TestIsBlockedIPAllowed(t *testing.T) { a := &accessCtx{} assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil)) - disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1") + disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1)) assert.False(t, disallowed) assert.Empty(t, disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2)) assert.True(t, disallowed) assert.Empty(t, disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1)) assert.False(t, disallowed) assert.Empty(t, disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1)) assert.True(t, disallowed) assert.Empty(t, disallowedRule) } @@ -31,19 +32,19 @@ func TestIsBlockedIPDisallowed(t *testing.T) { a := &accessCtx{} assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil)) - disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1") + disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1)) assert.True(t, disallowed) assert.Equal(t, "1.1.1.1", disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2)) assert.False(t, disallowed) assert.Empty(t, disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1)) assert.True(t, disallowed) assert.Equal(t, "2.2.0.0/16", disallowedRule) - disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1") + disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1)) assert.False(t, disallowed) assert.Empty(t, disallowedRule) } diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 0ff078f7..ee3f6aa0 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -25,11 +25,11 @@ type FilteringConfig struct { // -- // Filtering callback function - FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` + FilterHandler func(clientAddr net.IP, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` // GetCustomUpstreamByClient - a callback function that returns upstreams configuration // based on the client IP address. Returns nil if there are no custom upstreams for the client - // TODO(e.burkov): replace argument type with net.IP. + // TODO(e.burkov): Replace argument type with net.IP. GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"` // Protection configuration diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index ab6bea27..0ada0640 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -298,6 +298,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // IsBlockedIP - return TRUE if this client should be blocked -func (s *Server) IsBlockedIP(ip string) (bool, string) { +func (s *Server) IsBlockedIP(ip net.IP) (bool, string) { return s.access.IsBlockedIP(ip) } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index ab9bb03c..ec56c9e8 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -322,7 +322,7 @@ func TestServerCustomClientUpstream(t *testing.T) { assert.Nil(t, err) assert.Equal(t, dns.RcodeSuccess, reply.Rcode) assert.NotNil(t, reply.Answer) - assert.Equal(t, "192.168.0.1", reply.Answer[0].(*dns.A).A.String()) + assert.True(t, net.IP{192, 168, 0, 1}.Equal(reply.Answer[0].(*dns.A).A)) assert.Nil(t, s.Stop()) } @@ -473,7 +473,7 @@ func TestBlockCNAME(t *testing.T) { func TestClientRulesForCNAMEMatching(t *testing.T) { s := createTestServer(t) testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} - s.conf.FilterHandler = func(_ string, settings *dnsfilter.RequestFilteringSettings) { + s.conf.FilterHandler = func(_ net.IP, settings *dnsfilter.RequestFilteringSettings) { settings.FilteringEnabled = false } err := s.startWithUpstream(testUpstm) @@ -568,7 +568,7 @@ func TestBlockedCustomIP(t *testing.T) { assert.Len(t, reply.Answer, 1) a, ok := reply.Answer[0].(*dns.A) assert.True(t, ok) - assert.Equal(t, "0.0.0.1", a.A.String()) + assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A)) req = createTestMessageWithType("null.example.org.", dns.TypeAAAA) reply, err = dns.Exchange(req, addr.String()) @@ -713,7 +713,7 @@ func TestRewrite(t *testing.T) { assert.Len(t, reply.Answer, 1) a, ok := reply.Answer[0].(*dns.A) assert.True(t, ok) - assert.Equal(t, "1.2.3.4", a.A.String()) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A)) req = createTestMessageWithType("test.com.", dns.TypeAAAA) reply, err = dns.Exchange(req, addr.String()) @@ -725,7 +725,7 @@ func TestRewrite(t *testing.T) { assert.Nil(t, err) assert.Len(t, reply.Answer, 2) assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) - assert.Equal(t, "1.2.3.4", reply.Answer[1].(*dns.A).A.String()) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) reply, err = dns.Exchange(req, addr.String()) diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 77ae30a9..4d319288 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -12,7 +12,7 @@ import ( ) func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) { - ip := IPStringFromAddr(d.Addr) + ip := IPFromAddr(d.Addr) disallowed, _ := s.access.IsBlockedIP(ip) if disallowed { log.Tracef("Client IP %s is blocked by settings", ip) @@ -36,8 +36,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt setts := s.dnsFilter.GetConfig() setts.FilteringEnabled = true if s.conf.FilterHandler != nil { - clientAddr := IPStringFromAddr(d.Addr) - s.conf.FilterHandler(clientAddr, &setts) + s.conf.FilterHandler(IPFromAddr(d.Addr), &setts) } return &setts } diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go index 822df6a0..be45b0f9 100644 --- a/internal/dnsforward/stats.go +++ b/internal/dnsforward/stats.go @@ -36,7 +36,7 @@ func processQueryLogsAndStats(ctx *dnsContext) int { OrigAnswer: ctx.origResp, Result: ctx.result, Elapsed: elapsed, - ClientIP: ipFromAddr(d.Addr), + ClientIP: IPFromAddr(d.Addr), } switch d.Proto { diff --git a/internal/dnsforward/util.go b/internal/dnsforward/util.go index 3a8c1cb3..4b57768b 100644 --- a/internal/dnsforward/util.go +++ b/internal/dnsforward/util.go @@ -8,8 +8,8 @@ import ( "github.com/AdguardTeam/golibs/utils" ) -// ipFromAddr gets IP address from addr. -func ipFromAddr(addr net.Addr) (ip net.IP) { +// IPFromAddr gets IP address from addr. +func IPFromAddr(addr net.Addr) (ip net.IP) { switch addr := addr.(type) { case *net.UDPAddr: return addr.IP @@ -22,8 +22,8 @@ func ipFromAddr(addr net.Addr) (ip net.IP) { // IPStringFromAddr extracts IP address from net.Addr. // Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone: // https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261 -func IPStringFromAddr(addr net.Addr) (ipstr string) { - if ip := ipFromAddr(addr); ip != nil { +func IPStringFromAddr(addr net.Addr) (ipStr string) { + if ip := IPFromAddr(addr); ip != nil { return ip.String() } diff --git a/internal/home/clients.go b/internal/home/clients.go index 3c6bfa48..b631f63e 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -70,10 +70,12 @@ type ClientHost struct { } type clientsContainer struct { - list map[string]*Client // name -> client - idIndex map[string]*Client // IP -> client - ipHost map[string]*ClientHost // IP -> Hostname - lock sync.Mutex + list map[string]*Client // name -> client + idIndex map[string]*Client // IP -> client + // TODO(e.burkov): Think of a way to not require string conversion for + // IP addresses. + ipHost map[string]*ClientHost // IP -> Hostname + lock sync.Mutex allTags map[string]bool @@ -239,7 +241,7 @@ func (clients *clientsContainer) onHostsChanged() { } // Exists checks if client with this IP already exists -func (clients *clientsContainer) Exists(ip string, source clientSource) bool { +func (clients *clientsContainer) Exists(ip net.IP, source clientSource) bool { clients.lock.Lock() defer clients.lock.Unlock() @@ -248,7 +250,7 @@ func (clients *clientsContainer) Exists(ip string, source clientSource) bool { return true } - ch, ok := clients.ipHost[ip] + ch, ok := clients.ipHost[ip.String()] if !ok { return false } @@ -265,7 +267,7 @@ func stringArrayDup(a []string) []string { } // Find searches for a client by IP -func (clients *clientsContainer) Find(ip string) (Client, bool) { +func (clients *clientsContainer) Find(ip net.IP) (Client, bool) { clients.lock.Lock() defer clients.lock.Unlock() @@ -287,7 +289,7 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig clients.lock.Lock() defer clients.lock.Unlock() - c, ok := clients.findByIP(ip) + c, ok := clients.findByIP(net.ParseIP(ip)) if !ok { return nil } @@ -307,13 +309,12 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig } // Find searches for a client by IP (and does not lock anything) -func (clients *clientsContainer) findByIP(ip string) (Client, bool) { - ipAddr := net.ParseIP(ip) - if ipAddr == nil { +func (clients *clientsContainer) findByIP(ip net.IP) (Client, bool) { + if ip == nil { return Client{}, false } - c, ok := clients.idIndex[ip] + c, ok := clients.idIndex[ip.String()] if ok { return *c, true } @@ -324,7 +325,7 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) { if err != nil { continue } - if ipnet.Contains(ipAddr) { + if ipnet.Contains(ip) { return *c, true } } @@ -333,7 +334,7 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) { if clients.dhcpServer == nil { return Client{}, false } - macFound := clients.dhcpServer.FindMACbyIP(ipAddr) + macFound := clients.dhcpServer.FindMACbyIP(ip) if macFound == nil { return Client{}, false } @@ -353,16 +354,15 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) { } // FindAutoClient - search for an auto-client by IP -func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) { - ipAddr := net.ParseIP(ip) - if ipAddr == nil { +func (clients *clientsContainer) FindAutoClient(ip net.IP) (ClientHost, bool) { + if ip == nil { return ClientHost{}, false } clients.lock.Lock() defer clients.lock.Unlock() - ch, ok := clients.ipHost[ip] + ch, ok := clients.ipHost[ip.String()] if ok { return *ch, true } @@ -539,7 +539,7 @@ func (clients *clientsContainer) Update(name string, c Client) error { } // SetWhoisInfo - associate WHOIS information with a client -func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) { +func (clients *clientsContainer) SetWhoisInfo(ip net.IP, info [][]string) { clients.lock.Lock() defer clients.lock.Unlock() @@ -549,7 +549,8 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) { return } - ch, ok := clients.ipHost[ip] + ipStr := ip.String() + ch, ok := clients.ipHost[ipStr] if ok { ch.WhoisInfo = info log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo) @@ -561,7 +562,7 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) { Source: ClientSourceWHOIS, } ch.WhoisInfo = info - clients.ipHost[ip] = ch + clients.ipHost[ipStr] = ch log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo) } diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index 69f2bada..94ff8009 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -36,21 +36,21 @@ func TestClients(t *testing.T) { assert.True(t, b) assert.Nil(t, err) - c, b = clients.Find("1.1.1.1") + c, b = clients.Find(net.IPv4(1, 1, 1, 1)) assert.True(t, b) assert.Equal(t, c.Name, "client1") - c, b = clients.Find("1:2:3::4") + c, b = clients.Find(net.ParseIP("1:2:3::4")) assert.True(t, b) assert.Equal(t, c.Name, "client1") - c, b = clients.Find("2.2.2.2") + c, b = clients.Find(net.IPv4(2, 2, 2, 2)) assert.True(t, b) assert.Equal(t, c.Name, "client2") - assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile)) - assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) - assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile)) + assert.False(t, clients.Exists(net.IPv4(1, 2, 3, 4), ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IPv4(2, 2, 2, 2), ClientSourceHostsFile)) }) t.Run("add_fail_name", func(t *testing.T) { @@ -112,8 +112,8 @@ func TestClients(t *testing.T) { err := clients.Update("client1", c) assert.Nil(t, err) - assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) - assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) + assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile)) c = Client{ IDs: []string{"1.1.1.2"}, @@ -124,7 +124,7 @@ func TestClients(t *testing.T) { err = clients.Update("client1", c) assert.Nil(t, err) - c, b := clients.Find("1.1.1.2") + c, b := clients.Find(net.IPv4(1, 1, 1, 2)) assert.True(t, b) assert.Equal(t, "client1-renamed", c.Name) assert.Equal(t, "1.1.1.2", c.IDs[0]) @@ -135,7 +135,7 @@ func TestClients(t *testing.T) { t.Run("del_success", func(t *testing.T) { b := clients.Del("client1-renamed") assert.True(t, b) - assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) + assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile)) }) t.Run("del_fail", func(t *testing.T) { @@ -156,7 +156,7 @@ func TestClients(t *testing.T) { assert.True(t, b) assert.Nil(t, err) - assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile)) }) t.Run("addhost_fail", func(t *testing.T) { @@ -174,12 +174,12 @@ func TestClientsWhois(t *testing.T) { whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}} // set whois info on new client - clients.SetWhoisInfo("1.1.1.255", whois) + clients.SetWhoisInfo(net.IPv4(1, 1, 1, 255), whois) assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1]) // set whois info on existing auto-client _, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) - clients.SetWhoisInfo("1.1.1.1", whois) + clients.SetWhoisInfo(net.IPv4(1, 1, 1, 1), whois) assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1]) // Check that we cannot set whois info on a manually-added client @@ -188,7 +188,7 @@ func TestClientsWhois(t *testing.T) { Name: "client1", } _, _ = clients.Add(c) - clients.SetWhoisInfo("1.1.1.2", whois) + clients.SetWhoisInfo(net.IPv4(1, 1, 1, 2), whois) assert.Nil(t, clients.ipHost["1.1.1.2"]) _ = clients.Del("client1") } diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 51a193f9..aecc68e5 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -3,6 +3,7 @@ package home import ( "encoding/json" "fmt" + "net" "net/http" ) @@ -229,8 +230,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http q := r.URL.Query() data := []map[string]interface{}{} for i := 0; ; i++ { - ip := q.Get(fmt.Sprintf("ip%d", i)) - if len(ip) == 0 { + ipStr := q.Get(fmt.Sprintf("ip%d", i)) + ip := net.ParseIP(ipStr) + if ip == nil { break } @@ -248,7 +250,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip) } - el[ip] = cj + el[ipStr] = cj data = append(data, el) } @@ -267,7 +269,8 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http // findTemporary looks up the IP in temporary storages, like autohosts or // blocklists. -func (clients *clientsContainer) findTemporary(ip string) (cj clientJSON, found bool) { +func (clients *clientsContainer) findTemporary(ip net.IP) (cj clientJSON, found bool) { + ipStr := ip.String() ch, ok := clients.FindAutoClient(ip) if !ok { // It is still possible that the IP used to be in the runtime @@ -281,7 +284,7 @@ func (clients *clientsContainer) findTemporary(ip string) (cj clientJSON, found } cj = clientJSON{ - IDs: []string{ip}, + IDs: []string{ipStr}, Disallowed: disallowed, DisallowedRule: rule, } @@ -289,7 +292,7 @@ func (clients *clientsContainer) findTemporary(ip string) (cj clientJSON, found return cj, true } - cj = clientHostToJSON(ip, ch) + cj = clientHostToJSON(ipStr, ch) cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip) return cj, true diff --git a/internal/home/config.go b/internal/home/config.go index f7b799dc..6c8381ab 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -2,6 +2,7 @@ package home import ( "io/ioutil" + "net" "os" "path/filepath" "sync" @@ -40,7 +41,7 @@ type configuration struct { // It's reset after config is parsed fileData []byte - BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to + BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client Users []User `yaml:"users"` // Users that can access HTTP server @@ -74,7 +75,7 @@ type configuration struct { // field ordering is important -- yaml fields will mirror ordering from here type dnsConfig struct { - BindHost string `yaml:"bind_host"` + BindHost net.IP `yaml:"bind_host"` Port int `yaml:"port"` // time interval for statistics (in days) @@ -121,9 +122,9 @@ type tlsConfigSettings struct { var config = configuration{ BindPort: 3000, BetaBindPort: 0, - BindHost: "0.0.0.0", + BindHost: net.IP{0, 0, 0, 0}, DNS: dnsConfig{ - BindHost: "0.0.0.0", + BindHost: net.IP{0, 0, 0, 0}, Port: 53, StatsInterval: 1, FilteringConfig: dnsforward.FilteringConfig{ diff --git a/internal/home/control.go b/internal/home/control.go index 7caac5e6..f309663f 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -36,11 +36,12 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface // --------------- // dns run control // --------------- -func addDNSAddress(dnsAddresses *[]string, addr string) { +func addDNSAddress(dnsAddresses *[]string, addr net.IP) { + hostport := addr.String() if config.DNS.Port != 53 { - addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port) + hostport = net.JoinHostPort(hostport, strconv.Itoa(config.DNS.Port)) } - *dnsAddresses = append(*dnsAddresses, addr) + *dnsAddresses = append(*dnsAddresses, hostport) } func handleStatus(w http.ResponseWriter, _ *http.Request) { diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 2abfc4ba..ae93184c 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -31,7 +31,7 @@ type netInterfaceJSON struct { Name string `json:"name"` MTU int `json:"mtu"` HardwareAddr string `json:"hardware_address"` - Addresses []string `json:"ip_addresses"` + Addresses []net.IP `json:"ip_addresses"` Flags string `json:"flags"` } @@ -69,7 +69,7 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request type checkConfigReqEnt struct { Port int `json:"port"` - IP string `json:"ip"` + IP net.IP `json:"ip"` Autofix bool `json:"autofix"` } @@ -138,7 +138,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) if err != nil { respData.DNS.Status = err.Error() - } else if reqData.DNS.IP != "0.0.0.0" { + } else if !reqData.DNS.IP.IsUnspecified() { respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP) } } @@ -154,7 +154,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) // handleStaticIP - handles static IP request // It either checks if we have a static IP // Or if set=true, it tries to set it -func handleStaticIP(ip string, set bool) staticIPJSON { +func handleStaticIP(ip net.IP, set bool) staticIPJSON { resp := staticIPJSON{} interfaceName := util.GetInterfaceByIP(ip) @@ -186,7 +186,7 @@ func handleStaticIP(ip string, set bool) staticIPJSON { if isStaticIP { resp.Static = "yes" } - resp.IP = util.GetSubnet(interfaceName) + resp.IP = util.GetSubnet(interfaceName).String() } return resp } @@ -262,7 +262,7 @@ func disableDNSStubListener() error { } type applyConfigReqEnt struct { - IP string `json:"ip"` + IP net.IP `json:"ip"` Port int `json:"port"` } @@ -297,7 +297,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { } restartHTTP := true - if config.BindHost == newSettings.Web.IP && config.BindPort == newSettings.Web.Port { + if config.BindHost.Equal(newSettings.Web.IP) && config.BindPort == newSettings.Web.Port { // no need to rebind restartHTTP = false } @@ -307,7 +307,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port) if err != nil { httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s", - net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err) + net.JoinHostPort(newSettings.Web.IP.String(), strconv.Itoa(newSettings.Web.Port)), err) return } @@ -388,18 +388,18 @@ func (web *Web) registerInstallHandlers() { // checkConfigReqEntBeta is a struct representing new client's config check // request entry. It supports multiple IP values unlike the checkConfigReqEnt. // -// TODO(e.burkov): this should removed with the API v1 when the appropriate +// TODO(e.burkov): This should removed with the API v1 when the appropriate // functionality will appear in default checkConfigReqEnt. type checkConfigReqEntBeta struct { Port int `json:"port"` - IP []string `json:"ip"` + IP []net.IP `json:"ip"` Autofix bool `json:"autofix"` } // checkConfigReqBeta is a struct representing new client's config check request // body. It uses checkConfigReqEntBeta instead of checkConfigReqEnt. // -// TODO(e.burkov): this should removed with the API v1 when the appropriate +// TODO(e.burkov): This should removed with the API v1 when the appropriate // functionality will appear in default checkConfigReq. type checkConfigReqBeta struct { Web checkConfigReqEntBeta `json:"web"` @@ -410,7 +410,7 @@ type checkConfigReqBeta struct { // handleInstallCheckConfigBeta is a substitution of /install/check_config // handler for new client. // -// TODO(e.burkov): this should removed with the API v1 when the appropriate +// TODO(e.burkov): This should removed with the API v1 when the appropriate // functionality will appear in default handleInstallCheckConfig. func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Request) { reqData := checkConfigReqBeta{} @@ -456,17 +456,17 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ // applyConfigReqEntBeta is a struct representing new client's config setting // request entry. It supports multiple IP values unlike the applyConfigReqEnt. // -// TODO(e.burkov): this should removed with the API v1 when the appropriate +// TODO(e.burkov): This should removed with the API v1 when the appropriate // functionality will appear in default applyConfigReqEnt. type applyConfigReqEntBeta struct { - IP []string `json:"ip"` + IP []net.IP `json:"ip"` Port int `json:"port"` } // applyConfigReqBeta is a struct representing new client's config setting // request body. It uses applyConfigReqEntBeta instead of applyConfigReqEnt. // -// TODO(e.burkov): this should removed with the API v1 when the appropriate +// TODO(e.burkov): This should removed with the API v1 when the appropriate // functionality will appear in default applyConfigReq. type applyConfigReqBeta struct { Web applyConfigReqEntBeta `json:"web"` @@ -478,7 +478,7 @@ type applyConfigReqBeta struct { // handleInstallConfigureBeta is a substitution of /install/configure handler // for new client. // -// TODO(e.burkov): this should removed with the API v1 when the appropriate +// TODO(e.burkov): This should removed with the API v1 when the appropriate // functionality will appear in default handleInstallConfigure. func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Request) { reqData := applyConfigReqBeta{} @@ -523,7 +523,7 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques // firstRunDataBeta is a struct representing new client's getting addresses // request body. It uses array of structs instead of map. // -// TODO(e.burkov): this should removed with the API v1 when the appropriate +// TODO(e.burkov): This should removed with the API v1 when the appropriate // functionality will appear in default firstRunData. type firstRunDataBeta struct { WebPort int `json:"web_port"` @@ -534,7 +534,7 @@ type firstRunDataBeta struct { // handleInstallConfigureBeta is a substitution of /install/get_addresses // handler for new client. // -// TODO(e.burkov): this should removed with the API v1 when the appropriate +// TODO(e.burkov): This should removed with the API v1 when the appropriate // functionality will appear in default handleInstallGetAddresses. func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Request) { data := firstRunDataBeta{} @@ -570,7 +570,7 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req // registerBetaInstallHandlers registers the install handlers for new client // with the structures it supports. // -// TODO(e.burkov): this should removed with the API v1 when the appropriate +// TODO(e.burkov): This should removed with the API v1 when the appropriate // functionality will appear in default handlers. func (web *Web) registerBetaInstallHandlers() { Context.mux.HandleFunc("/control/install/get_addresses_beta", preInstall(ensureGET(web.handleInstallGetAddressesBeta))) diff --git a/internal/home/dns.go b/internal/home/dns.go index a9880629..82b844a4 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -55,8 +55,8 @@ func initDNSServer() error { filterConf := config.DNS.DnsfilterConf bindhost := config.DNS.BindHost - if config.DNS.BindHost == "0.0.0.0" { - bindhost = "127.0.0.1" + if config.DNS.BindHost.IsUnspecified() { + bindhost = net.IPv4(127, 0, 0, 1) } filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) filterConf.AutoHosts = &Context.autoHosts @@ -98,26 +98,24 @@ func isRunning() bool { } func onDNSRequest(d *proxy.DNSContext) { - ip := dnsforward.IPStringFromAddr(d.Addr) - if ip == "" { + ip := dnsforward.IPFromAddr(d.Addr) + if ip == nil { // This would be quite weird if we get here return } - ipAddr := net.ParseIP(ip) - if !ipAddr.IsLoopback() { + if !ip.IsLoopback() { Context.rdns.Begin(ip) } - if !Context.ipDetector.detectSpecialNetwork(ipAddr) { + if !Context.ipDetector.detectSpecialNetwork(ip) { Context.whois.Begin(ip) } } func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) { - bindHost := net.ParseIP(config.DNS.BindHost) newconfig = dnsforward.ServerConfig{ - UDPListenAddr: &net.UDPAddr{IP: bindHost, Port: config.DNS.Port}, - TCPListenAddr: &net.TCPAddr{IP: bindHost, Port: config.DNS.Port}, + UDPListenAddr: &net.UDPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port}, + TCPListenAddr: &net.TCPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port}, FilteringConfig: config.DNS.FilteringConfig, ConfigModified: onConfigModified, HTTPRegister: httpRegister, @@ -131,20 +129,20 @@ func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) { if tlsConf.PortDNSOverTLS != 0 { newconfig.TLSListenAddr = &net.TCPAddr{ - IP: bindHost, + IP: config.DNS.BindHost, Port: tlsConf.PortDNSOverTLS, } } if tlsConf.PortDNSOverQUIC != 0 { newconfig.QUICListenAddr = &net.UDPAddr{ - IP: bindHost, + IP: config.DNS.BindHost, Port: int(tlsConf.PortDNSOverQUIC), } } if tlsConf.PortDNSCrypt != 0 { - newconfig.DNSCryptConfig, err = newDNSCrypt(bindHost, tlsConf) + newconfig.DNSCryptConfig, err = newDNSCrypt(config.DNS.BindHost, tlsConf) if err != nil { // Don't wrap the error, because it's already // wrapped by newDNSCrypt. @@ -245,7 +243,7 @@ func getDNSEncryption() dnsEncryption { func getDNSAddresses() []string { dnsAddresses := []string{} - if config.DNS.BindHost == "0.0.0.0" { + if config.DNS.BindHost.IsUnspecified() { ifaces, e := util.GetValidNetInterfacesForWeb() if e != nil { log.Error("Couldn't get network interfaces: %v", e) @@ -276,10 +274,10 @@ func getDNSAddresses() []string { } // If a client has his own settings, apply them -func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { +func applyAdditionalFiltering(clientAddr net.IP, setts *dnsfilter.RequestFilteringSettings) { Context.dnsFilter.ApplyBlockedServices(setts, nil, true) - if len(clientAddr) == 0 { + if clientAddr == nil { return } setts.ClientIP = clientAddr @@ -328,13 +326,11 @@ func startDNSServer() error { Context.queryLog.Start() const topClientsNumber = 100 // the number of clients to get - topClients := Context.stats.GetTopClientsIP(topClientsNumber) - for _, ip := range topClients { - ipAddr := net.ParseIP(ip) - if !ipAddr.IsLoopback() { + for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) { + if !ip.IsLoopback() { Context.rdns.Begin(ip) } - if !Context.ipDetector.detectSpecialNetwork(ipAddr) { + if !Context.ipDetector.detectSpecialNetwork(ip) { Context.whois.Begin(ip) } } diff --git a/internal/home/home.go b/internal/home/home.go index fbead57a..55d6c20d 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -206,7 +206,7 @@ func setupConfig(args options) { } // override bind host/port from the console - if args.bindHost != "" { + if args.bindHost != nil { config.BindHost = args.bindHost } if args.bindPort != 0 { @@ -575,36 +575,40 @@ func printHTTPAddresses(proto string) { port = strconv.Itoa(tlsConf.PortHTTPS) } + var hostStr string if proto == "https" && tlsConf.ServerName != "" { if tlsConf.PortHTTPS == 443 { log.Printf("Go to https://%s", tlsConf.ServerName) } else { log.Printf("Go to https://%s:%s", tlsConf.ServerName, port) } - } else if config.BindHost == "0.0.0.0" { + } else if config.BindHost.IsUnspecified() { log.Println("AdGuard Home is available on the following addresses:") ifaces, err := util.GetValidNetInterfacesForWeb() if err != nil { // That's weird, but we'll ignore it - log.Printf("Go to %s://%s", proto, net.JoinHostPort(config.BindHost, port)) + hostStr = config.BindHost.String() + log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, port)) if config.BetaBindPort != 0 { - log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(config.BindHost, strconv.Itoa(config.BetaBindPort))) + log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort))) } return } for _, iface := range ifaces { for _, addr := range iface.Addresses { - log.Printf("Go to %s://%s", proto, net.JoinHostPort(addr, strconv.Itoa(config.BindPort))) + hostStr = addr.String() + log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BindPort))) if config.BetaBindPort != 0 { - log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(addr, strconv.Itoa(config.BetaBindPort))) + log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort))) } } } } else { - log.Printf("Go to %s://%s", proto, net.JoinHostPort(config.BindHost, port)) + hostStr = config.BindHost.String() + log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, port)) if config.BetaBindPort != 0 { - log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(config.BindHost, strconv.Itoa(config.BetaBindPort))) + log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort))) } } } diff --git a/internal/home/home_test.go b/internal/home/home_test.go index b21d7d46..344657e9 100644 --- a/internal/home/home_test.go +++ b/internal/home/home_test.go @@ -1,6 +1,6 @@ // +build !race -// TODO(e.burkov): remove this weird buildtag. +// TODO(e.burkov): Remove this weird buildtag. package home diff --git a/internal/home/options.go b/internal/home/options.go index 0493e856..897cbd08 100644 --- a/internal/home/options.go +++ b/internal/home/options.go @@ -2,6 +2,7 @@ package home import ( "fmt" + "net" "os" "strconv" @@ -13,7 +14,7 @@ type options struct { verbose bool // is verbose logging enabled configFilename string // path to the config file workDir string // path to the working directory where we will store the filters data and the querylog - bindHost string // host address to bind HTTP server on + bindHost net.IP // host address to bind HTTP server on bindPort int // port to serve HTTP pages on logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog pidFile string // File name to save PID to @@ -54,10 +55,19 @@ type arg struct { // against its zero value and return nil if the parameter value is // zero otherwise they return a string slice of the parameter +func ipSliceOrNil(ip net.IP) []string { + if ip == nil { + return nil + } + + return []string{ip.String()} +} + func stringSliceOrNil(s string) []string { if s == "" { return nil } + return []string{s} } @@ -65,6 +75,7 @@ func intSliceOrNil(i int) []string { if i == 0 { return nil } + return []string{strconv.Itoa(i)} } @@ -72,6 +83,7 @@ func boolSliceOrNil(b bool) []string { if b { return []string{} } + return nil } @@ -96,8 +108,8 @@ var workDirArg = arg{ var hostArg = arg{ "Host address to bind HTTP server on", "host", "h", - func(o options, v string) (options, error) { o.bindHost = v; return o, nil }, nil, nil, - func(o options) []string { return stringSliceOrNil(o.bindHost) }, + func(o options, v string) (options, error) { o.bindHost = net.ParseIP(v); return o, nil }, nil, nil, + func(o options) []string { return ipSliceOrNil(o.bindHost) }, } var portArg = arg{ diff --git a/internal/home/options_test.go b/internal/home/options_test.go index afaa873f..f24dc816 100644 --- a/internal/home/options_test.go +++ b/internal/home/options_test.go @@ -2,6 +2,7 @@ package home import ( "fmt" + "net" "testing" ) @@ -65,14 +66,14 @@ func TestParseWorkDir(t *testing.T) { } func TestParseBindHost(t *testing.T) { - if testParseOk(t).bindHost != "" { + if testParseOk(t).bindHost != nil { t.Fatal("empty is no host") } - if testParseOk(t, "-h", "addr").bindHost != "addr" { + if !testParseOk(t, "-h", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) { t.Fatal("-h is host") } testParseParamMissing(t, "-h") - if testParseOk(t, "--host", "addr").bindHost != "addr" { + if !testParseOk(t, "--host", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) { t.Fatal("--host is host") } testParseParamMissing(t, "--host") @@ -204,7 +205,7 @@ func TestSerializeWorkDir(t *testing.T) { } func TestSerializeBindHost(t *testing.T) { - testSerialize(t, options{bindHost: "addr"}, "-h", "addr") + testSerialize(t, options{bindHost: net.IP{1, 2, 3, 4}}, "-h", "1.2.3.4") } func TestSerializeBindPort(t *testing.T) { diff --git a/internal/home/rdns.go b/internal/home/rdns.go index 05df66ef..c71f3822 100644 --- a/internal/home/rdns.go +++ b/internal/home/rdns.go @@ -2,6 +2,7 @@ package home import ( "encoding/binary" + "net" "strings" "time" @@ -15,7 +16,7 @@ import ( type RDNS struct { dnsServer *dnsforward.Server clients *clientsContainer - ipChannel chan string // pass data from DNS request handling thread to rDNS thread + ipChannel chan net.IP // pass data from DNS request handling thread to rDNS thread // Contains IP addresses of clients to be resolved by rDNS // If IP address is resolved, it stays here while it's inside Clients. @@ -35,15 +36,15 @@ func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS { cconf.MaxCount = 10000 r.ipAddrs = cache.New(cconf) - r.ipChannel = make(chan string, 256) + r.ipChannel = make(chan net.IP, 256) go r.workerLoop() return &r } // Begin - add IP address to rDNS queue -func (r *RDNS) Begin(ip string) { +func (r *RDNS) Begin(ip net.IP) { now := uint64(time.Now().Unix()) - expire := r.ipAddrs.Get([]byte(ip)) + expire := r.ipAddrs.Get(ip) if len(expire) != 0 { exp := binary.BigEndian.Uint64(expire) if exp > now { @@ -54,7 +55,7 @@ func (r *RDNS) Begin(ip string) { expire = make([]byte, 8) const ttl = 1 * 60 * 60 binary.BigEndian.PutUint64(expire, now+ttl) - _ = r.ipAddrs.Set([]byte(ip), expire) + _ = r.ipAddrs.Set(ip, expire) if r.clients.Exists(ip, ClientSourceRDNS) { return @@ -70,7 +71,7 @@ func (r *RDNS) Begin(ip string) { } // Use rDNS to get hostname by IP address -func (r *RDNS) resolve(ip string) string { +func (r *RDNS) resolve(ip net.IP) string { log.Tracef("Resolving host for %s", ip) req := dns.Msg{} @@ -83,7 +84,7 @@ func (r *RDNS) resolve(ip string) string { }, } var err error - req.Question[0].Name, err = dns.ReverseAddr(ip) + req.Question[0].Name, err = dns.ReverseAddr(ip.String()) if err != nil { log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err) return "" @@ -123,6 +124,6 @@ func (r *RDNS) workerLoop() { continue } - _, _ = r.clients.AddHost(ip, host, ClientSourceRDNS) + _, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS) } } diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 516b0ed5..53dd093d 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -1,6 +1,7 @@ package home import ( + "net" "testing" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" @@ -16,6 +17,6 @@ func TestResolveRDNS(t *testing.T) { clients := &clientsContainer{} rdns := InitRDNS(dns, clients) - r := rdns.resolve("1.1.1.1") + r := rdns.resolve(net.IP{1, 1, 1, 1}) assert.Equal(t, "one.one.one.one", r, r) } diff --git a/internal/home/web.go b/internal/home/web.go index 83fe9db4..72a2eb62 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -31,7 +31,7 @@ const ( type webConfig struct { firstRun bool - BindHost string + BindHost net.IP BindPort int BetaBindPort int PortHTTPS int @@ -161,10 +161,11 @@ func (web *Web) Start() { printHTTPAddresses("http") errs := make(chan error, 2) + hostStr := web.conf.BindHost.String() // we need to have new instance, because after Shutdown() the Server is not usable web.httpServer = &http.Server{ ErrorLog: log.StdLog("web: http", log.DEBUG), - Addr: net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BindPort)), + Addr: net.JoinHostPort(hostStr, strconv.Itoa(web.conf.BindPort)), Handler: withMiddlewares(Context.mux, limitRequestBody), ReadTimeout: web.conf.ReadTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout, @@ -177,7 +178,7 @@ func (web *Web) Start() { if web.conf.BetaBindPort != 0 { web.httpServerBeta = &http.Server{ ErrorLog: log.StdLog("web: http", log.DEBUG), - Addr: net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BetaBindPort)), + Addr: net.JoinHostPort(hostStr, strconv.Itoa(web.conf.BetaBindPort)), Handler: withMiddlewares(Context.mux, limitRequestBody, web.wrapIndexBeta), ReadTimeout: web.conf.ReadTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout, @@ -236,7 +237,7 @@ func (web *Web) tlsServerLoop() { web.httpsServer.cond.L.Unlock() // prepare HTTPS server - address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.PortHTTPS)) + address := net.JoinHostPort(web.conf.BindHost.String(), strconv.Itoa(web.conf.PortHTTPS)) web.httpsServer.server = &http.Server{ ErrorLog: log.StdLog("web: https", log.DEBUG), Addr: address, diff --git a/internal/home/whois.go b/internal/home/whois.go index 4884d776..6c40ed54 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -26,7 +26,7 @@ const ( // Whois - module context type Whois struct { clients *clientsContainer - ipChan chan string + ipChan chan net.IP timeoutMsec uint // Contains IP addresses of clients @@ -46,7 +46,7 @@ func initWhois(clients *clientsContainer) *Whois { cconf.MaxCount = 10000 w.ipAddrs = cache.New(cconf) - w.ipChan = make(chan string, 255) + w.ipChan = make(chan net.IP, 255) go w.workerLoop() return &w } @@ -183,9 +183,9 @@ func (w *Whois) queryAll(target string) (string, error) { } // Request WHOIS information -func (w *Whois) process(ip string) [][]string { +func (w *Whois) process(ip net.IP) [][]string { data := [][]string{} - resp, err := w.queryAll(ip) + resp, err := w.queryAll(ip.String()) if err != nil { log.Debug("Whois: error: %s IP:%s", err, ip) return data @@ -209,7 +209,7 @@ func (w *Whois) process(ip string) [][]string { } // Begin - begin requesting WHOIS info -func (w *Whois) Begin(ip string) { +func (w *Whois) Begin(ip net.IP) { now := uint64(time.Now().Unix()) expire := w.ipAddrs.Get([]byte(ip)) if len(expire) != 0 { diff --git a/internal/querylog/decode.go b/internal/querylog/decode.go index ed721489..ad0948ba 100644 --- a/internal/querylog/decode.go +++ b/internal/querylog/decode.go @@ -22,9 +22,11 @@ var logEntryHandlers = map[string]logEntryHandler{ if !ok { return nil } - if len(ent.IP) == 0 { - ent.IP = v + + if ent.IP == nil { + ent.IP = net.ParseIP(v) } + return nil }, "T": func(t json.Token, ent *logEntry) error { diff --git a/internal/querylog/decode_test.go b/internal/querylog/decode_test.go index a599084d..40052fea 100644 --- a/internal/querylog/decode_test.go +++ b/internal/querylog/decode_test.go @@ -47,7 +47,7 @@ func TestDecodeLogEntry(t *testing.T) { assert.Nil(t, err) want := &logEntry{ - IP: "127.0.0.1", + IP: net.IPv4(127, 0, 0, 1), Time: time.Date(2020, 11, 25, 15, 55, 56, 519796000, time.UTC), QHost: "an.yandex.ru", QType: "A", diff --git a/internal/querylog/json.go b/internal/querylog/json.go index 3beeb0f1..152f6ce2 100644 --- a/internal/querylog/json.go +++ b/internal/querylog/json.go @@ -14,22 +14,19 @@ import ( // TODO(a.garipov): Use a proper structured approach here. // Get Client IP address -func (l *queryLog) getClientIP(clientIP string) string { - if l.conf.AnonymizeClientIP { - ip := net.ParseIP(clientIP) - if ip != nil { - ip4 := ip.To4() - const AnonymizeClientIP4Mask = 16 - const AnonymizeClientIP6Mask = 112 - if ip4 != nil { - clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String() - } else { - clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String() - } +func (l *queryLog) getClientIP(ip net.IP) (clientIP net.IP) { + if l.conf.AnonymizeClientIP && ip != nil { + const AnonymizeClientIPv4Mask = 16 + const AnonymizeClientIPv6Mask = 112 + + if ip.To4() != nil { + return ip.Mask(net.CIDRMask(AnonymizeClientIPv4Mask, 32)) } + + return ip.Mask(net.CIDRMask(AnonymizeClientIPv6Mask, 128)) } - return clientIP + return ip } // jobject is a JSON object alias. @@ -153,9 +150,9 @@ func answerToMap(a *dns.Msg) (answers []jobject) { // try most common record types switch v := k.(type) { case *dns.A: - answer["value"] = v.A.String() + answer["value"] = v.A case *dns.AAAA: - answer["value"] = v.AAAA.String() + answer["value"] = v.AAAA case *dns.MX: answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) case *dns.CNAME: diff --git a/internal/querylog/qlog.go b/internal/querylog/qlog.go index 97343006..30054a92 100644 --- a/internal/querylog/qlog.go +++ b/internal/querylog/qlog.go @@ -3,6 +3,7 @@ package querylog import ( "fmt" + "net" "os" "path/filepath" "strings" @@ -60,7 +61,7 @@ func NewClientProto(s string) (cp ClientProto, err error) { // logEntry - represents a single log entry type logEntry struct { - IP string `json:"IP"` // Client IP + IP net.IP `json:"IP"` // Client IP Time time.Time `json:"T"` QHost string `json:"QH"` @@ -147,7 +148,7 @@ func (l *queryLog) Add(params AddParams) { now := time.Now() entry := logEntry{ - IP: l.getClientIP(params.ClientIP.String()), + IP: l.getClientIP(params.ClientIP), Time: now, Result: *params.Result, diff --git a/internal/querylog/qlog_test.go b/internal/querylog/qlog_test.go index 0fa072c1..24d9064e 100644 --- a/internal/querylog/qlog_test.go +++ b/internal/querylog/qlog_test.go @@ -40,27 +40,27 @@ func TestQueryLog(t *testing.T) { l := newQueryLog(conf) // add disk entries - addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") + addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // write to disk (first file) _ = l.flushLogBuffer(true) // start writing to the second file _ = l.rotate() // add disk entries - addEntry(l, "example.org", "1.1.1.2", "2.2.2.2") + addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) // write to disk _ = l.flushLogBuffer(true) // add memory entries - addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3") - addEntry(l, "example.com", "1.1.1.4", "2.2.2.4") + addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) + addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) // get all entries params := newSearchParams() entries, _ := l.search(params) assert.Len(t, entries, 4) - assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") - assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") - assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") - assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1") + assertLogEntry(t, entries[0], "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) + assertLogEntry(t, entries[1], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) + assertLogEntry(t, entries[2], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) + assertLogEntry(t, entries[3], "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // search by domain (strict) params = newSearchParams() @@ -71,7 +71,7 @@ func TestQueryLog(t *testing.T) { }) entries, _ = l.search(params) assert.Len(t, entries, 1) - assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") + assertLogEntry(t, entries[0], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) // search by domain (not strict) params = newSearchParams() @@ -82,9 +82,9 @@ func TestQueryLog(t *testing.T) { }) entries, _ = l.search(params) assert.Len(t, entries, 3) - assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") - assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2") - assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1") + assertLogEntry(t, entries[0], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) + assertLogEntry(t, entries[1], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) + assertLogEntry(t, entries[2], "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // search by client IP (strict) params = newSearchParams() @@ -95,7 +95,7 @@ func TestQueryLog(t *testing.T) { }) entries, _ = l.search(params) assert.Len(t, entries, 1) - assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2") + assertLogEntry(t, entries[0], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) // search by client IP (part of) params = newSearchParams() @@ -106,10 +106,10 @@ func TestQueryLog(t *testing.T) { }) entries, _ = l.search(params) assert.Len(t, entries, 4) - assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") - assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") - assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") - assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1") + assertLogEntry(t, entries[0], "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) + assertLogEntry(t, entries[1], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) + assertLogEntry(t, entries[2], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) + assertLogEntry(t, entries[3], "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } func TestQueryLogOffsetLimit(t *testing.T) { @@ -124,13 +124,13 @@ func TestQueryLogOffsetLimit(t *testing.T) { // add 10 entries to the log for i := 0; i < 10; i++ { - addEntry(l, "second.example.org", "1.1.1.1", "2.2.2.1") + addEntry(l, "second.example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // write them to disk (first file) _ = l.flushLogBuffer(true) // add 10 more entries to the log (memory) for i := 0; i < 10; i++ { - addEntry(l, "first.example.org", "1.1.1.1", "2.2.2.1") + addEntry(l, "first.example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // First page @@ -178,7 +178,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) { // add 10 entries to the log for i := 0; i < 10; i++ { - addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") + addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // write them to disk (first file) _ = l.flushLogBuffer(true) @@ -204,9 +204,9 @@ func TestQueryLogFileDisabled(t *testing.T) { defer func() { _ = os.RemoveAll(conf.BaseDir) }() l := newQueryLog(conf) - addEntry(l, "example1.org", "1.1.1.1", "2.2.2.1") - addEntry(l, "example2.org", "1.1.1.1", "2.2.2.1") - addEntry(l, "example3.org", "1.1.1.1", "2.2.2.1") + addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) + addEntry(l, "example2.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) + addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // the oldest entry is now removed from mem buffer params := newSearchParams() @@ -216,7 +216,7 @@ func TestQueryLogFileDisabled(t *testing.T) { assert.Equal(t, "example2.org", ll[1].QHost) } -func addEntry(l *queryLog, host, answerStr, client string) { +func addEntry(l *queryLog, host string, answerStr, client net.IP) { q := dns.Msg{} q.Question = append(q.Question, dns.Question{ Name: host + ".", @@ -232,7 +232,7 @@ func addEntry(l *queryLog, host, answerStr, client string) { Rrtype: dns.TypeA, Class: dns.ClassINET, } - answer.A = net.ParseIP(answerStr) + answer.A = answerStr a.Answer = append(a.Answer, answer) res := dnsfilter.Result{ IsFiltered: true, @@ -248,13 +248,13 @@ func addEntry(l *queryLog, host, answerStr, client string) { Answer: &a, OrigAnswer: &a, Result: &res, - ClientIP: net.ParseIP(client), + ClientIP: client, Upstream: "upstream", } l.Add(params) } -func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) bool { +func assertLogEntry(t *testing.T, entry *logEntry, host string, answer, client net.IP) bool { assert.Equal(t, host, entry.QHost) assert.Equal(t, client, entry.IP) assert.Equal(t, "A", entry.QType) @@ -263,9 +263,9 @@ func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) msg := new(dns.Msg) assert.Nil(t, msg.Unpack(entry.Answer)) assert.Len(t, msg.Answer, 1) - ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]) + ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]).To16() assert.NotNil(t, ip) - assert.Equal(t, answer, ip.String()) + assert.Equal(t, answer, ip) return true } diff --git a/internal/querylog/searchcriteria.go b/internal/querylog/searchcriteria.go index f6b0ee98..1c2b26e3 100644 --- a/internal/querylog/searchcriteria.go +++ b/internal/querylog/searchcriteria.go @@ -94,16 +94,20 @@ func (c *searchCriteria) ctDomainOrClientCase(entry *logEntry) bool { if c.strict && qhost == searchVal { return true } + if !c.strict && strings.Contains(qhost, searchVal) { return true } - if c.strict && entry.IP == c.value { + ipStr := entry.IP.String() + if c.strict && ipStr == c.value { return true } - if !c.strict && strings.Contains(entry.IP, c.value) { + + if !c.strict && strings.Contains(ipStr, c.value) { return true } + return false } diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 5cd5910d..1addbebd 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -48,7 +48,7 @@ type Stats interface { Update(e Entry) // Get IP addresses of the clients with the most number of requests - GetTopClientsIP(limit uint) []string + GetTopClientsIP(limit uint) []net.IP // WriteDiskConfig - write configuration WriteDiskConfig(dc *DiskConfig) diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index 47e68799..5fd9ff55 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -80,7 +80,7 @@ func TestStats(t *testing.T) { assert.EqualValues(t, 0.123456, d["avg_processing_time"].(float64)) topClients := s.GetTopClientsIP(2) - assert.Equal(t, "127.0.0.1", topClients[0]) + assert.True(t, net.IP{127, 0, 0, 1}.Equal(topClients[0])) s.clear() s.Close() diff --git a/internal/stats/unit.go b/internal/stats/unit.go index a8bd224c..35deb327 100644 --- a/internal/stats/unit.go +++ b/internal/stats/unit.go @@ -443,22 +443,19 @@ func (s *statsCtx) clear() { } // Get Client IP address -func (s *statsCtx) getClientIP(clientIP string) string { - if s.conf.AnonymizeClientIP { - ip := net.ParseIP(clientIP) - if ip != nil { - ip4 := ip.To4() - const AnonymizeClientIP4Mask = 16 - const AnonymizeClientIP6Mask = 112 - if ip4 != nil { - clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String() - } else { - clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String() - } +func (s *statsCtx) getClientIP(ip net.IP) (clientIP net.IP) { + if s.conf.AnonymizeClientIP && ip != nil { + const AnonymizeClientIP4Mask = 16 + const AnonymizeClientIP6Mask = 112 + + if ip.To4() != nil { + return ip.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)) } + + return ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)) } - return clientIP + return ip } func (s *statsCtx) Update(e Entry) { @@ -468,7 +465,7 @@ func (s *statsCtx) Update(e Entry) { !(len(e.Client) == 4 || len(e.Client) == 16) { return } - client := s.getClientIP(e.Client.String()) + client := s.getClientIP(e.Client) s.unitLock.Lock() u := s.unit @@ -481,7 +478,7 @@ func (s *statsCtx) Update(e Entry) { u.blockedDomains[e.Domain]++ } - u.clients[client]++ + u.clients[client.String()]++ u.timeSum += uint64(e.Time) u.nTotal++ s.unitLock.Unlock() @@ -658,7 +655,7 @@ func (s *statsCtx) getData() map[string]interface{} { return d } -func (s *statsCtx) GetTopClientsIP(maxCount uint) []string { +func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP { units, _ := s.loadUnits(s.conf.limit) if units == nil { return nil @@ -672,9 +669,9 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []string { } } a := convertMapToArray(m, int(maxCount)) - d := []string{} + d := []net.IP{} for _, it := range a { - d = append(d, it.Name) + d = append(d, net.ParseIP(it.Name)) } return d } diff --git a/internal/sysutil/net_linux.go b/internal/sysutil/net_linux.go index 06d27eb2..8f47cf42 100644 --- a/internal/sysutil/net_linux.go +++ b/internal/sysutil/net_linux.go @@ -119,17 +119,13 @@ func ifacesStaticConfig(r io.Reader, ifaceName string) (has bool, err error) { } func ifaceSetStaticIP(ifaceName string) (err error) { - ip := util.GetSubnet(ifaceName) - if len(ip) == 0 { + ipNet := util.GetSubnet(ifaceName) + if ipNet.IP == nil { return errors.New("can't get IP address") } - ip4, _, err := net.ParseCIDR(ip) - if err != nil { - return err - } gatewayIP := GatewayIP(ifaceName) - add := updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, ip4) + add := updateStaticIPdhcpcdConf(ifaceName, ipNet.String(), gatewayIP, ipNet.IP) body, err := ioutil.ReadFile("/etc/dhcpcd.conf") if err != nil { diff --git a/internal/util/autohosts_test.go b/internal/util/autohosts_test.go index 393646c7..b9632855 100644 --- a/internal/util/autohosts_test.go +++ b/internal/util/autohosts_test.go @@ -108,11 +108,11 @@ func TestAutoHostsFSNotify(t *testing.T) { ips = ah.Process("newhost", dns.TypeA) assert.NotNil(t, ips) assert.Len(t, ips, 1) - assert.Equal(t, "127.0.0.2", ips[0].String()) + assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0])) } func TestIP(t *testing.T) { - assert.Equal(t, "127.0.0.1", DNSUnreverseAddr("1.0.0.127.in-addr.arpa").String()) + assert.True(t, net.IP{127, 0, 0, 1}.Equal(DNSUnreverseAddr("1.0.0.127.in-addr.arpa"))) assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String()) assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String()) diff --git a/internal/util/network.go b/internal/util/network.go index 1731ed08..f83bbd63 100644 --- a/internal/util/network.go +++ b/internal/util/network.go @@ -15,12 +15,12 @@ import ( // NetInterface represents a list of network interfaces type NetInterface struct { - Name string // Network interface name - MTU int // MTU - HardwareAddr string // Hardware address - Addresses []string // Array with the network interface addresses - Subnets []string // Array with CIDR addresses of this network interface - Flags string // Network interface flags (up, broadcast, etc) + Name string // Network interface name + MTU int // MTU + HardwareAddr string // Hardware address + Addresses []net.IP // Array with the network interface addresses + Subnets []*net.IPNet // Array with CIDR addresses of this network interface + Flags string // Network interface flags (up, broadcast, etc) } // GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP @@ -78,8 +78,8 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) { if ipNet.IP.IsLinkLocalUnicast() { continue } - netIface.Addresses = append(netIface.Addresses, ipNet.IP.String()) - netIface.Subnets = append(netIface.Subnets, ipNet.String()) + netIface.Addresses = append(netIface.Addresses, ipNet.IP) + netIface.Subnets = append(netIface.Subnets, ipNet) } // Discard interfaces with no addresses @@ -91,8 +91,8 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) { return netInterfaces, nil } -// GetInterfaceByIP - Get interface name by its IP address. -func GetInterfaceByIP(ip string) string { +// GetInterfaceByIP returns the name of interface containing provided ip. +func GetInterfaceByIP(ip net.IP) string { ifaces, err := GetValidNetInterfacesForWeb() if err != nil { return "" @@ -100,7 +100,7 @@ func GetInterfaceByIP(ip string) string { for _, iface := range ifaces { for _, addr := range iface.Addresses { - if ip == addr { + if ip.Equal(addr) { return iface.Name } } @@ -109,13 +109,13 @@ func GetInterfaceByIP(ip string) string { return "" } -// GetSubnet - Get IP address with netmask for the specified interface -// Returns an empty string if it fails to find it -func GetSubnet(ifaceName string) string { +// GetSubnet returns pointer to net.IPNet for the specified interface or nil if +// the search fails. +func GetSubnet(ifaceName string) *net.IPNet { netIfaces, err := GetValidNetInterfacesForWeb() if err != nil { log.Error("Could not get network interfaces info: %v", err) - return "" + return nil } for _, netIface := range netIfaces { @@ -124,12 +124,12 @@ func GetSubnet(ifaceName string) string { } } - return "" + return nil } // CheckPortAvailable - check if TCP port is available -func CheckPortAvailable(host string, port int) error { - ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) +func CheckPortAvailable(host net.IP, port int) error { + ln, err := net.Listen("tcp", net.JoinHostPort(host.String(), strconv.Itoa(port))) if err != nil { return err } @@ -142,8 +142,8 @@ func CheckPortAvailable(host string, port int) error { } // CheckPacketPortAvailable - check if UDP port is available -func CheckPacketPortAvailable(host string, port int) error { - ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port))) +func CheckPacketPortAvailable(host net.IP, port int) error { + ln, err := net.ListenPacket("udp", net.JoinHostPort(host.String(), strconv.Itoa(port))) if err != nil { return err }