diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 341f2fad..42018265 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -627,6 +627,66 @@ func TestBlockedBySafeBrowsing(t *testing.T) { } } +func TestRewrite(t *testing.T) { + c := dnsfilter.Config{} + c.Rewrites = []dnsfilter.RewriteEntry{ + dnsfilter.RewriteEntry{ + Domain: "test.com", + Answer: "1.2.3.4", + Type: dns.TypeA, + }, + dnsfilter.RewriteEntry{ + Domain: "alias.test.com", + Answer: "test.com", + Type: dns.TypeCNAME, + }, + dnsfilter.RewriteEntry{ + Domain: "my.alias.example.org", + Answer: "example.org", + Type: dns.TypeCNAME, + }, + } + + f := dnsfilter.New(&c, nil) + s := NewServer(f, nil, nil) + conf := ServerConfig{} + conf.UDPListenAddr = &net.UDPAddr{Port: 0} + conf.TCPListenAddr = &net.TCPAddr{Port: 0} + conf.ProtectionEnabled = true + conf.UpstreamDNS = []string{"8.8.8.8:53"} + + err := s.Prepare(&conf) + assert.Nil(t, err) + err = s.Start() + assert.Nil(t, err) + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + + req := createTestMessageWithType("test.com.", dns.TypeA) + reply, err := dns.Exchange(req, addr.String()) + assert.Nil(t, err) + assert.Equal(t, 1, len(reply.Answer)) + a, ok := reply.Answer[0].(*dns.A) + assert.True(t, ok) + assert.Equal(t, "1.2.3.4", a.A.String()) + + req = createTestMessageWithType("alias.test.com.", dns.TypeA) + reply, err = dns.Exchange(req, addr.String()) + assert.Nil(t, err) + assert.Equal(t, 2, len(reply.Answer)) + assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) + assert.Equal(t, "1.2.3.4", reply.Answer[1].(*dns.A).A.String()) + + req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) + reply, err = dns.Exchange(req, addr.String()) + assert.Nil(t, err) + assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) // the original question is restored + assert.Equal(t, 2, len(reply.Answer)) + assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) + assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) + + _ = s.Stop() +} + func createTestServer(t *testing.T) *Server { rules := `||nxdomain.example.org ||null.example.org^ diff --git a/dnsforward/handle_dns.go b/dnsforward/handle_dns.go index 7aabd25e..87230e9b 100644 --- a/dnsforward/handle_dns.go +++ b/dnsforward/handle_dns.go @@ -211,12 +211,11 @@ func processFilteringAfterResponse(ctx *dnsContext) int { switch res.Reason { case dnsfilter.ReasonRewrite: - if d.Res != nil { - break // response is already prepared - } - if len(res.CanonName) == 0 { + if len(ctx.origQuestion.Name) == 0 { + // origQuestion is set in case we get only CNAME without IP from rewrites table break } + d.Req.Question[0] = ctx.origQuestion d.Res.Question[0] = ctx.origQuestion