diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 9c36627d..5c21ae99 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -38,6 +38,11 @@ type Server struct { ServerConfig } +const ( + safeBrowsingBlockHost = "standard-block.dns.adguard.com" + parentalBlockHost = "family-block.dns.adguard.com" +) + // uncomment this block to have tracing of locks /* func (s *Server) Lock() { @@ -426,6 +431,9 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP return s.genNotImpl(msg), nil, nil, nil } + // we need upstream to resolve A records + upstream := s.chooseUpstream() + host := strings.TrimSuffix(msg.Question[0].Name, ".") // use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise var res dnsfilter.Result @@ -437,6 +445,12 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP return s.genServerFailure(msg), &res, nil, err } else if res.IsFiltered { log.Printf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) + switch res.Reason { + case dnsfilter.FilteredSafeBrowsing: + return s.genArecord(msg, safeBrowsingBlockHost, upstream), &res, nil, nil + case dnsfilter.FilteredParental: + return s.genArecord(msg, parentalBlockHost, upstream), &res, nil, nil + } return s.genNXDomain(msg), &res, nil, nil } } @@ -449,7 +463,6 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP } // TODO: replace with single-socket implementation - upstream := s.chooseUpstream() reply, err := upstream.Exchange(msg) if err != nil { log.Printf("talking to upstream failed for host '%s': %s", host, err) @@ -563,6 +576,44 @@ func (s *Server) genNotImpl(request *dns.Msg) *dns.Msg { return &resp } +func (s *Server) genArecord(request *dns.Msg, newAddr string, upstream Upstream) *dns.Msg { + addr := net.ParseIP(newAddr) + if addr != nil { + // this is an IP address, return it + resp := dns.Msg{} + resp.SetReply(request) + answer, err := dns.NewRR(fmt.Sprintf("%s %d A %s", request.Question[0].Name, s.BlockedResponseTTL, newAddr)) + if err != nil { + log.Printf("Couldn't generate A record for up replacement host '%s': %s", newAddr, err) + return s.genServerFailure(request) + } + resp.Answer = append(resp.Answer, answer) + return &resp + } + + // look up the hostname, TODO: cache + replReq := dns.Msg{} + replReq.SetQuestion(dns.Fqdn(newAddr), request.Question[0].Qtype) + replReq.RecursionDesired = true + reply, err := upstream.Exchange(&replReq) + if err != nil { + log.Printf("Couldn't look up replacement host '%s' on upstream %s: %s", newAddr, upstream.Address(), err) + return s.genServerFailure(request) + } + + resp := dns.Msg{} + resp.SetReply(request) + resp.Authoritative, resp.RecursionAvailable = true, true + if reply != nil { + for _, answer := range reply.Answer { + answer.Header().Name = request.Question[0].Name + resp.Answer = append(resp.Answer, answer) + } + } + + return &resp +} + func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg { resp := dns.Msg{} resp.SetRcode(request, dns.RcodeNameError)