diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 40ca57f0..61e9bee0 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -555,9 +555,18 @@ func (d *Dnsfilter) matchHost(host string, qtype uint16) (Result, error) { // either IPv4 or IPv4-mapped IPv6 address res.IP = hostRule.IP.To4() return res, nil - } else if qtype == dns.TypeAAAA && hostRule.IP.To4() == nil { - res.IP = hostRule.IP - return res, nil + + } else if qtype == dns.TypeAAAA { + ip4 := hostRule.IP.To4() + if ip4 == nil { + res.IP = hostRule.IP + return res, nil + } + if bytes.Equal(ip4, []byte{0, 0, 0, 0}) { + // send IP="::" response for a rule "0.0.0.0 blockdomain" + res.IP = net.IPv6zero + return res, nil + } } continue diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 51c6d3f6..b728c90f 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -378,7 +378,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error var res dnsfilter.Result var err error - res, err = dnsFilter.CheckHost(host) + res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype) if err != nil { // Return immediately if there's an error return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) @@ -394,7 +394,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Result) *dns.Msg { m := d.Req - if m.Question[0].Qtype != dns.TypeA { + if m.Question[0].Qtype != dns.TypeA && m.Question[0].Qtype != dns.TypeAAAA { return s.genNXDomain(m) } @@ -405,11 +405,25 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu return s.genBlockedHost(m, parentalBlockHost, d) default: if result.IP != nil { - return s.genARecord(m, result.IP) + if m.Question[0].Qtype == dns.TypeA { + return s.genARecord(m, result.IP) + } else if m.Question[0].Qtype == dns.TypeAAAA { + return s.genAAAARecord(m, result.IP) + } + + // empty response + resp := dns.Msg{} + resp.SetReply(m) + return &resp } if s.conf.BlockingMode == "null_ip" { - return s.genARecord(m, net.IPv4zero) + switch m.Question[0].Qtype { + case dns.TypeA: + return s.genARecord(m, []byte{0, 0, 0, 0}) + case dns.TypeAAAA: + return s.genAAAARecord(m, net.IPv6zero) + } } return s.genNXDomain(m) @@ -426,15 +440,41 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg { resp := dns.Msg{} resp.SetReply(request) - answer, err := dns.NewRR(fmt.Sprintf("%s %d A %s", request.Question[0].Name, s.conf.BlockedResponseTTL, ip.String())) - if err != nil { - log.Printf("Couldn't generate A record for replacement host '%s': %s", ip.String(), err) - return s.genServerFailure(request) - } - resp.Answer = append(resp.Answer, answer) + resp.Answer = append(resp.Answer, s.genAAnswer(request, ip)) return &resp } +func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg { + resp := dns.Msg{} + resp.SetReply(request) + resp.Answer = append(resp.Answer, s.genAAAAAnswer(request, ip)) + return &resp +} + +func (s *Server) genAAnswer(req *dns.Msg, ip net.IP) *dns.A { + answer := new(dns.A) + answer.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeA, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + answer.A = ip + return answer +} + +func (s *Server) genAAAAAnswer(req *dns.Msg, ip net.IP) *dns.AAAA { + answer := new(dns.AAAA) + answer.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeAAAA, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + answer.AAAA = ip + return answer +} + func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg { // look up the hostname, TODO: cache replReq := dns.Msg{}