package dnsfilter

import (
	"net"
	"testing"

	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
)

func TestRewrites(t *testing.T) {
	d := Dnsfilter{}
	// CNAME, A, AAAA
	d.Rewrites = []RewriteEntry{
		RewriteEntry{"somecname", "somehost.com", 0, nil},
		RewriteEntry{"somehost.com", "0.0.0.0", 0, nil},

		RewriteEntry{"host.com", "1.2.3.4", 0, nil},
		RewriteEntry{"host.com", "1.2.3.5", 0, nil},
		RewriteEntry{"host.com", "1:2:3::4", 0, nil},
		RewriteEntry{"www.host.com", "host.com", 0, nil},
	}
	d.prepareRewrites()
	r := d.processRewrites("host2.com", dns.TypeA)
	assert.Equal(t, NotFilteredNotFound, r.Reason)

	r = d.processRewrites("www.host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, "host.com", r.CanonName)
	assert.Equal(t, 2, len(r.IPList))
	assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
	assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5")))

	r = d.processRewrites("www.host.com", dns.TypeAAAA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, "host.com", r.CanonName)
	assert.Equal(t, 1, len(r.IPList))
	assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))

	// wildcard
	d.Rewrites = []RewriteEntry{
		RewriteEntry{"host.com", "1.2.3.4", 0, nil},
		RewriteEntry{"*.host.com", "1.2.3.5", 0, nil},
	}
	d.prepareRewrites()
	r = d.processRewrites("host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))

	r = d.processRewrites("www.host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5")))

	r = d.processRewrites("www.host2.com", dns.TypeA)
	assert.Equal(t, NotFilteredNotFound, r.Reason)

	// override a wildcard
	d.Rewrites = []RewriteEntry{
		RewriteEntry{"a.host.com", "1.2.3.4", 0, nil},
		RewriteEntry{"*.host.com", "1.2.3.5", 0, nil},
	}
	d.prepareRewrites()
	r = d.processRewrites("a.host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.True(t, len(r.IPList) == 1)
	assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))

	// wildcard + CNAME
	d.Rewrites = []RewriteEntry{
		RewriteEntry{"host.com", "1.2.3.4", 0, nil},
		RewriteEntry{"*.host.com", "host.com", 0, nil},
	}
	d.prepareRewrites()
	r = d.processRewrites("www.host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, "host.com", r.CanonName)
	assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))

	// 2 CNAMEs
	d.Rewrites = []RewriteEntry{
		RewriteEntry{"b.host.com", "a.host.com", 0, nil},
		RewriteEntry{"a.host.com", "host.com", 0, nil},
		RewriteEntry{"host.com", "1.2.3.4", 0, nil},
	}
	d.prepareRewrites()
	r = d.processRewrites("b.host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, "host.com", r.CanonName)
	assert.True(t, len(r.IPList) == 1)
	assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))

	// 2 CNAMEs + wildcard
	d.Rewrites = []RewriteEntry{
		RewriteEntry{"b.host.com", "a.host.com", 0, nil},
		RewriteEntry{"a.host.com", "x.somehost.com", 0, nil},
		RewriteEntry{"*.somehost.com", "1.2.3.4", 0, nil},
	}
	d.prepareRewrites()
	r = d.processRewrites("b.host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, "x.somehost.com", r.CanonName)
	assert.True(t, len(r.IPList) == 1)
	assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
}

func TestRewritesLevels(t *testing.T) {
	d := Dnsfilter{}
	// exact host, wildcard L2, wildcard L3
	d.Rewrites = []RewriteEntry{
		RewriteEntry{"host.com", "1.1.1.1", 0, nil},
		RewriteEntry{"*.host.com", "2.2.2.2", 0, nil},
		RewriteEntry{"*.sub.host.com", "3.3.3.3", 0, nil},
	}
	d.prepareRewrites()

	// match exact
	r := d.processRewrites("host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, 1, len(r.IPList))
	assert.Equal(t, "1.1.1.1", r.IPList[0].String())

	// match L2
	r = d.processRewrites("sub.host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, 1, len(r.IPList))
	assert.Equal(t, "2.2.2.2", r.IPList[0].String())

	// match L3
	r = d.processRewrites("my.sub.host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, 1, len(r.IPList))
	assert.Equal(t, "3.3.3.3", r.IPList[0].String())
}

func TestRewritesExceptionCNAME(t *testing.T) {
	d := Dnsfilter{}
	// wildcard; exception for a sub-domain
	d.Rewrites = []RewriteEntry{
		RewriteEntry{"*.host.com", "2.2.2.2", 0, nil},
		RewriteEntry{"sub.host.com", "sub.host.com", 0, nil},
	}
	d.prepareRewrites()

	// match sub-domain
	r := d.processRewrites("my.host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, 1, len(r.IPList))
	assert.Equal(t, "2.2.2.2", r.IPList[0].String())

	// match sub-domain, but handle exception
	r = d.processRewrites("sub.host.com", dns.TypeA)
	assert.Equal(t, NotFilteredNotFound, r.Reason)
}

func TestRewritesExceptionWC(t *testing.T) {
	d := Dnsfilter{}
	// wildcard; exception for a sub-wildcard
	d.Rewrites = []RewriteEntry{
		RewriteEntry{"*.host.com", "2.2.2.2", 0, nil},
		RewriteEntry{"*.sub.host.com", "*.sub.host.com", 0, nil},
	}
	d.prepareRewrites()

	// match sub-domain
	r := d.processRewrites("my.host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, 1, len(r.IPList))
	assert.Equal(t, "2.2.2.2", r.IPList[0].String())

	// match sub-domain, but handle exception
	r = d.processRewrites("my.sub.host.com", dns.TypeA)
	assert.Equal(t, NotFilteredNotFound, r.Reason)
}

func TestRewritesExceptionIP(t *testing.T) {
	d := Dnsfilter{}
	// exception for AAAA record
	d.Rewrites = []RewriteEntry{
		RewriteEntry{"host.com", "1.2.3.4", 0, nil},
		RewriteEntry{"host.com", "AAAA", 0, nil},
		RewriteEntry{"host2.com", "::1", 0, nil},
		RewriteEntry{"host2.com", "A", 0, nil},
		RewriteEntry{"host3.com", "A", 0, nil},
	}
	d.prepareRewrites()

	// match domain
	r := d.processRewrites("host.com", dns.TypeA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, 1, len(r.IPList))
	assert.Equal(t, "1.2.3.4", r.IPList[0].String())

	// match exception
	r = d.processRewrites("host.com", dns.TypeAAAA)
	assert.Equal(t, NotFilteredNotFound, r.Reason)

	// match exception
	r = d.processRewrites("host2.com", dns.TypeA)
	assert.Equal(t, NotFilteredNotFound, r.Reason)

	// match domain
	r = d.processRewrites("host2.com", dns.TypeAAAA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, 1, len(r.IPList))
	assert.Equal(t, "::1", r.IPList[0].String())

	// match exception
	r = d.processRewrites("host3.com", dns.TypeA)
	assert.Equal(t, NotFilteredNotFound, r.Reason)

	// match domain
	r = d.processRewrites("host3.com", dns.TypeAAAA)
	assert.Equal(t, ReasonRewrite, r.Reason)
	assert.Equal(t, 0, len(r.IPList))
}