package querylog

import (
	"net"
	"os"
	"testing"
	"time"

	"github.com/AdguardTeam/AdGuardHome/dnsfilter"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
)

func prepareTestDir() string {
	const dir = "./agh-test"
	_ = os.RemoveAll(dir)
	_ = os.MkdirAll(dir, 0755)
	return dir
}

// Check adding and loading (with filtering) entries from disk and memory
func TestQueryLog(t *testing.T) {
	conf := Config{
		Enabled:  true,
		Interval: 1,
		MemSize:  100,
	}
	conf.BaseDir = prepareTestDir()
	defer func() { _ = os.RemoveAll(conf.BaseDir) }()
	l := newQueryLog(conf)

	// add disk entries
	addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
	// write to disk (first file)
	_ = l.flushLogBuffer(true)
	// start writing to the second file
	_ = l.rotate()
	// add disk entries
	addEntry(l, "example.org", "1.1.1.2", "2.2.2.2")
	// write to disk
	_ = l.flushLogBuffer(true)
	// add memory entries
	addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3")
	addEntry(l, "example.com", "1.1.1.4", "2.2.2.4")

	// get all entries
	params := getDataParams{
		OlderThan: time.Time{},
	}
	d := l.getData(params)
	mdata := d["data"].([]map[string]interface{})
	assert.Equal(t, 4, len(mdata))
	assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4"))
	assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3"))
	assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2"))
	assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1"))

	// search by domain (strict)
	params = getDataParams{
		OlderThan:         time.Time{},
		Domain:            "test.example.org",
		StrictMatchDomain: true,
	}
	d = l.getData(params)
	mdata = d["data"].([]map[string]interface{})
	assert.True(t, len(mdata) == 1)
	assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3"))

	// search by domain (not strict)
	params = getDataParams{
		OlderThan:         time.Time{},
		Domain:            "example.org",
		StrictMatchDomain: false,
	}
	d = l.getData(params)
	mdata = d["data"].([]map[string]interface{})
	assert.Equal(t, 3, len(mdata))
	assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3"))
	assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2"))
	assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1"))

	// search by client IP (strict)
	params = getDataParams{
		OlderThan:         time.Time{},
		Client:            "2.2.2.2",
		StrictMatchClient: true,
	}
	d = l.getData(params)
	mdata = d["data"].([]map[string]interface{})
	assert.Equal(t, 1, len(mdata))
	assert.True(t, checkEntry(t, mdata[0], "example.org", "1.1.1.2", "2.2.2.2"))

	// search by client IP (part of)
	params = getDataParams{
		OlderThan:         time.Time{},
		Client:            "2.2.2",
		StrictMatchClient: false,
	}
	d = l.getData(params)
	mdata = d["data"].([]map[string]interface{})
	assert.Equal(t, 4, len(mdata))
	assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4"))
	assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3"))
	assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2"))
	assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1"))
}

func addEntry(l *queryLog, host, answerStr, client string) {
	q := dns.Msg{}
	q.Question = append(q.Question, dns.Question{
		Name:   host + ".",
		Qtype:  dns.TypeA,
		Qclass: dns.ClassINET,
	})

	a := dns.Msg{}
	a.Question = append(a.Question, q.Question[0])
	answer := new(dns.A)
	answer.Hdr = dns.RR_Header{
		Name:   q.Question[0].Name,
		Rrtype: dns.TypeA,
		Class:  dns.ClassINET,
	}
	answer.A = net.ParseIP(answerStr)
	a.Answer = append(a.Answer, answer)
	res := dnsfilter.Result{}
	params := AddParams{
		Question: &q,
		Answer:   &a,
		Result:   &res,
		ClientIP: net.ParseIP(client),
		Upstream: "upstream",
	}
	l.Add(params)
}

func checkEntry(t *testing.T, m map[string]interface{}, host, answer, client string) bool {
	mq := m["question"].(map[string]interface{})
	ma := m["answer"].([]map[string]interface{})
	ma0 := ma[0]
	if !assert.Equal(t, host, mq["host"].(string)) ||
		!assert.Equal(t, "IN", mq["class"].(string)) ||
		!assert.Equal(t, "A", mq["type"].(string)) ||
		!assert.Equal(t, answer, ma0["value"].(string)) ||
		!assert.Equal(t, client, m["client"].(string)) {
		return false
	}
	return true
}

func TestJSON(t *testing.T) {
	s := `
	{"keystr":"val","obj":{"keybool":true,"keyint":123456}}
	`
	k, v, jtype := readJSON(&s)
	assert.Equal(t, jtype, int32(jsonTStr))
	assert.Equal(t, "keystr", k)
	assert.Equal(t, "val", v)

	k, v, jtype = readJSON(&s)
	assert.Equal(t, jtype, int32(jsonTObj))
	assert.Equal(t, "obj", k)

	k, v, jtype = readJSON(&s)
	assert.Equal(t, jtype, int32(jsonTBool))
	assert.Equal(t, "keybool", k)
	assert.Equal(t, "true", v)

	k, v, jtype = readJSON(&s)
	assert.Equal(t, jtype, int32(jsonTNum))
	assert.Equal(t, "keyint", k)
	assert.Equal(t, "123456", v)

	k, v, jtype = readJSON(&s)
	assert.True(t, jtype == jsonTErr)
}