diff --git a/AGHTechDoc.md b/AGHTechDoc.md index 8d1de3d7..7d223c5b 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -673,6 +673,7 @@ Response: { name: "client1" ids: ["...", ...] // IP, CIDR or MAC + tags: ["...", ...] use_global_settings: true filtering_enabled: false parental_enabled: false @@ -698,6 +699,7 @@ Response: } } ] + supported_tags: ["...", ...] } Supported keys for `whois_info`: orgname, country, city. @@ -712,6 +714,7 @@ Request: { name: "client1" ids: ["...", ...] // IP, CIDR or MAC + tags: ["...", ...] use_global_settings: true filtering_enabled: false parental_enabled: false @@ -742,6 +745,7 @@ Request: data: { name: "client1" ids: ["...", ...] // IP, CIDR or MAC + tags: ["...", ...] use_global_settings: true filtering_enabled: false parental_enabled: false diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 0f667850..4c84c1e8 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -31,6 +31,7 @@ type RequestFilteringSettings struct { SafeSearchEnabled bool SafeBrowsingEnabled bool ParentalEnabled bool + ClientTags []string ServicesRules []ServiceEntry } @@ -264,7 +265,7 @@ func (d *Dnsfilter) CheckHostRules(host string, qtype uint16, setts *RequestFilt return Result{}, nil } - return d.matchHost(host, qtype) + return d.matchHost(host, qtype, setts.ClientTags) } // CheckHost tries to match the host against filtering rules, @@ -286,7 +287,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering // try filter lists first if setts.FilteringEnabled { - result, err = d.matchHost(host, qtype) + result, err = d.matchHost(host, qtype, setts.ClientTags) if err != nil { return result, err } @@ -475,14 +476,14 @@ func (d *Dnsfilter) initFiltering(filters map[int]string) error { } // matchHost is a low-level way to check only if hostname is filtered by rules, skipping expensive safebrowsing and parental lookups -func (d *Dnsfilter) matchHost(host string, qtype uint16) (Result, error) { +func (d *Dnsfilter) matchHost(host string, qtype uint16, ctags []string) (Result, error) { d.engineLock.RLock() defer d.engineLock.RUnlock() if d.filteringEngine == nil { return Result{}, nil } - frules, ok := d.filteringEngine.Match(host) + frules, ok := d.filteringEngine.Match(host, ctags) if !ok { return Result{}, nil } diff --git a/go.mod b/go.mod index 61de5bee..902c44c4 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.13 require ( github.com/AdguardTeam/dnsproxy v0.23.7 github.com/AdguardTeam/golibs v0.3.0 - github.com/AdguardTeam/urlfilter v0.7.2 + github.com/AdguardTeam/urlfilter v0.8.1 github.com/NYTimes/gziphandler v1.1.1 github.com/etcd-io/bbolt v1.3.3 github.com/go-test/deep v1.0.4 // indirect diff --git a/go.sum b/go.sum index a407e811..c1fc4421 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/AdguardTeam/golibs v0.3.0/go.mod h1:R3M+mAg3nWG4X4Hsag5eef/TckHFH12ZY github.com/AdguardTeam/gomitmproxy v0.1.2/go.mod h1:Mrt/3EfiXIYY2aZ7KsLuCUJzUARD/fWJ119IfzOB13M= github.com/AdguardTeam/urlfilter v0.7.0 h1:ffFLt4rA3GX8PJYGL3bGcT5bSxZlML5k6cKpSeN2UI8= github.com/AdguardTeam/urlfilter v0.7.0/go.mod h1:GHXPzEG59ezyff22lXSQ7dicj1kFZBrH5kmZ6EvQzfk= -github.com/AdguardTeam/urlfilter v0.7.2 h1:0XyepkVAvY1eYtNKgDRCkookfaBmROvY6VE23ebPUNc= -github.com/AdguardTeam/urlfilter v0.7.2/go.mod h1:GHXPzEG59ezyff22lXSQ7dicj1kFZBrH5kmZ6EvQzfk= +github.com/AdguardTeam/urlfilter v0.8.1 h1:9YRQOR15DU7+k01PWAgc/Ay12jjxVqSi6P0+whFm0f4= +github.com/AdguardTeam/urlfilter v0.8.1/go.mod h1:GHXPzEG59ezyff22lXSQ7dicj1kFZBrH5kmZ6EvQzfk= github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I= github.com/NYTimes/gziphandler v1.1.1/go.mod h1:n/CVRwUEOgIxrgPvAQhUUr9oeUtvrhMomdKFjzJNB0c= github.com/StackExchange/wmi v0.0.0-20181212234831-e0a55b97c705 h1:UUppSQnhf4Yc6xGxSkoQpPhb7RVzuv5Nb1mwJ5VId9s= diff --git a/home/clients.go b/home/clients.go index 65ce78e2..fb96fd76 100644 --- a/home/clients.go +++ b/home/clients.go @@ -8,6 +8,7 @@ import ( "os" "os/exec" "runtime" + "sort" "strings" "sync" "time" @@ -26,6 +27,7 @@ const ( // Client information type Client struct { IDs []string + Tags []string Name string UseOwnSettings bool // false: use global settings FilteringEnabled bool @@ -69,6 +71,8 @@ type clientsContainer struct { ipHost map[string]*ClientHost // IP -> Hostname lock sync.Mutex + allTags map[string]bool + // dhcpServer is used for looking up clients IP addresses by MAC addresses dhcpServer *dhcpd.Server @@ -84,6 +88,12 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd. clients.list = make(map[string]*Client) clients.idIndex = make(map[string]*Client) clients.ipHost = make(map[string]*ClientHost) + + clients.allTags = make(map[string]bool) + for _, t := range clientTags { + clients.allTags[t] = false + } + clients.dhcpServer = dhcpServer clients.addFromConfig(objects) @@ -96,6 +106,7 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd. type clientObject struct { Name string `yaml:"name"` + Tags []string `yaml:"tags"` IDs []string `yaml:"ids"` UseGlobalSettings bool `yaml:"use_global_settings"` FilteringEnabled bool `yaml:"filtering_enabled"` @@ -109,6 +120,11 @@ type clientObject struct { Upstreams []string `yaml:"upstreams"` } +func (clients *clientsContainer) tagKnown(tag string) bool { + _, ok := clients.allTags[tag] + return ok +} + func (clients *clientsContainer) addFromConfig(objects []clientObject) { for _, cy := range objects { cli := Client{ @@ -125,6 +141,16 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) { Upstreams: cy.Upstreams, } + + for _, t := range cy.Tags { + if !clients.tagKnown(t) { + log.Debug("Clients: skipping unknown tag '%s'", t) + continue + } + cli.Tags = append(cli.Tags, t) + } + sort.Strings(cli.Tags) + _, err := clients.Add(cli) if err != nil { log.Tracef("clientAdd: %s", err) @@ -146,14 +172,10 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) { UseGlobalBlockedServices: !cli.UseOwnBlockedServices, } - cy.IDs = make([]string, len(cli.IDs)) - copy(cy.IDs, cli.IDs) - - cy.BlockedServices = make([]string, len(cli.BlockedServices)) - copy(cy.BlockedServices, cli.BlockedServices) - - cy.Upstreams = make([]string, len(cli.Upstreams)) - copy(cy.Upstreams, cli.Upstreams) + cy.Tags = stringArrayDup(cli.Tags) + cy.IDs = stringArrayDup(cli.IDs) + cy.BlockedServices = stringArrayDup(cli.BlockedServices) + cy.Upstreams = stringArrayDup(cli.Upstreams) *objects = append(*objects, cy) } @@ -189,12 +211,26 @@ func (clients *clientsContainer) Exists(ip string, source clientSource) bool { return true } +func stringArrayDup(a []string) []string { + a2 := make([]string, len(a)) + copy(a2, a) + return a2 +} + // Find searches for a client by IP func (clients *clientsContainer) Find(ip string) (Client, bool) { clients.lock.Lock() defer clients.lock.Unlock() - return clients.findByIP(ip) + c, ok := clients.findByIP(ip) + if !ok { + return Client{}, false + } + c.IDs = stringArrayDup(c.IDs) + c.Tags = stringArrayDup(c.Tags) + c.BlockedServices = stringArrayDup(c.BlockedServices) + c.Upstreams = stringArrayDup(c.Upstreams) + return c, true } func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream { @@ -297,7 +333,7 @@ func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) { } // Check if Client object's fields are correct -func (c *Client) check() error { +func (clients *clientsContainer) check(c *Client) error { if len(c.Name) == 0 { return fmt.Errorf("Invalid Name") } @@ -326,6 +362,13 @@ func (c *Client) check() error { return fmt.Errorf("Invalid ID: %s", id) } + for _, t := range c.Tags { + if !clients.tagKnown(t) { + return fmt.Errorf("Invalid tag: %s", t) + } + } + sort.Strings(c.Tags) + if len(c.Upstreams) != 0 { err := dnsforward.ValidateUpstreams(c.Upstreams) if err != nil { @@ -339,7 +382,7 @@ func (c *Client) check() error { // Add a new client object // Return true: success; false: client exists. func (clients *clientsContainer) Add(c Client) (bool, error) { - e := c.check() + e := clients.check(&c) if e != nil { return false, e } @@ -408,7 +451,7 @@ func arraysEqual(a, b []string) bool { // Update a client func (clients *clientsContainer) Update(name string, c Client) error { - err := c.check() + err := clients.check(&c) if err != nil { return err } diff --git a/home/clients_http.go b/home/clients_http.go index aaaef5e0..c48e4937 100644 --- a/home/clients_http.go +++ b/home/clients_http.go @@ -9,6 +9,7 @@ import ( type clientJSON struct { IDs []string `json:"ids"` + Tags []string `json:"tags"` Name string `json:"name"` UseGlobalSettings bool `json:"use_global_settings"` FilteringEnabled bool `json:"filtering_enabled"` @@ -33,6 +34,7 @@ type clientHostJSON struct { type clientListJSON struct { Clients []clientJSON `json:"clients"` AutoClients []clientHostJSON `json:"auto_clients"` + Tags []string `json:"supported_tags"` } // respond with information about configured clients @@ -71,6 +73,8 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http } clients.lock.Unlock() + data.Tags = clientTags + w.Header().Set("Content-Type", "application/json") e := json.NewEncoder(w).Encode(data) if e != nil { @@ -84,6 +88,7 @@ func jsonToClient(cj clientJSON) (*Client, error) { c := Client{ Name: cj.Name, IDs: cj.IDs, + Tags: cj.Tags, UseOwnSettings: !cj.UseGlobalSettings, FilteringEnabled: cj.FilteringEnabled, ParentalEnabled: cj.ParentalEnabled, @@ -103,6 +108,7 @@ func clientToJSON(c *Client) clientJSON { cj := clientJSON{ Name: c.Name, IDs: c.IDs, + Tags: c.Tags, UseGlobalSettings: !c.UseOwnSettings, FilteringEnabled: c.FilteringEnabled, ParentalEnabled: c.ParentalEnabled, diff --git a/home/clients_tags.go b/home/clients_tags.go new file mode 100644 index 00000000..c40caa2a --- /dev/null +++ b/home/clients_tags.go @@ -0,0 +1,26 @@ +package home + +var clientTags = []string{ + "device_audio", + "device_gameconsole", + "device_laptop", + "device_nas", // Network-attached Storage + "device_other", + "device_pc", + "device_phone", + "device_printer", + "device_tablet", + "device_tv", + + "os_android", + "os_ios", + "os_linux", + "os_macos", + "os_other", + "os_windows", + + "user_admin", + "user_child", + "user_other", + "user_regular", +} diff --git a/home/dns.go b/home/dns.go index 6a8be1fd..4d5dceb3 100644 --- a/home/dns.go +++ b/home/dns.go @@ -204,6 +204,8 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri ApplyBlockedServices(setts, c.BlockedServices) } + setts.ClientTags = c.Tags + if !c.UseOwnSettings { return }