diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index a642abe4..20d17493 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -39,7 +39,7 @@ const defaultParentalURL = "%s://%s/check-parental-control-hash?prefixes=%s&sens const defaultParentalSensitivity = 13 // use "TEEN" by default const maxDialCacheSize = 2 // the number of host names for safebrowsing and parental control -// Custom filtering settings +// RequestFilteringSettings is custom filtering settings type RequestFilteringSettings struct { FilteringEnabled bool SafeSearchEnabled bool @@ -47,6 +47,12 @@ type RequestFilteringSettings struct { ParentalEnabled bool } +// RewriteEntry is a rewrite array element +type RewriteEntry struct { + Domain string `yaml:"domain"` + Answer string `yaml:"answer"` // IP address or canonical name +} + // Config allows you to configure DNS filtering with New() or just change variables directly. type Config struct { ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17 @@ -60,6 +66,8 @@ type Config struct { SafeSearchCacheSize int `yaml:"safesearch_cache_size"` ParentalCacheSize int `yaml:"parental_cache_size"` + Rewrites []RewriteEntry `yaml:"rewrites"` + // Filtering callback function FilterHandler func(clientAddr string, settings *RequestFilteringSettings) `yaml:"-"` } @@ -131,6 +139,9 @@ const ( FilteredInvalid // FilteredSafeSearch - the host was replaced with safesearch variant FilteredSafeSearch + + // ReasonRewrite - rewrite rule was applied + ReasonRewrite ) func (i Reason) String() string { @@ -138,11 +149,14 @@ func (i Reason) String() string { "NotFilteredNotFound", "NotFilteredWhiteList", "NotFilteredError", + "FilteredBlackList", "FilteredSafeBrowsing", "FilteredParental", "FilteredInvalid", "FilteredSafeSearch", + + "Rewrite", } if uint(i) >= uint(len(names)) { return "" @@ -167,6 +181,10 @@ type Result struct { Rule string `json:",omitempty"` // Original rule text IP net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to + + // for ReasonRewrite: + CanonName string `json:",omitempty"` // CNAME value + IPList []net.IP `json:",omitempty"` // list of IP addresses } // Matched can be used to see if any match at all was found, no matter filtered or not @@ -197,6 +215,12 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, clientAddr string) (Res var result Result var err error + + result = d.processRewrites(host, qtype) + if result.Reason == ReasonRewrite { + return result, nil + } + // try filter lists first if setts.FilteringEnabled { result, err = d.matchHost(host, qtype) @@ -251,6 +275,57 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, clientAddr string) (Res return Result{}, nil } +// Process rewrites table +// . Find CNAME for a domain name +// . if found, set domain name to canonical name +// . Find A or AAAA record for a domain name +// . if found, return IP addresses +func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result { + var res Result + + for _, r := range d.Rewrites { + if r.Domain != host { + continue + } + + ip := net.ParseIP(r.Answer) + if ip == nil { + log.Debug("Rewrite: CNAME for %s is %s", host, r.Answer) + host = r.Answer + res.CanonName = r.Answer + res.Reason = ReasonRewrite + break + } + } + + for _, r := range d.Rewrites { + if r.Domain != host { + continue + } + + ip := net.ParseIP(r.Answer) + if ip == nil { + continue + } + ip4 := ip.To4() + + if qtype == dns.TypeA && ip4 != nil { + res.IPList = append(res.IPList, ip4) + log.Debug("Rewrite: A for %s is %s", host, ip4) + + } else if qtype == dns.TypeAAAA && ip4 == nil { + res.IPList = append(res.IPList, ip) + log.Debug("Rewrite: AAAA for %s is %s", host, ip) + } + } + + if len(res.IPList) != 0 { + res.Reason = ReasonRewrite + } + + return res +} + func setCacheResult(cache *fastcache.Cache, host string, res Result) { var buf bytes.Buffer enc := gob.NewEncoder(&buf) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index ef9fbc20..1661f68b 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -453,11 +453,31 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { } if d.Res == nil { + answer := []dns.RR{} + originalQuestion := d.Req.Question[0] + + if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { + answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName)) + // resolve canonical name, not the original host name + d.Req.Question[0].Name = dns.Fqdn(res.CanonName) + } + // request was not filtered so let it be processed further err = p.Resolve(d) if err != nil { return err } + + if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { + + d.Req.Question[0] = originalQuestion + d.Res.Question[0] = originalQuestion + + if len(d.Res.Answer) != 0 { + answer = append(answer, d.Res.Answer...) // host -> IP + d.Res.Answer = answer + } + } } shouldLog := true @@ -485,8 +505,10 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { // filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) { - msg := d.Req - host := strings.TrimSuffix(msg.Question[0].Name, ".") + var res dnsfilter.Result + req := d.Req + host := strings.TrimSuffix(req.Question[0].Name, ".") + origHost := host s.RLock() protectionEnabled := s.conf.ProtectionEnabled @@ -497,7 +519,10 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error return nil, nil } - var res dnsfilter.Result + if host != origHost { + log.Debug("Rewrite: not supported: CNAME for %s is %s", origHost, host) + } + var err error clientAddr := "" @@ -508,9 +533,35 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error if err != nil { // Return immediately if there's an error return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) + } else if res.IsFiltered { // log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) d.Res = s.genDNSFilterMessage(d, &res) + + } else if res.Reason == dnsfilter.ReasonRewrite && len(res.IPList) != 0 { + resp := dns.Msg{} + resp.SetReply(req) + + name := host + if len(res.CanonName) != 0 { + resp.Answer = append(resp.Answer, s.genCNAMEAnswer(req, res.CanonName)) + name = res.CanonName + } + + for _, ip := range res.IPList { + if req.Question[0].Qtype == dns.TypeA { + a := s.genAAnswer(req, ip) + a.Hdr.Name = dns.Fqdn(name) + resp.Answer = append(resp.Answer, a) + + } else if req.Question[0].Qtype == dns.TypeAAAA { + a := s.genAAAAAnswer(req, res.IP) + a.Hdr.Name = dns.Fqdn(name) + resp.Answer = append(resp.Answer, a) + } + } + + d.Res = &resp } return &res, err @@ -644,6 +695,19 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo return &resp } +// Make a CNAME response +func (s *Server) genCNAMEAnswer(req *dns.Msg, cname string) *dns.CNAME { + answer := new(dns.CNAME) + answer.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeCNAME, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + answer.Target = dns.Fqdn(cname) + return answer +} + func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg { resp := dns.Msg{} resp.SetRcode(request, dns.RcodeNameError)