diff --git a/internal/aghnet/systemresolvers.go b/internal/aghnet/systemresolvers.go index 4cba0f92..96a674c7 100644 --- a/internal/aghnet/systemresolvers.go +++ b/internal/aghnet/systemresolvers.go @@ -15,9 +15,6 @@ import ( // implementation must be safe for concurrent use. type HostGenFunc func() (host string) -// unit is an alias for an existing map value. -type unit = struct{} - // SystemResolvers helps to work with local resolvers' addresses provided by OS. type SystemResolvers interface { // Get returns the slice of local resolvers' addresses. diff --git a/internal/aghnet/systemresolvers_others.go b/internal/aghnet/systemresolvers_others.go index 975ff744..d2bec062 100644 --- a/internal/aghnet/systemresolvers_others.go +++ b/internal/aghnet/systemresolvers_others.go @@ -11,6 +11,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" ) // defaultHostGen is the default method of generating host for Refresh. @@ -24,8 +25,8 @@ type systemResolvers struct { resolver *net.Resolver hostGenFunc HostGenFunc - // addrs is the map that contains cached local resolvers' addresses. - addrs map[string]unit + // addrs is the set that contains cached local resolvers' addresses. + addrs *aghstrings.Set addrsLock sync.RWMutex } @@ -50,7 +51,7 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S PreferGo: true, }, hostGenFunc: hostGenFunc, - addrs: make(map[string]unit), + addrs: aghstrings.NewSet(), } s.resolver.Dial = s.dialFunc @@ -75,7 +76,7 @@ func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net sr.addrsLock.Lock() defer sr.addrsLock.Unlock() - sr.addrs[host] = unit{} + sr.addrs.Add(host) return nil, fakeDialErr } @@ -84,13 +85,5 @@ func (sr *systemResolvers) Get() (rs []string) { sr.addrsLock.RLock() defer sr.addrsLock.RUnlock() - addrs := sr.addrs - rs = make([]string, len(addrs)) - var i int - for addr := range addrs { - rs[i] = addr - i++ - } - - return rs + return sr.addrs.Values() } diff --git a/internal/dnsfilter/dnsfilter.go b/internal/dnsfilter/dnsfilter.go index 3a5abdf3..5b656ca2 100644 --- a/internal/dnsfilter/dnsfilter.go +++ b/internal/dnsfilter/dnsfilter.go @@ -14,6 +14,7 @@ import ( "sync" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" @@ -477,10 +478,10 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) { res.Reason = Rewritten } - cnames := map[string]bool{} + cnames := aghstrings.NewSet() origHost := host for len(rr) != 0 && rr[0].Type == dns.TypeCNAME { - log.Debug("Rewrite: CNAME for %s is %s", host, rr[0].Answer) + log.Debug("rewrite: CNAME for %s is %s", host, rr[0].Answer) if host == rr[0].Answer { // "host == CNAME" is an exception res.Reason = NotFilteredNotFound @@ -489,12 +490,13 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) { } host = rr[0].Answer - _, ok := cnames[host] - if ok { - log.Info("Rewrite: breaking CNAME redirection loop: %s. Question: %s", host, origHost) + if cnames.Has(host) { + log.Info("rewrite: breaking CNAME redirection loop: %s. Question: %s", host, origHost) + return res } - cnames[host] = false + + cnames.Add(host) res.CanonName = rr[0].Answer rr = findRewrites(d.Rewrites, host) } @@ -509,7 +511,7 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) { } res.IPList = append(res.IPList, r.IP) - log.Debug("Rewrite: A/AAAA for %s is %s", host, r.IP) + log.Debug("rewrite: A/AAAA for %s is %s", host, r.IP) } } diff --git a/internal/dnsforward/access.go b/internal/dnsforward/access.go index c3e5aa7c..dba0015d 100644 --- a/internal/dnsforward/access.go +++ b/internal/dnsforward/access.go @@ -17,8 +17,11 @@ import ( type accessCtx struct { lock sync.Mutex - allowedClients map[string]bool // IP addresses of whitelist clients - disallowedClients map[string]bool // IP addresses of clients that should be blocked + // allowedClients are the IP addresses of clients in the allowlist. + allowedClients *aghstrings.Set + + // disallowedClients are the IP addresses of clients in the blocklist. + disallowedClients *aghstrings.Set allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked @@ -26,15 +29,20 @@ type accessCtx struct { blockedHostsEngine *urlfilter.DNSEngine // finds hosts that should be blocked } -func (a *accessCtx) Init(allowedClients, disallowedClients, blockedHosts []string) error { - err := processIPCIDRArray(&a.allowedClients, &a.allowedClientsIPNet, allowedClients) - if err != nil { - return err +func newAccessCtx(allowedClients, disallowedClients, blockedHosts []string) (a *accessCtx, err error) { + a = &accessCtx{ + allowedClients: aghstrings.NewSet(), + disallowedClients: aghstrings.NewSet(), } - err = processIPCIDRArray(&a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients) + err = processIPCIDRArray(a.allowedClients, &a.allowedClientsIPNet, allowedClients) if err != nil { - return err + return nil, fmt.Errorf("processing allowed clients: %w", err) + } + + err = processIPCIDRArray(a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients) + if err != nil { + return nil, fmt.Errorf("processing disallowed clients: %w", err) } b := &strings.Builder{} @@ -51,21 +59,20 @@ func (a *accessCtx) Init(allowedClients, disallowedClients, blockedHosts []strin listArray = append(listArray, list) rulesStorage, err := filterlist.NewRuleStorage(listArray) if err != nil { - return fmt.Errorf("filterlist.NewRuleStorage(): %w", err) + return nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err) } a.blockedHostsEngine = urlfilter.NewDNSEngine(rulesStorage) - return nil + return a, nil } // Split array of IP or CIDR into 2 containers for fast search -func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []string) error { - *dst = make(map[string]bool) - +func processIPCIDRArray(dst *aghstrings.Set, dstIPNet *[]net.IPNet, src []string) error { for _, s := range src { ip := net.ParseIP(s) if ip != nil { - (*dst)[s] = true + dst.Add(s) + continue } @@ -73,6 +80,7 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin if err != nil { return err } + *dstIPNet = append(*dstIPNet, *ipnet) } @@ -89,9 +97,8 @@ func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) { a.lock.Lock() defer a.lock.Unlock() - if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 { - _, ok := a.allowedClients[ipStr] - if ok { + if a.allowedClients.Len() != 0 || len(a.allowedClientsIPNet) != 0 { + if a.allowedClients.Has(ipStr) { return false, "" } @@ -106,8 +113,7 @@ func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) { return true, "" } - _, ok := a.disallowedClients[ipStr] - if ok { + if a.disallowedClients.Has(ipStr) { return true, ipStr } @@ -186,10 +192,11 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) { return } - a := &accessCtx{} - err = a.Init(j.AllowedClients, j.DisallowedClients, j.BlockedHosts) + var a *accessCtx + a, err = newAccessCtx(j.AllowedClients, j.DisallowedClients, j.BlockedHosts) if err != nil { - httpError(r, w, http.StatusBadRequest, "access.Init: %s", err) + httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err) + return } diff --git a/internal/dnsforward/access_test.go b/internal/dnsforward/access_test.go index 6c8a3766..eec5c511 100644 --- a/internal/dnsforward/access_test.go +++ b/internal/dnsforward/access_test.go @@ -82,7 +82,6 @@ func TestIsBlockedIP(t *testing.T) { } t.Run(prefix+tc.name, func(t *testing.T) { - aCtx := &accessCtx{} allowedRules := rules var disallowedRules []string @@ -90,7 +89,8 @@ func TestIsBlockedIP(t *testing.T) { allowedRules, disallowedRules = disallowedRules, allowedRules } - require.Nil(t, aCtx.Init(allowedRules, disallowedRules, nil)) + aCtx, err := newAccessCtx(allowedRules, disallowedRules, nil) + require.NoError(t, err) disallowed, rule := aCtx.IsBlockedIP(tc.ip) assert.Equal(t, tc.wantDis, disallowed) @@ -100,12 +100,12 @@ func TestIsBlockedIP(t *testing.T) { } func TestIsBlockedDomain(t *testing.T) { - aCtx := &accessCtx{} - require.Nil(t, aCtx.Init(nil, nil, []string{ + aCtx, err := newAccessCtx(nil, nil, []string{ "host1", "*.host.com", "||host3.com^", - })) + }) + require.NoError(t, err) testCases := []struct { name string diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 3757d2f8..ff6eaa27 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -340,19 +340,6 @@ func (s *Server) collectDNSIPAddrs() (addrs []string, err error) { return addrs[:i], nil } -// unit is used to show the presence of a value in a set. -type unit = struct{} - -// sliceToSet converts a slice of strings into a string set. -func sliceToSet(strs []string) (set map[string]unit) { - set = make(map[string]unit, len(strs)) - for _, s := range strs { - set[s] = unit{} - } - - return set -} - // setupResolvers initializes the resolvers for local addresses. For internal // use only. func (s *Server) setupResolvers(localAddrs []string) (err error) { @@ -377,16 +364,14 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) { return err } - ourAddrsSet := sliceToSet(ourAddrs) + ourAddrsSet := aghstrings.NewSet(ourAddrs...) // TODO(e.burkov): The approach of subtracting sets of strings is not // really applicable here since in case of listening on all network // interfaces we should check the whole interface's network to cut off // all the loopback addresses as well. localAddrs = aghstrings.FilterOut(localAddrs, func(s string) (ok bool) { - _, ok = ourAddrsSet[s] - - return ok + return ourAddrsSet.Has(s) }) var upsConfig proxy.UpstreamConfig @@ -464,10 +449,7 @@ func (s *Server) Prepare(config *ServerConfig) error { // -- s.prepareIntlProxy() - // Initialize DNS access module - // -- - s.access = &accessCtx{} - err = s.access.Init(s.conf.AllowedClients, s.conf.DisallowedClients, s.conf.BlockedHosts) + s.access, err = newAccessCtx(s.conf.AllowedClients, s.conf.DisallowedClients, s.conf.BlockedHosts) if err != nil { return err } diff --git a/internal/home/clients.go b/internal/home/clients.go index dccf95c5..12f5b32d 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -83,7 +83,7 @@ type clientsContainer struct { ipToRC map[string]*RuntimeClient // IP -> runtime client lock sync.Mutex - allTags map[string]bool + allTags *aghstrings.Set // dhcpServer is used for looking up clients IP addresses by MAC addresses dhcpServer *dhcpd.Server @@ -111,10 +111,7 @@ func (clients *clientsContainer) Init( clients.idIndex = make(map[string]*Client) clients.ipToRC = make(map[string]*RuntimeClient) - clients.allTags = make(map[string]bool) - for _, t := range clientTags { - clients.allTags[t] = false - } + clients.allTags = aghstrings.NewSet(clientTags...) clients.dhcpServer = dhcpServer clients.etcHosts = etcHosts @@ -163,9 +160,8 @@ type clientObject struct { Upstreams []string `yaml:"upstreams"` } -func (clients *clientsContainer) tagKnown(tag string) bool { - _, ok := clients.allTags[tag] - return ok +func (clients *clientsContainer) tagKnown(tag string) (ok bool) { + return clients.allTags.Has(tag) } func (clients *clientsContainer) addFromConfig(objects []clientObject) { diff --git a/internal/home/i18n.go b/internal/home/i18n.go index adbc95aa..623cdef6 100644 --- a/internal/home/i18n.go +++ b/internal/home/i18n.go @@ -6,52 +6,46 @@ import ( "net/http" "strings" + "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/golibs/log" ) -// -------------------- -// internationalization -// -------------------- -var allowedLanguages = map[string]bool{ - "be": true, - "bg": true, - "cs": true, - "da": true, - "de": true, - "en": true, - "es": true, - "fa": true, - "fr": true, - "hr": true, - "hu": true, - "id": true, - "it": true, - "ja": true, - "ko": true, - "nl": true, - "no": true, - "pl": true, - "pt-br": true, - "pt-pt": true, - "ro": true, - "ru": true, - "si-lk": true, - "sk": true, - "sl": true, - "sr-cs": true, - "sv": true, - "th": true, - "tr": true, - "vi": true, - "zh-cn": true, - "zh-hk": true, - "zh-tw": true, -} - -func isLanguageAllowed(language string) bool { - l := strings.ToLower(language) - return allowedLanguages[l] -} +// TODO(a.garipov): Get rid of a global variable? +var allowedLanguages = aghstrings.NewSet( + "be", + "bg", + "cs", + "da", + "de", + "en", + "es", + "fa", + "fr", + "hr", + "hu", + "id", + "it", + "ja", + "ko", + "nl", + "no", + "pl", + "pt-br", + "pt-pt", + "ro", + "ru", + "si-lk", + "sk", + "sl", + "sr-cs", + "sv", + "th", + "tr", + "vi", + "zh-cn", + "zh-hk", + "zh-tw", +) func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") @@ -80,12 +74,15 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) { msg := "empty language specified" log.Println(msg) http.Error(w, msg, http.StatusBadRequest) + return } - if !isLanguageAllowed(language) { + + if !allowedLanguages.Has(language) { msg := fmt.Sprintf("unknown language specified: %s", language) log.Println(msg) http.Error(w, msg, http.StatusBadRequest) + return } diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index ecc1a2bc..fd111565 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" @@ -82,7 +83,7 @@ func TestRDNS_Begin(t *testing.T) { list: map[string]*Client{}, idIndex: tc.cliIDIndex, ipToRC: map[string]*RuntimeClient{}, - allTags: map[string]bool{}, + allTags: aghstrings.NewSet(), }, } ipCache.Clear() @@ -172,7 +173,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { list: map[string]*Client{}, idIndex: map[string]*Client{}, ipToRC: map[string]*RuntimeClient{}, - allTags: map[string]bool{}, + allTags: aghstrings.NewSet(), } ch := make(chan net.IP) rdns := &RDNS{