diff --git a/go.mod b/go.mod index ce16e99b..0e6d2f59 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.16 require ( github.com/AdguardTeam/dnsproxy v0.39.2 - github.com/AdguardTeam/golibs v0.8.4 + github.com/AdguardTeam/golibs v0.9.0 github.com/AdguardTeam/urlfilter v0.14.6 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.2.1 @@ -25,7 +25,7 @@ require ( github.com/ti-mo/netfilter v0.4.0 go.etcd.io/bbolt v1.3.5 golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a - golang.org/x/net v0.0.0-20210510120150-4163338589ed + golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985 golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 09bf0b42..cce9ed46 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,8 @@ github.com/AdguardTeam/dnsproxy v0.39.2/go.mod h1:aNXKNdTyKfgAG2OS712SYSaGIM9Aas github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.8.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= -github.com/AdguardTeam/golibs v0.8.4 h1:jd6GwvQQtfSLOKn30qisDVujvas3q7Agjm3BOEqRWpQ= -github.com/AdguardTeam/golibs v0.8.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= +github.com/AdguardTeam/golibs v0.9.0 h1:QwmHqeZOVs9XpkmPb2iYpZ35OBArjgTesE8gLtEFRFg= +github.com/AdguardTeam/golibs v0.9.0/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/urlfilter v0.14.6 h1:emqoKZElooHACYehRBYENeKVN1a/rspxiqTIMYLuoIo= github.com/AdguardTeam/urlfilter v0.14.6/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U= @@ -300,8 +300,8 @@ golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210510120150-4163338589ed h1:p9UgmWI9wKpfYmgaV/IZKGdXc5qEK45tDwwwDyjS26I= -golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985 h1:4CSI6oo7cOjJKajidEljs9h+uP0rRZBPPPhcCbj5mw8= +golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= diff --git a/internal/aghnet/addr.go b/internal/aghnet/addr.go deleted file mode 100644 index 4eb957c5..00000000 --- a/internal/aghnet/addr.go +++ /dev/null @@ -1,206 +0,0 @@ -package aghnet - -import ( - "fmt" - "net" - "strconv" - "strings" - - "github.com/AdguardTeam/golibs/errors" - "golang.org/x/net/idna" -) - -// CloneIP returns a clone of an IP address. -func CloneIP(ip net.IP) (clone net.IP) { - if ip != nil && len(ip) == 0 { - return net.IP{} - } - - return append(clone, ip...) -} - -// CloneMAC returns a clone of a MAC address. -func CloneMAC(mac net.HardwareAddr) (clone net.HardwareAddr) { - if mac != nil && len(mac) == 0 { - return net.HardwareAddr{} - } - - return append(clone, mac...) -} - -// IPFromAddr returns an IP address from addr. If addr is neither -// a *net.TCPAddr nor a *net.UDPAddr, it returns nil. -func IPFromAddr(addr net.Addr) (ip net.IP) { - switch addr := addr.(type) { - case *net.TCPAddr: - return addr.IP - case *net.UDPAddr: - return addr.IP - } - - return nil -} - -// IsValidHostOuterRune returns true if r is a valid initial or final rune for -// a hostname label. -func IsValidHostOuterRune(r rune) (ok bool) { - return (r >= 'a' && r <= 'z') || - (r >= 'A' && r <= 'Z') || - (r >= '0' && r <= '9') -} - -// JoinHostPort is a convinient wrapper for net.JoinHostPort with port of type -// int. -func JoinHostPort(host string, port int) (hostport string) { - return net.JoinHostPort(host, strconv.Itoa(port)) -} - -// isValidHostRune returns true if r is a valid rune for a hostname label. -func isValidHostRune(r rune) (ok bool) { - return r == '-' || IsValidHostOuterRune(r) -} - -// ValidateHardwareAddress returns an error if hwa is not a valid EUI-48, -// EUI-64, or 20-octet InfiniBand link-layer address. -func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) { - defer func() { err = errors.Annotate(err, "validating hardware address %q: %w", hwa) }() - - switch l := len(hwa); l { - case 0: - return errors.Error("address is empty") - case 6, 8, 20: - return nil - default: - return fmt.Errorf("bad len: %d", l) - } -} - -// maxDomainLabelLen is the maximum allowed length of a domain name label -// according to RFC 1035. -const maxDomainLabelLen = 63 - -// MaxDomainNameLen is the maximum allowed length of a full domain name -// according to RFC 1035. -// -// See https://stackoverflow.com/a/32294443/1892060. -const MaxDomainNameLen = 253 - -// ValidateDomainNameLabel returns an error if label is not a valid label of -// a domain name. -func ValidateDomainNameLabel(label string) (err error) { - defer func() { err = errors.Annotate(err, "validating label %q: %w", label) }() - - l := len(label) - if l > maxDomainLabelLen { - return fmt.Errorf("label is too long, max: %d", maxDomainLabelLen) - } else if l == 0 { - return errors.Error("label is empty") - } - - if r := label[0]; !IsValidHostOuterRune(rune(r)) { - return fmt.Errorf("invalid char %q at index %d", r, 0) - } else if l == 1 { - return nil - } - - for i, r := range label[1 : l-1] { - if !isValidHostRune(r) { - return fmt.Errorf("invalid char %q at index %d", r, i+1) - } - } - - if r := label[l-1]; !IsValidHostOuterRune(rune(r)) { - return fmt.Errorf("invalid char %q at index %d", r, l-1) - } - - return nil -} - -// ValidateDomainName validates the domain name in accordance to RFC 952, RFC -// 1035, and with RFC-1123's inclusion of digits at the start of the host. It -// doesn't validate against two or more hyphens to allow punycode and -// internationalized domains. -// -// TODO(a.garipov): After making sure that this works correctly, port this into -// module golibs. -func ValidateDomainName(name string) (err error) { - defer func() { err = errors.Annotate(err, "validating domain name %q: %w", name) }() - - name, err = idna.ToASCII(name) - if err != nil { - return err - } - - l := len(name) - if l == 0 { - return errors.Error("domain name is empty") - } else if l > MaxDomainNameLen { - return fmt.Errorf("too long, max: %d", MaxDomainNameLen) - } - - labels := strings.Split(name, ".") - for i, l := range labels { - err = ValidateDomainNameLabel(l) - if err != nil { - return fmt.Errorf("invalid domain name label at index %d: %w", i, err) - } - } - - return nil -} - -// The maximum lengths of generated hostnames for different IP versions. -const ( - ipv4HostnameMaxLen = len("192-168-100-100") - ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010") -) - -// generateIPv4Hostname generates the hostname for specific IP version. -func generateIPv4Hostname(ipv4 net.IP) (hostname string) { - hnData := make([]byte, 0, ipv4HostnameMaxLen) - for i, part := range ipv4 { - if i > 0 { - hnData = append(hnData, '-') - } - hnData = strconv.AppendUint(hnData, uint64(part), 10) - } - - return string(hnData) -} - -// generateIPv6Hostname generates the hostname for specific IP version. -func generateIPv6Hostname(ipv6 net.IP) (hostname string) { - hnData := make([]byte, 0, ipv6HostnameMaxLen) - for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ { - if i > 0 { - hnData = append(hnData, '-') - } - for _, val := range ipv6[i*2 : i*2+2] { - if val < 10 { - hnData = append(hnData, '0') - } - hnData = strconv.AppendUint(hnData, uint64(val), 16) - } - } - - return string(hnData) -} - -// GenerateHostname generates the hostname from ip. In case of using IPv4 the -// result should be like: -// -// 192-168-10-1 -// -// In case of using IPv6, the result is like: -// -// ff80-f076-0000-0000-0000-0000-0000-0010 -// -func GenerateHostname(ip net.IP) (hostname string) { - if ipv4 := ip.To4(); ipv4 != nil { - return generateIPv4Hostname(ipv4) - } else if ipv6 := ip.To16(); ipv6 != nil { - return generateIPv6Hostname(ipv6) - } - - return "" -} diff --git a/internal/aghnet/addr_test.go b/internal/aghnet/addr_test.go deleted file mode 100644 index 0711a597..00000000 --- a/internal/aghnet/addr_test.go +++ /dev/null @@ -1,228 +0,0 @@ -package aghnet - -import ( - "net" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCloneIP(t *testing.T) { - assert.Equal(t, net.IP(nil), CloneIP(nil)) - assert.Equal(t, net.IP{}, CloneIP(net.IP{})) - - ip := net.IP{1, 2, 3, 4} - clone := CloneIP(ip) - assert.Equal(t, ip, clone) - assert.NotSame(t, &ip[0], &clone[0]) -} - -func TestCloneMAC(t *testing.T) { - assert.Equal(t, net.HardwareAddr(nil), CloneMAC(nil)) - assert.Equal(t, net.HardwareAddr{}, CloneMAC(net.HardwareAddr{})) - - mac := net.HardwareAddr{0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC} - clone := CloneMAC(mac) - assert.Equal(t, mac, clone) - assert.NotSame(t, &mac[0], &clone[0]) -} - -func TestIPFromAddr(t *testing.T) { - ip := net.IP{1, 2, 3, 4} - assert.Equal(t, net.IP(nil), IPFromAddr(nil)) - assert.Equal(t, net.IP(nil), IPFromAddr(struct{ net.Addr }{})) - assert.Equal(t, ip, IPFromAddr(&net.TCPAddr{IP: ip})) - assert.Equal(t, ip, IPFromAddr(&net.UDPAddr{IP: ip})) -} - -func TestValidateHardwareAddress(t *testing.T) { - testCases := []struct { - name string - wantErrMsg string - in net.HardwareAddr - }{{ - name: "success_eui_48", - wantErrMsg: "", - in: net.HardwareAddr{0x00, 0x01, 0x02, 0x03, 0x04, 0x05}, - }, { - name: "success_eui_64", - wantErrMsg: "", - in: net.HardwareAddr{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, - }, { - name: "success_infiniband", - wantErrMsg: "", - in: net.HardwareAddr{ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, - }, - }, { - name: "error_nil", - wantErrMsg: `validating hardware address "": address is empty`, - in: nil, - }, { - name: "error_empty", - wantErrMsg: `validating hardware address "": address is empty`, - in: net.HardwareAddr{}, - }, { - name: "error_bad", - wantErrMsg: `validating hardware address "00:01:02:03": bad len: 4`, - in: net.HardwareAddr{0x00, 0x01, 0x02, 0x03}, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := ValidateHardwareAddress(tc.in) - if tc.wantErrMsg == "" { - assert.NoError(t, err) - } else { - require.Error(t, err) - - assert.Equal(t, tc.wantErrMsg, err.Error()) - } - }) - } -} - -func TestJoinHostPort(t *testing.T) { - assert.Equal(t, ":0", JoinHostPort("", 0)) - assert.Equal(t, "host:12345", JoinHostPort("host", 12345)) - assert.Equal(t, "1.2.3.4:12345", JoinHostPort("1.2.3.4", 12345)) - assert.Equal(t, "[1234::5678]:12345", JoinHostPort("1234::5678", 12345)) - assert.Equal(t, "[1234::5678%lo]:12345", JoinHostPort("1234::5678%lo", 12345)) -} - -func repeatStr(b *strings.Builder, s string, n int) { - for i := 0; i < n; i++ { - _, _ = b.WriteString(s) - } -} - -func TestValidateDomainName(t *testing.T) { - b := &strings.Builder{} - repeatStr(b, "a", 255) - longDomainName := b.String() - - b.Reset() - repeatStr(b, "a", 64) - longLabel := b.String() - - _, _ = b.WriteString(".com") - longLabelDomainName := b.String() - - testCases := []struct { - name string - in string - wantErrMsg string - }{{ - name: "success", - in: "example.com", - wantErrMsg: "", - }, { - name: "success_idna", - in: "пример.рф", - wantErrMsg: "", - }, { - name: "success_one", - in: "e", - wantErrMsg: "", - }, { - name: "empty", - in: "", - wantErrMsg: `validating domain name "": domain name is empty`, - }, { - name: "bad_symbol", - in: "!!!", - wantErrMsg: `validating domain name "!!!": invalid domain name label at index 0: ` + - `validating label "!!!": invalid char '!' at index 0`, - }, { - name: "bad_length", - in: longDomainName, - wantErrMsg: `validating domain name "` + longDomainName + `": too long, max: 253`, - }, { - name: "bad_label_length", - in: longLabelDomainName, - wantErrMsg: `validating domain name "` + longLabelDomainName + `": ` + - `invalid domain name label at index 0: validating label "` + longLabel + - `": label is too long, max: 63`, - }, { - name: "bad_label_empty", - in: "example..com", - wantErrMsg: `validating domain name "example..com": ` + - `invalid domain name label at index 1: ` + - `validating label "": label is empty`, - }, { - name: "bad_label_first_symbol", - in: "example.-aa.com", - wantErrMsg: `validating domain name "example.-aa.com": ` + - `invalid domain name label at index 1: ` + - `validating label "-aa": invalid char '-' at index 0`, - }, { - name: "bad_label_last_symbol", - in: "example-.aa.com", - wantErrMsg: `validating domain name "example-.aa.com": ` + - `invalid domain name label at index 0: ` + - `validating label "example-": invalid char '-' at index 7`, - }, { - name: "bad_label_symbol", - in: "example.a!!!.com", - wantErrMsg: `validating domain name "example.a!!!.com": ` + - `invalid domain name label at index 1: ` + - `validating label "a!!!": invalid char '!' at index 1`, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := ValidateDomainName(tc.in) - if tc.wantErrMsg == "" { - assert.NoError(t, err) - } else { - require.Error(t, err) - - assert.Equal(t, tc.wantErrMsg, err.Error()) - } - }) - } -} - -func TestGenerateHostName(t *testing.T) { - testCases := []struct { - name string - want string - ip net.IP - }{{ - name: "good_ipv4", - want: "127-0-0-1", - ip: net.IP{127, 0, 0, 1}, - }, { - name: "bad_ipv4", - want: "", - ip: net.IP{127, 0, 0, 1, 0}, - }, { - name: "good_ipv6", - want: "fe00-0000-0000-0000-0000-0000-0000-0001", - ip: net.ParseIP("fe00::1"), - }, { - name: "bad_ipv6", - want: "", - ip: net.IP{ - 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, - }, - }, { - name: "nil", - want: "", - ip: nil, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - hostname := GenerateHostname(tc.ip) - assert.Equal(t, tc.want, hostname) - }) - } -} diff --git a/internal/aghnet/etchostscontainer.go b/internal/aghnet/etchostscontainer.go index f95896ce..af597e61 100644 --- a/internal/aghnet/etchostscontainer.go +++ b/internal/aghnet/etchostscontainer.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/fsnotify/fsnotify" "github.com/miekg/dns" ) @@ -29,7 +30,7 @@ type EtcHostsContainer struct { table map[string][]net.IP // tableReverse is the IP-to-hosts map. The type of the values in the // map is []string. - tableReverse *IPMap + tableReverse *netutil.IPMap hostsFn string // path to the main hosts-file hostsDirs []string // paths to OS-specific directories with hosts-files @@ -150,8 +151,10 @@ func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts [ return nil } - ip := UnreverseAddr(addr) - if ip == nil { + ip, err := netutil.IPFromReversedAddr(addr) + if err != nil { + log.Error("etchosts: reversed addr: %s", err) + return nil } @@ -179,7 +182,7 @@ func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts [ // List returns an IP-to-hostnames table. The type of the values in the map is // []string. It is safe for concurrent use. -func (ehc *EtcHostsContainer) List() (ipToHosts *IPMap) { +func (ehc *EtcHostsContainer) List() (ipToHosts *netutil.IPMap) { ehc.lock.RLock() defer ehc.lock.RUnlock() @@ -211,7 +214,7 @@ func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string } // updateTableRev updates the reverse address table. -func (ehc *EtcHostsContainer) updateTableRev(tableRev *IPMap, newHost string, ip net.IP) { +func (ehc *EtcHostsContainer) updateTableRev(tableRev *netutil.IPMap, newHost string, ip net.IP) { v, ok := tableRev.Get(ip) if !ok { tableRev.Set(ip, []string{newHost}) @@ -258,7 +261,7 @@ func parseHostsLine(fields []string) (hosts []string) { // line for one IP are supported. func (ehc *EtcHostsContainer) load( table map[string][]net.IP, - tableRev *IPMap, + tableRev *netutil.IPMap, fn string, ) { f, err := os.Open(fn) @@ -353,7 +356,7 @@ func (ehc *EtcHostsContainer) watcherLoop() { // updateHosts - loads system hosts func (ehc *EtcHostsContainer) updateHosts() { table := make(map[string][]net.IP) - tableRev := NewIPMap(0) + tableRev := netutil.NewIPMap(0) ehc.load(table, tableRev, ehc.hostsFn) diff --git a/internal/aghnet/hostgen.go b/internal/aghnet/hostgen.go new file mode 100644 index 00000000..d9278515 --- /dev/null +++ b/internal/aghnet/hostgen.go @@ -0,0 +1,62 @@ +package aghnet + +import ( + "net" + "strconv" +) + +// The maximum lengths of generated hostnames for different IP versions. +const ( + ipv4HostnameMaxLen = len("192-168-100-100") + ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010") +) + +// generateIPv4Hostname generates the hostname for specific IP version. +func generateIPv4Hostname(ipv4 net.IP) (hostname string) { + hnData := make([]byte, 0, ipv4HostnameMaxLen) + for i, part := range ipv4 { + if i > 0 { + hnData = append(hnData, '-') + } + hnData = strconv.AppendUint(hnData, uint64(part), 10) + } + + return string(hnData) +} + +// generateIPv6Hostname generates the hostname for specific IP version. +func generateIPv6Hostname(ipv6 net.IP) (hostname string) { + hnData := make([]byte, 0, ipv6HostnameMaxLen) + for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ { + if i > 0 { + hnData = append(hnData, '-') + } + for _, val := range ipv6[i*2 : i*2+2] { + if val < 10 { + hnData = append(hnData, '0') + } + hnData = strconv.AppendUint(hnData, uint64(val), 16) + } + } + + return string(hnData) +} + +// GenerateHostname generates the hostname from ip. In case of using IPv4 the +// result should be like: +// +// 192-168-10-1 +// +// In case of using IPv6, the result is like: +// +// ff80-f076-0000-0000-0000-0000-0000-0010 +// +func GenerateHostname(ip net.IP) (hostname string) { + if ipv4 := ip.To4(); ipv4 != nil { + return generateIPv4Hostname(ipv4) + } else if ipv6 := ip.To16(); ipv6 != nil { + return generateIPv6Hostname(ipv6) + } + + return "" +} diff --git a/internal/aghnet/hostgen_test.go b/internal/aghnet/hostgen_test.go new file mode 100644 index 00000000..37121628 --- /dev/null +++ b/internal/aghnet/hostgen_test.go @@ -0,0 +1,48 @@ +package aghnet + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateHostName(t *testing.T) { + testCases := []struct { + name string + want string + ip net.IP + }{{ + name: "good_ipv4", + want: "127-0-0-1", + ip: net.IP{127, 0, 0, 1}, + }, { + name: "bad_ipv4", + want: "", + ip: net.IP{127, 0, 0, 1, 0}, + }, { + name: "good_ipv6", + want: "fe00-0000-0000-0000-0000-0000-0000-0001", + ip: net.ParseIP("fe00::1"), + }, { + name: "bad_ipv6", + want: "", + ip: net.IP{ + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + }, + }, { + name: "nil", + want: "", + ip: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + hostname := GenerateHostname(tc.ip) + assert.Equal(t, tc.want, hostname) + }) + } +} diff --git a/internal/aghnet/ipmap.go b/internal/aghnet/ipmap.go deleted file mode 100644 index 281e27ee..00000000 --- a/internal/aghnet/ipmap.go +++ /dev/null @@ -1,117 +0,0 @@ -package aghnet - -import ( - "fmt" - "net" -) - -// ipArr is a representation of an IP address as an array of bytes. -type ipArr [16]byte - -// String implements the fmt.Stringer interface for ipArr. -func (a ipArr) String() (s string) { - return net.IP(a[:]).String() -} - -// IPMap is a map of IP addresses. -type IPMap struct { - m map[ipArr]interface{} -} - -// NewIPMap returns a new empty IP map using hint as a size hint for the -// underlying map. -func NewIPMap(hint int) (m *IPMap) { - return &IPMap{ - m: make(map[ipArr]interface{}, hint), - } -} - -// ipToArr converts a net.IP into an ipArr. -// -// TODO(a.garipov): Use the slice-to-array conversion in Go 1.17. -func ipToArr(ip net.IP) (a ipArr) { - copy(a[:], ip.To16()) - - return a -} - -// Del deletes ip from the map. Calling Del on a nil *IPMap has no effect, just -// like delete on an empty map doesn't. -func (m *IPMap) Del(ip net.IP) { - if m != nil { - delete(m.m, ipToArr(ip)) - } -} - -// Get returns the value from the map. Calling Get on a nil *IPMap returns nil -// and false, just like indexing on an empty map does. -func (m *IPMap) Get(ip net.IP) (v interface{}, ok bool) { - if m != nil { - v, ok = m.m[ipToArr(ip)] - - return v, ok - } - - return nil, false -} - -// Len returns the length of the map. A nil *IPMap has a length of zero, just -// like an empty map. -func (m *IPMap) Len() (n int) { - if m == nil { - return 0 - } - - return len(m.m) -} - -// Range calls f for each key and value present in the map in an undefined -// order. If cont is false, range stops the iteration. Calling Range on a nil -// *IPMap has no effect, just like ranging over a nil map. -func (m *IPMap) Range(f func(ip net.IP, v interface{}) (cont bool)) { - if m == nil { - return - } - - for k, v := range m.m { - // Array slicing produces a pointer, so copy the array here. - // - // See https://github.com/AdguardTeam/AdGuardHome/issues/3346 - // as well as https://github.com/kyoh86/looppointer/issues/9. - k := k - if !f(net.IP(k[:]), v) { - break - } - } -} - -// Set sets the value. Set panics if the m is a nil *IPMap, just like a nil map -// does. -func (m *IPMap) Set(ip net.IP, v interface{}) { - m.m[ipToArr(ip)] = v -} - -// ShallowClone returns a shallow clone of the map. -func (m *IPMap) ShallowClone() (sclone *IPMap) { - if m == nil { - return nil - } - - sclone = NewIPMap(m.Len()) - m.Range(func(ip net.IP, v interface{}) (cont bool) { - sclone.Set(ip, v) - - return true - }) - - return sclone -} - -// String implements the fmt.Stringer interface for *IPMap. -func (m *IPMap) String() (s string) { - if m == nil { - return "" - } - - return fmt.Sprint(m.m) -} diff --git a/internal/aghnet/ipmap_test.go b/internal/aghnet/ipmap_test.go deleted file mode 100644 index 3d3e765d..00000000 --- a/internal/aghnet/ipmap_test.go +++ /dev/null @@ -1,142 +0,0 @@ -package aghnet - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestIPMap_allocs(t *testing.T) { - ip4 := net.IP{1, 2, 3, 4} - m := NewIPMap(0) - m.Set(ip4, 42) - - t.Run("get", func(t *testing.T) { - var v interface{} - var ok bool - allocs := testing.AllocsPerRun(100, func() { - v, ok = m.Get(ip4) - }) - - require.True(t, ok) - require.Equal(t, 42, v) - - assert.Equal(t, float64(0), allocs) - }) - - t.Run("len", func(t *testing.T) { - var n int - allocs := testing.AllocsPerRun(100, func() { - n = m.Len() - }) - - require.Equal(t, 1, n) - - assert.Equal(t, float64(0), allocs) - }) -} - -func TestIPMap(t *testing.T) { - ip4 := net.IP{1, 2, 3, 4} - ip6 := net.IP{ - 0x12, 0x34, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x56, 0x78, - } - - val := 42 - - t.Run("nil", func(t *testing.T) { - var m *IPMap - - assert.NotPanics(t, func() { - m.Del(ip4) - m.Del(ip6) - }) - - assert.NotPanics(t, func() { - v, ok := m.Get(ip4) - assert.Nil(t, v) - assert.False(t, ok) - - v, ok = m.Get(ip6) - assert.Nil(t, v) - assert.False(t, ok) - }) - - assert.NotPanics(t, func() { - assert.Equal(t, 0, m.Len()) - }) - - assert.NotPanics(t, func() { - n := 0 - m.Range(func(_ net.IP, _ interface{}) (cont bool) { - n++ - - return true - }) - - assert.Equal(t, 0, n) - }) - - assert.Panics(t, func() { - m.Set(ip4, val) - }) - - assert.Panics(t, func() { - m.Set(ip6, val) - }) - - assert.NotPanics(t, func() { - sclone := m.ShallowClone() - assert.Nil(t, sclone) - }) - }) - - testIPMap := func(t *testing.T, ip net.IP, s string) { - m := NewIPMap(0) - assert.Equal(t, 0, m.Len()) - - v, ok := m.Get(ip) - assert.Nil(t, v) - assert.False(t, ok) - - m.Set(ip, val) - v, ok = m.Get(ip) - assert.Equal(t, val, v) - assert.True(t, ok) - - n := 0 - m.Range(func(ipKey net.IP, v interface{}) (cont bool) { - assert.Equal(t, ip.To16(), ipKey) - assert.Equal(t, val, v) - - n++ - - return false - }) - assert.Equal(t, 1, n) - - sclone := m.ShallowClone() - assert.Equal(t, m, sclone) - - assert.Equal(t, s, m.String()) - - m.Del(ip) - v, ok = m.Get(ip) - assert.Nil(t, v) - assert.False(t, ok) - assert.Equal(t, 0, m.Len()) - } - - t.Run("ipv4", func(t *testing.T) { - testIPMap(t, ip4, "map[1.2.3.4:42]") - }) - - t.Run("ipv6", func(t *testing.T) { - testIPMap(t, ip6, "map[1234::5678:42]") - }) -} diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index c063cbea..6e68da03 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -8,14 +8,13 @@ import ( "os" "os/exec" "runtime" - "strconv" "strings" "syscall" "time" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/stringutil" + "github.com/AdguardTeam/golibs/netutil" ) // ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about @@ -79,14 +78,14 @@ func CanBindPrivilegedPorts() (can bool, err error) { // NetInterface represents an entry of network interfaces map. type NetInterface struct { - MTU int `json:"mtu"` + // Addresses are the network interface addresses. + Addresses []net.IP `json:"ip_addresses,omitempty"` + // Subnets are the IP networks for this network interface. + Subnets []*net.IPNet `json:"-"` Name string `json:"name"` HardwareAddr net.HardwareAddr `json:"hardware_address"` Flags net.Flags `json:"flags"` - // Array with the network interface addresses. - Addresses []net.IP `json:"ip_addresses,omitempty"` - // Array with IP networks for this network interface. - Subnets []*net.IPNet `json:"-"` + MTU int `json:"mtu"` } // MarshalJSON implements the json.Marshaler interface for NetInterface. @@ -192,7 +191,7 @@ func GetSubnet(ifaceName string) *net.IPNet { // CheckPortAvailable - check if TCP port is available func CheckPortAvailable(host net.IP, port int) error { - ln, err := net.Listen("tcp", JoinHostPort(host.String(), port)) + ln, err := net.Listen("tcp", netutil.JoinHostPort(host.String(), port)) if err != nil { return err } @@ -206,7 +205,7 @@ func CheckPortAvailable(host net.IP, port int) error { // CheckPacketPortAvailable - check if UDP port is available func CheckPacketPortAvailable(host net.IP, port int) error { - ln, err := net.ListenPacket("udp", JoinHostPort(host.String(), port)) + ln, err := net.ListenPacket("udp", netutil.JoinHostPort(host.String(), port)) if err != nil { return err } @@ -265,141 +264,6 @@ func SplitHost(hostport string) (host string, err error) { return host, nil } -// TODO(e.burkov): Inspect the charToHex, ipParseARPA6, ipReverse and -// UnreverseAddr and maybe refactor it. - -// charToHex converts character to a hexadecimal. -func charToHex(n byte) int8 { - if n >= '0' && n <= '9' { - return int8(n) - '0' - } else if (n|0x20) >= 'a' && (n|0x20) <= 'f' { - return (int8(n) | 0x20) - 'a' + 10 - } - return -1 -} - -// ipParseARPA6 parse IPv6 reverse address -func ipParseARPA6(s string) (ip6 net.IP) { - if len(s) != 63 { - return nil - } - - ip6 = make(net.IP, 16) - - for i := 0; i != 64; i += 4 { - // parse "0.1." - n := charToHex(s[i]) - n2 := charToHex(s[i+2]) - if s[i+1] != '.' || (i != 60 && s[i+3] != '.') || - n < 0 || n2 < 0 { - return nil - } - - ip6[16-i/4-1] = byte(n2<<4) | byte(n&0x0f) - } - return ip6 -} - -// ipReverse inverts byte order of ip. -func ipReverse(ip net.IP) (rev net.IP) { - ipLen := len(ip) - rev = make(net.IP, ipLen) - for i, b := range ip { - rev[ipLen-i-1] = b - } - - return rev -} - -// ARPA addresses' suffixes. -const ( - arpaV4Suffix = ".in-addr.arpa" - arpaV6Suffix = ".ip6.arpa" -) - -// UnreverseAddr tries to convert reversed ARPA to a normal IP address. -func UnreverseAddr(arpa string) (unreversed net.IP) { - // Unify the input data. - arpa = strings.TrimSuffix(arpa, ".") - arpa = strings.ToLower(arpa) - - if strings.HasSuffix(arpa, arpaV4Suffix) { - ip := strings.TrimSuffix(arpa, arpaV4Suffix) - ip4 := net.ParseIP(ip).To4() - if ip4 == nil { - return nil - } - - return ipReverse(ip4) - - } else if strings.HasSuffix(arpa, arpaV6Suffix) { - ip := strings.TrimSuffix(arpa, arpaV6Suffix) - return ipParseARPA6(ip) - } - - // The suffix unrecognizable. - return nil -} - -// The length of extreme cases of arpa formatted addresses. -// -// The example of IPv4 with maximum length: -// -// 49.91.20.104.in-addr.arpa -// -// The example of IPv6 with maximum length: -// -// 1.3.b.5.4.1.8.6.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.7.4.6.0.6.2.ip6.arpa -// -const ( - arpaV4MaxLen = len("000.000.000.000") + len(arpaV4Suffix) - arpaV6MaxLen = len("0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0") + - len(arpaV6Suffix) -) - -// ReverseAddr returns the ARPA hostname of the ip suitable for reverse DNS -// (PTR) record lookups. This is the modified version of ReverseAddr from -// github.com/miekg/dns package with no error among returned values. -func ReverseAddr(ip net.IP) (arpa string) { - const dot = "." - - var strLen int - var suffix string - var writeByte func(val byte) - b := &strings.Builder{} - if ip4 := ip.To4(); ip4 != nil { - strLen, suffix = arpaV4MaxLen, arpaV4Suffix[1:] - ip = ip4 - writeByte = func(val byte) { - stringutil.WriteToBuilder(b, strconv.Itoa(int(val)), dot) - } - - } else if ip6 := ip.To16(); ip6 != nil { - strLen, suffix = arpaV6MaxLen, arpaV6Suffix[1:] - ip = ip6 - writeByte = func(val byte) { - stringutil.WriteToBuilder( - b, - strconv.FormatUint(uint64(val&0xF), 16), - dot, - strconv.FormatUint(uint64(val>>4), 16), - dot, - ) - } - - } else { - return "" - } - - b.Grow(strLen) - for i := len(ip) - 1; i >= 0; i-- { - writeByte(ip[i]) - } - stringutil.WriteToBuilder(b, suffix) - - return b.String() -} - // CollectAllIfacesAddrs returns the slice of all network interfaces IP // addresses without port number. func CollectAllIfacesAddrs() (addrs []string, err error) { diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index 9c5afd07..3cd2fd6a 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -1,10 +1,8 @@ package aghnet import ( - "net" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,92 +14,3 @@ func TestGetValidNetInterfacesForWeb(t *testing.T) { require.NotEmptyf(t, iface.Addresses, "no addresses found for %s", iface.Name) } } - -func TestUnreverseAddr(t *testing.T) { - testCases := []struct { - name string - have string - want net.IP - }{{ - name: "good_ipv4", - have: "1.0.0.127.in-addr.arpa", - want: net.IP{127, 0, 0, 1}, - }, { - name: "good_ipv6", - have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - want: net.ParseIP("::abcd:1234"), - }, { - name: "good_ipv6_case", - have: "4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.iP6.ArPa", - want: net.ParseIP("::abcd:1234"), - }, { - name: "good_ipv4_dot", - have: "1.0.0.127.in-addr.arpa.", - want: net.IP{127, 0, 0, 1}, - }, { - name: "good_ipv4_case", - have: "1.0.0.127.In-Addr.Arpa", - want: net.IP{127, 0, 0, 1}, - }, { - name: "wrong_ipv4", - have: ".0.0.127.in-addr.arpa", - want: nil, - }, { - name: "wrong_ipv6", - have: ".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - want: nil, - }, { - name: "bad_ipv6_dot", - have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa", - want: nil, - }, { - name: "bad_ipv6_space", - have: "4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - want: nil, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ip := UnreverseAddr(tc.have) - assert.True(t, tc.want.Equal(ip)) - }) - } -} - -func TestReverseAddr(t *testing.T) { - testCases := []struct { - name string - want string - ip net.IP - }{{ - name: "valid_ipv4", - want: "4.3.2.1.in-addr.arpa", - ip: net.IP{1, 2, 3, 4}, - }, { - name: "valid_ipv6", - want: "1.3.b.5.4.1.8.6.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.7.4.6.0.6.2.ip6.arpa", - ip: net.ParseIP("2606:4700:10::6814:5b31"), - }, { - name: "nil_ip", - want: "", - ip: nil, - }, { - name: "unspecified_ipv6", - want: "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", - ip: net.IPv6unspecified, - }, { - name: "unspecified_ipv4", - want: "0.0.0.0.in-addr.arpa", - ip: net.IPv4zero, - }, { - name: "wrong_length_ip", - want: "", - ip: net.IP{1, 2, 3, 4, 5}, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, ReverseAddr(tc.ip)) - }) - } -} diff --git a/internal/aghnet/systemresolvers_others.go b/internal/aghnet/systemresolvers_others.go index 32c4ef0f..8acdb6c7 100644 --- a/internal/aghnet/systemresolvers_others.go +++ b/internal/aghnet/systemresolvers_others.go @@ -12,6 +12,7 @@ import ( "time" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" ) @@ -104,7 +105,7 @@ const dockerEmbeddedDNS = "127.0.0.11" func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) { // Just validate the passed address is a valid IP. var host string - host, err = SplitHost(address) + host, err = netutil.SplitHost(address) if err != nil { // TODO(e.burkov): Maybe use a structured errBadAddrPassed to // allow unwrapping of the real error. diff --git a/internal/dhcpd/checkother.go b/internal/dhcpd/checkother.go index 07320c8d..39e60955 100644 --- a/internal/dhcpd/checkother.go +++ b/internal/dhcpd/checkother.go @@ -11,10 +11,10 @@ import ( "runtime" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4/nclient4" "github.com/insomniacslk/dhcp/dhcpv6" @@ -45,7 +45,7 @@ func CheckIfOtherDHCPServersPresentV4(ifaceName string) (ok bool, err error) { } srcIP := ifaceIPNet[0] - src := aghnet.JoinHostPort(srcIP.String(), 68) + src := netutil.JoinHostPort(srcIP.String(), 68) dst := "255.255.255.255:67" hostname, _ := os.Hostname() @@ -176,7 +176,7 @@ func CheckIfOtherDHCPServersPresentV6(ifaceName string) (ok bool, err error) { } srcIP := ifaceIPNet[0] - src := aghnet.JoinHostPort(srcIP.String(), 546) + src := netutil.JoinHostPort(srcIP.String(), 546) dst := "[ff02::1:2]:547" req, err := dhcpv6.NewSolicit(iface.HardwareAddr) diff --git a/internal/dhcpd/dhcpd.go b/internal/dhcpd/dhcpd.go index 12bec4b7..b77b1e5b 100644 --- a/internal/dhcpd/dhcpd.go +++ b/internal/dhcpd/dhcpd.go @@ -10,8 +10,8 @@ import ( "runtime" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" ) const ( @@ -46,8 +46,8 @@ func (l *Lease) Clone() (clone *Lease) { return &Lease{ Expiry: l.Expiry, Hostname: l.Hostname, - HWAddr: aghnet.CloneMAC(l.HWAddr), - IP: aghnet.CloneIP(l.IP), + HWAddr: netutil.CloneMAC(l.HWAddr), + IP: netutil.CloneIP(l.IP), } } diff --git a/internal/dhcpd/routeradv.go b/internal/dhcpd/routeradv.go index 52d64bce..863ac5e1 100644 --- a/internal/dhcpd/routeradv.go +++ b/internal/dhcpd/routeradv.go @@ -7,8 +7,8 @@ import ( "sync/atomic" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "golang.org/x/net/icmp" "golang.org/x/net/ipv6" ) @@ -42,7 +42,7 @@ type icmpv6RA struct { // // See https://tools.ietf.org/html/rfc4861#section-4.6.1. func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) { - err = aghnet.ValidateHardwareAddress(hwa) + err = netutil.ValidateMAC(hwa) if err != nil { // Don't wrap the error, because it already contains enough // context. @@ -56,8 +56,8 @@ func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) { return lla, nil } - // Assume that aghnet.ValidateHardwareAddress prevents lengths other - // than 20 by now. + // Assume that netutil.ValidateMAC prevents lengths other than 20 by + // now. lla = make([]byte, 24) copy(lla, hwa) diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index 2b7c9b25..7848811f 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -14,6 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/go-ping/ping" "github.com/insomniacslk/dhcp/dhcpv4" @@ -61,7 +62,7 @@ func normalizeHostname(hostname string) (norm string, err error) { norm = strings.ToLower(hostname) parts := strings.FieldsFunc(norm, func(c rune) (ok bool) { - return c != '.' && !aghnet.IsValidHostOuterRune(c) + return c != '.' && !netutil.IsValidHostOuterRune(c) }) if len(parts) == 0 { @@ -87,7 +88,7 @@ func (s *v4Server) validHostnameForClient(cliHostname string, ip net.IP) (hostna hostname = aghnet.GenerateHostname(ip) } - err = aghnet.ValidateDomainName(hostname) + err = netutil.ValidateDomainName(hostname) if err != nil { log.Info("dhcpv4: %s", err) hostname = "" @@ -335,7 +336,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) { l.Expiry = time.Unix(leaseExpireStatic, 0) - err = aghnet.ValidateHardwareAddress(l.HWAddr) + err = netutil.ValidateMAC(l.HWAddr) if err != nil { return err } @@ -346,7 +347,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) { return err } - err = aghnet.ValidateDomainName(hostname) + err = netutil.ValidateDomainName(hostname) if err != nil { return fmt.Errorf("validating hostname: %w", err) } @@ -402,7 +403,7 @@ func (s *v4Server) RemoveStaticLease(l *Lease) (err error) { return fmt.Errorf("invalid IP") } - err = aghnet.ValidateHardwareAddress(l.HWAddr) + err = netutil.ValidateMAC(l.HWAddr) if err != nil { return fmt.Errorf("validating lease: %w", err) } @@ -913,7 +914,7 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4 return } - err = aghnet.ValidateHardwareAddress(req.ClientHWAddr) + err = netutil.ValidateMAC(req.ClientHWAddr) if err != nil { log.Error("dhcpv4: invalid ClientHWAddr: %s", err) @@ -1061,7 +1062,7 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) { Mask: subnetMask, } - bcastIP := aghnet.CloneIP(routerIP) + bcastIP := netutil.CloneIP(routerIP) for i, b := range subnetMask { bcastIP[i] |= ^b } diff --git a/internal/dhcpd/v6.go b/internal/dhcpd/v6.go index ff5be152..b4a4424f 100644 --- a/internal/dhcpd/v6.go +++ b/internal/dhcpd/v6.go @@ -10,9 +10,9 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/dhcpv6/server6" "github.com/insomniacslk/dhcp/iana" @@ -175,7 +175,7 @@ func (s *v6Server) AddStaticLease(l *Lease) (err error) { return fmt.Errorf("invalid IP") } - err = aghnet.ValidateHardwareAddress(l.HWAddr) + err = netutil.ValidateMAC(l.HWAddr) if err != nil { return fmt.Errorf("validating lease: %w", err) } @@ -207,7 +207,7 @@ func (s *v6Server) RemoveStaticLease(l *Lease) (err error) { return fmt.Errorf("invalid IP") } - err = aghnet.ValidateHardwareAddress(l.HWAddr) + err = netutil.ValidateMAC(l.HWAddr) if err != nil { return fmt.Errorf("validating lease: %w", err) } @@ -633,7 +633,7 @@ func (s *v6Server) Start() (err error) { log.Debug("dhcpv6: listening...") - err = aghnet.ValidateHardwareAddress(iface.HardwareAddr) + err = netutil.ValidateMAC(iface.HardwareAddr) if err != nil { return fmt.Errorf("validating interface %s: %w", iface.Name, err) } diff --git a/internal/dnsforward/access.go b/internal/dnsforward/access.go index bbf899c9..17619884 100644 --- a/internal/dnsforward/access.go +++ b/internal/dnsforward/access.go @@ -7,8 +7,8 @@ import ( "net/http" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter/filterlist" @@ -17,8 +17,8 @@ import ( // accessCtx controls IP and client blocking that takes place before all other // processing. An accessCtx is safe for concurrent use. type accessCtx struct { - allowedIPs *aghnet.IPMap - blockedIPs *aghnet.IPMap + allowedIPs *netutil.IPMap + blockedIPs *netutil.IPMap allowedClientIDs *stringutil.Set blockedClientIDs *stringutil.Set @@ -26,7 +26,7 @@ type accessCtx struct { blockedHostsEng *urlfilter.DNSEngine // TODO(a.garipov): Create a type for a set of IP networks. - // aghnet.IPNetSet? + // netutil.IPNetSet? allowedNets []*net.IPNet blockedNets []*net.IPNet } @@ -38,7 +38,7 @@ type unit = struct{} // which may be an IP address, a CIDR, or a ClientID. func processAccessClients( clientStrs []string, - ips *aghnet.IPMap, + ips *netutil.IPMap, nets *[]*net.IPNet, clientIDs *stringutil.Set, ) (err error) { @@ -68,8 +68,8 @@ func processAccessClients( // newAccessCtx creates a new accessCtx. func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) { a = &accessCtx{ - allowedIPs: aghnet.NewIPMap(0), - blockedIPs: aghnet.NewIPMap(0), + allowedIPs: netutil.NewIPMap(0), + blockedIPs: netutil.NewIPMap(0), allowedClientIDs: stringutil.NewSet(), blockedClientIDs: stringutil.NewSet(), diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 01301611..00433358 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -7,15 +7,15 @@ import ( "path" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" "github.com/lucas-clemente/quic-go" ) // ValidateClientID returns an error if clientID is not a valid client ID. func ValidateClientID(clientID string) (err error) { - err = aghnet.ValidateDomainNameLabel(clientID) + err = netutil.ValidateDomainNameLabel(clientID) if err != nil { // Replace the domain name label wrapper with our own. return fmt.Errorf("invalid client id %q: %w", clientID, errors.Unwrap(err)) diff --git a/internal/dnsforward/clientid_test.go b/internal/dnsforward/clientid_test.go index b4adf8de..7b14a836 100644 --- a/internal/dnsforward/clientid_test.go +++ b/internal/dnsforward/clientid_test.go @@ -46,6 +46,8 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) { } func TestServer_clientIDFromDNSContext(t *testing.T) { + // TODO(a.garipov): Consider moving away from the text-based error + // checks and onto a more structured approach. testCases := []struct { name string proto proxy.Proto @@ -111,7 +113,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) { cliSrvName: "!!!.example.com", wantClientID: "", wantErrMsg: `client id check: invalid client id "!!!": ` + - `invalid char '!' at index 0`, + `bad domain name label rune '!'`, strictSNI: true, }, { name: "tls_client_id_too_long", @@ -122,7 +124,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) { wantClientID: "", wantErrMsg: `client id check: invalid client id "abcdefghijklmno` + `pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` + - `label is too long, max: 63`, + `domain name label is too long: got 72, max 63`, strictSNI: true, }, { name: "quic_client_id", @@ -220,7 +222,7 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) { path: "/dns-query/!!!", wantClientID: "", wantErrMsg: `client id check: invalid client id "!!!": ` + - `invalid char '!' at index 0`, + `bad domain name label rune '!'`, }} for _, tc := range testCases { diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 201bcb82..88b44fb2 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -11,12 +11,12 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/ameshkov/dnscrypt/v2" ) @@ -451,7 +451,7 @@ func matchesDomainWildcard(host, pat string) (ok bool) { // anyNameMatches returns true if sni, the client's SNI value, matches any of // the DNS names and patterns from certificate. dnsNames must be sorted. func anyNameMatches(dnsNames []string, sni string) (ok bool) { - if aghnet.ValidateDomainName(sni) != nil { + if netutil.ValidateDomainName(sni) != nil { return false } diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index b11406f6..57fc7e0d 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -5,11 +5,11 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/miekg/dns" ) @@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) { s.tableHostToIP = t } -func (s *Server) setTableIPToHost(t *aghnet.IPMap) { +func (s *Server) setTableIPToHost(t *netutil.IPMap) { s.tableIPToHostLock.Lock() defer s.tableIPToHostLock.Unlock() @@ -188,18 +188,18 @@ func (s *Server) onDHCPLeaseChanged(flags int) { } var hostToIP hostToIPTable - var ipToHost *aghnet.IPMap + var ipToHost *netutil.IPMap if add { ll := s.dhcpServer.Leases(dhcpd.LeasesAll) hostToIP = make(hostToIPTable, len(ll)) - ipToHost = aghnet.NewIPMap(len(ll)) + ipToHost = netutil.NewIPMap(len(ll)) for _, l := range ll { // TODO(a.garipov): Remove this after we're finished // with the client hostname validations in the DHCP // server code. - err = aghnet.ValidateDomainName(l.Hostname) + err = netutil.ValidateDomainName(l.Hostname) if err != nil { log.Debug( "dns: skipping invalid hostname %q from dhcp: %s", @@ -230,7 +230,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) { rc = resultCodeSuccess var ip net.IP - if ip = aghnet.IPFromAddr(dctx.proxyCtx.Addr); ip == nil { + if ip, _ = netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr); ip == nil { return rc } @@ -331,12 +331,11 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - ip := aghnet.UnreverseAddr(q.Name) - if ip == nil { - // That's weird. - // - // TODO(e.burkov): Research the cases when it could happen. - return resultCodeSuccess + ip, err := netutil.IPFromReversedAddr(q.Name) + if err != nil { + log.Debug("dns: reversed addr: %s", err) + + return resultCodeError } // Restrict an access to local addresses for external clients. We also @@ -502,7 +501,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { // ipStringFromAddr extracts an IP address string from net.Addr. func ipStringFromAddr(addr net.Addr) (ipStr string) { - if ip := aghnet.IPFromAddr(addr); ip != nil { + if ip, _ := netutil.IPAndPortFromAddr(addr); ip != nil { return ip.String() } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 3af2e74c..6f9246bb 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -20,6 +20,7 @@ import ( "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/miekg/dns" ) @@ -81,7 +82,7 @@ type Server struct { tableHostToIP hostToIPTable tableHostToIPLock sync.Mutex - tableIPToHost *aghnet.IPMap + tableIPToHost *netutil.IPMap tableIPToHostLock sync.Mutex // clientIDCache is a temporary storage for clientIDs that were @@ -141,7 +142,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { if p.LocalDomain == "" { localDomainSuffix = defaultLocalDomainSuffix } else { - err = aghnet.ValidateDomainName(p.LocalDomain) + err = netutil.ValidateDomainName(p.LocalDomain) if err != nil { return nil, fmt.Errorf("local domain: %w", err) } @@ -281,7 +282,12 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) { return "", nil } - arpa := dns.Fqdn(aghnet.ReverseAddr(ip)) + arpa, err := netutil.IPToReversedAddr(ip) + if err != nil { + return "", fmt.Errorf("reversing ip: %w", err) + } + + arpa = dns.Fqdn(arpa) req := &dns.Msg{ MsgHdr: dns.MsgHdr{ Id: dns.Id(), diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 35477696..c7ceab58 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -1119,6 +1119,8 @@ func TestPTRResponseFromHosts(t *testing.T) { } func TestNewServer(t *testing.T) { + // TODO(a.garipov): Consider moving away from the text-based error + // checks and onto a more structured approach. testCases := []struct { name string in DNSCreateParams @@ -1144,9 +1146,8 @@ func TestNewServer(t *testing.T) { in: DNSCreateParams{ LocalDomain: "!!!", }, - wantErrMsg: `local domain: validating domain name "!!!": ` + - `invalid domain name label at index 0: ` + - `validating label "!!!": invalid char '!' at index 0`, + wantErrMsg: `local domain: bad domain name "!!!": ` + + `bad domain name label "!!!": bad domain name label rune '!'`, }} for _, tc := range testCases { diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 1f2d5998..5edca948 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -5,10 +5,10 @@ import ( "fmt" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" ) @@ -19,7 +19,7 @@ func (s *Server) beforeRequestHandler( _ *proxy.Proxy, pctx *proxy.DNSContext, ) (reply bool, err error) { - ip := aghnet.IPFromAddr(pctx.Addr) + ip, _ := netutil.IPAndPortFromAddr(pctx.Addr) clientID, err := s.clientIDFromDNSContext(pctx) if err != nil { return false, fmt.Errorf("getting clientid: %w", err) @@ -53,7 +53,8 @@ func (s *Server) beforeRequestHandler( func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings { setts := s.dnsFilter.GetConfig() if s.conf.FilterHandler != nil { - s.conf.FilterHandler(aghnet.IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts) + ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr) + s.conf.FilterHandler(ip, ctx.clientID, &setts) } return &setts diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index ba163691..3623de71 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -9,11 +9,11 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/miekg/dns" ) @@ -443,7 +443,7 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err continue } - err = aghnet.ValidateDomainName(host) + err = netutil.ValidateDomainName(host) if err != nil { return "", false, fmt.Errorf("domain at index %d: %w", i, err) } diff --git a/internal/dnsforward/recursiondetector.go b/internal/dnsforward/recursiondetector.go index e0e1537b..4555f4b0 100644 --- a/internal/dnsforward/recursiondetector.go +++ b/internal/dnsforward/recursiondetector.go @@ -5,9 +5,9 @@ import ( "encoding/binary" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" ) @@ -77,7 +77,7 @@ func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDet // msgToSignature converts msg into it's signature represented in bytes. func msgToSignature(msg dns.Msg) (sig []byte) { - sig = make([]byte, uint16sz*2+aghnet.MaxDomainNameLen) + sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen) // The binary.BigEndian byte order is used everywhere except when the // real machine's endianess is needed. byteOrder := binary.BigEndian @@ -95,7 +95,7 @@ func msgToSignature(msg dns.Msg) (sig []byte) { // See BenchmarkMsgToSignature. func msgToSignatureSlow(msg dns.Msg) (sig []byte) { type msgSignature struct { - name [aghnet.MaxDomainNameLen]byte + name [netutil.MaxDomainNameLen]byte id uint16 qtype uint16 } diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go index 3b2c9c4e..1714563f 100644 --- a/internal/dnsforward/stats.go +++ b/internal/dnsforward/stats.go @@ -4,11 +4,11 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" ) @@ -32,13 +32,14 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) { // Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use. // This can happen after proxy server has been stopped, but its workers haven't yet exited. if shouldLog && s.queryLog != nil { + ip, _ := netutil.IPAndPortFromAddr(pctx.Addr) p := querylog.AddParams{ Question: msg, Answer: pctx.Res, OrigAnswer: ctx.origResp, Result: ctx.result, Elapsed: elapsed, - ClientIP: aghnet.IPFromAddr(pctx.Addr), + ClientIP: ip, ClientID: ctx.clientID, } @@ -80,7 +81,7 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri if clientID := ctx.clientID; clientID != "" { e.Client = clientID - } else if ip := aghnet.IPFromAddr(pctx.Addr); ip != nil { + } else if ip, _ := netutil.IPAndPortFromAddr(pctx.Addr); ip != nil { e.Client = ip.String() } diff --git a/internal/home/auth.go b/internal/home/auth.go index 80c1a3a0..a2c36c80 100644 --- a/internal/home/auth.go +++ b/internal/home/auth.go @@ -13,8 +13,8 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "go.etcd.io/bbolt" "golang.org/x/crypto/bcrypt" ) @@ -404,8 +404,7 @@ func realIP(r *http.Request) (ip net.IP, err error) { // When everything else fails, just return the remote address as // understood by the stdlib. - var ipStr string - ipStr, err = aghnet.SplitHost(r.RemoteAddr) + ipStr, err := netutil.SplitHost(r.RemoteAddr) if err != nil { return nil, fmt.Errorf("getting ip from client addr: %w", err) } @@ -428,7 +427,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { // See https://github.com/AdguardTeam/AdGuardHome/issues/2799. // // TODO(e.burkov): Use realIP when the issue will be fixed. - if remoteAddr, err = aghnet.SplitHost(r.RemoteAddr); err != nil { + if remoteAddr, err = netutil.SplitHost(r.RemoteAddr); err != nil { httpError(w, http.StatusBadRequest, "auth: getting remote address: %s", err) return diff --git a/internal/home/clients.go b/internal/home/clients.go index 5fc81068..9ba8db0c 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -20,6 +20,7 @@ import ( "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" ) @@ -82,7 +83,7 @@ type clientsContainer struct { idIndex map[string]*Client // ID -> client // ipToRC is the IP address to *RuntimeClient map. - ipToRC *aghnet.IPMap + ipToRC *netutil.IPMap lock sync.Mutex @@ -112,7 +113,7 @@ func (clients *clientsContainer) Init( } clients.list = make(map[string]*Client) clients.idIndex = make(map[string]*Client) - clients.ipToRC = aghnet.NewIPMap(0) + clients.ipToRC = netutil.NewIPMap(0) clients.allTags = stringutil.NewSet(clientTags...) @@ -793,7 +794,7 @@ func (clients *clientsContainer) addFromSystemARP() { host := ln[:lparen] ipStr := ln[lparen+2 : rparen] ip := net.ParseIP(ipStr) - if aghnet.ValidateDomainName(host) != nil || ip == nil { + if netutil.ValidateDomainName(host) != nil || ip == nil { continue } diff --git a/internal/home/control.go b/internal/home/control.go index 8e64231b..a0c44ad8 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/NYTimes/gziphandler" ) @@ -39,7 +40,7 @@ func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) { for _, addr := range addrs { var hostport string if config.DNS.Port != 53 { - hostport = aghnet.JoinHostPort(addr.String(), config.DNS.Port) + hostport = netutil.JoinHostPort(addr.String(), config.DNS.Port) } else { hostport = addr.String() } @@ -294,7 +295,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) { return true } - host, err := aghnet.SplitHost(r.Host) + host, err := netutil.SplitHost(r.Host) if err != nil { httpError(w, http.StatusBadRequest, "bad host: %s", err) @@ -304,7 +305,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) { if r.TLS == nil && web.forceHTTPS { hostPort := host if port := web.conf.PortHTTPS; port != defaultHTTPSPort { - hostPort = aghnet.JoinHostPort(host, port) + hostPort = netutil.JoinHostPort(host, port) } httpsURL := &url.URL{ diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 3c9569a4..0d0d0134 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -16,6 +16,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" ) // getAddrsResponse is the response for /install/get_addresses endpoint. @@ -311,7 +312,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { w, http.StatusBadRequest, "can not listen on IP:port %s: %s", - aghnet.JoinHostPort(req.Web.IP.String(), req.Web.Port), + netutil.JoinHostPort(req.Web.IP.String(), req.Web.Port), err, ) diff --git a/internal/home/dns.go b/internal/home/dns.go index 8c0cc035..e3986acd 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -7,7 +7,6 @@ import ( "os" "path/filepath" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" @@ -15,6 +14,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/ameshkov/dnscrypt/v2" yaml "gopkg.in/yaml.v2" ) @@ -106,7 +106,7 @@ func isRunning() bool { } func onDNSRequest(pctx *proxy.DNSContext) { - ip := aghnet.IPFromAddr(pctx.Addr) + ip, _ := netutil.IPAndPortFromAddr(pctx.Addr) if ip == nil { // This would be quite weird if we get here. return @@ -254,7 +254,7 @@ func getDNSEncryption() (de dnsEncryption) { if tlsConf.PortHTTPS != 0 { addr := hostname if tlsConf.PortHTTPS != 443 { - addr = aghnet.JoinHostPort(addr, tlsConf.PortHTTPS) + addr = netutil.JoinHostPort(addr, tlsConf.PortHTTPS) } de.https = (&url.URL{ @@ -267,14 +267,14 @@ func getDNSEncryption() (de dnsEncryption) { if tlsConf.PortDNSOverTLS != 0 { de.tls = (&url.URL{ Scheme: "tls", - Host: aghnet.JoinHostPort(hostname, tlsConf.PortDNSOverTLS), + Host: netutil.JoinHostPort(hostname, tlsConf.PortDNSOverTLS), }).String() } if tlsConf.PortDNSOverQUIC != 0 { de.quic = (&url.URL{ Scheme: "quic", - Host: aghnet.JoinHostPort(hostname, tlsConf.PortDNSOverQUIC), + Host: netutil.JoinHostPort(hostname, tlsConf.PortDNSOverQUIC), }).String() } } diff --git a/internal/home/home.go b/internal/home/home.go index a6ee025e..e99eee65 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -30,6 +30,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "gopkg.in/natefinch/lumberjack.v2" ) @@ -707,12 +708,12 @@ func printWebAddrs(proto, addr string, port, betaPort int) { hostBetaMsg = hostMsg + " (BETA)" ) - log.Printf(hostMsg, proto, aghnet.JoinHostPort(addr, port)) + log.Printf(hostMsg, proto, netutil.JoinHostPort(addr, port)) if betaPort == 0 { return } - log.Printf(hostBetaMsg, proto, aghnet.JoinHostPort(addr, config.BetaBindPort)) + log.Printf(hostBetaMsg, proto, netutil.JoinHostPort(addr, config.BetaBindPort)) } // printHTTPAddresses prints the IP addresses which user can use to access the diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index c125bf5e..202f9f5f 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -8,12 +8,12 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -85,7 +85,7 @@ func TestRDNS_Begin(t *testing.T) { clients: &clientsContainer{ list: map[string]*Client{}, idIndex: tc.cliIDIndex, - ipToRC: aghnet.NewIPMap(0), + ipToRC: netutil.NewIPMap(0), allTags: stringutil.NewSet(), }, } @@ -205,7 +205,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { cc := &clientsContainer{ list: map[string]*Client{}, idIndex: map[string]*Client{}, - ipToRC: aghnet.NewIPMap(0), + ipToRC: netutil.NewIPMap(0), allTags: stringutil.NewSet(), } ch := make(chan net.IP) diff --git a/internal/home/upgrade.go b/internal/home/upgrade.go index 175e73f1..5599eb3b 100644 --- a/internal/home/upgrade.go +++ b/internal/home/upgrade.go @@ -11,9 +11,9 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/google/renameio/maybe" "golang.org/x/crypto/bcrypt" yaml "gopkg.in/yaml.v2" @@ -545,7 +545,7 @@ func addQUICPort(ups string, port int) (withPort string) { } var host string - host, err = aghnet.SplitHost(upsURL.Host) + host, err = netutil.SplitHost(upsURL.Host) if err != nil || host != upsURL.Host { return ups } diff --git a/internal/home/web.go b/internal/home/web.go index a22c6bd0..9be037cc 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/NYTimes/gziphandler" ) @@ -175,7 +176,7 @@ func (web *Web) Start() { // we need to have new instance, because after Shutdown() the Server is not usable web.httpServer = &http.Server{ ErrorLog: log.StdLog("web: plain", log.DEBUG), - Addr: aghnet.JoinHostPort(hostStr, web.conf.BindPort), + Addr: netutil.JoinHostPort(hostStr, web.conf.BindPort), Handler: withMiddlewares(Context.mux, limitRequestBody), ReadTimeout: web.conf.ReadTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout, @@ -188,7 +189,7 @@ func (web *Web) Start() { if web.conf.BetaBindPort != 0 { web.httpServerBeta = &http.Server{ ErrorLog: log.StdLog("web: plain", log.DEBUG), - Addr: aghnet.JoinHostPort(hostStr, web.conf.BetaBindPort), + Addr: netutil.JoinHostPort(hostStr, web.conf.BetaBindPort), Handler: withMiddlewares(Context.mux, limitRequestBody, web.wrapIndexBeta), ReadTimeout: web.conf.ReadTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout, @@ -249,7 +250,7 @@ func (web *Web) tlsServerLoop() { web.httpsServer.cond.L.Unlock() // prepare HTTPS server - address := aghnet.JoinHostPort(web.conf.BindHost.String(), web.conf.PortHTTPS) + address := netutil.JoinHostPort(web.conf.BindHost.String(), web.conf.PortHTTPS) web.httpsServer.server = &http.Server{ ErrorLog: log.StdLog("web: https", log.DEBUG), Addr: address,