// Safe Browsing, Parental Control

package dnsfilter

import (
	"bytes"
	"crypto/sha256"
	"encoding/binary"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"net"
	"net/http"
	"sort"
	"strings"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/cache"
	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
	"golang.org/x/net/publicsuffix"
)

const dnsTimeout = 3 * time.Second
const defaultSafebrowsingServer = "https://dns-family.adguard.com/dns-query"
const defaultParentalServer = "https://dns-family.adguard.com/dns-query"
const sbTXTSuffix = "sb.dns.adguard.com."
const pcTXTSuffix = "pc.dns.adguard.com."

func (d *Dnsfilter) initSecurityServices() error {
	var err error
	d.safeBrowsingServer = defaultSafebrowsingServer
	d.parentalServer = defaultParentalServer
	opts := upstream.Options{
		Timeout: dnsTimeout,
		ServerIPAddrs: []net.IP{
			net.ParseIP("176.103.130.132"),
			net.ParseIP("176.103.130.134"),
			net.ParseIP("2a00:5a60::bad1:ff"),
			net.ParseIP("2a00:5a60::bad2:ff"),
		},
	}

	d.parentalUpstream, err = upstream.AddressToUpstream(d.parentalServer, opts)
	if err != nil {
		return err
	}

	d.safeBrowsingUpstream, err = upstream.AddressToUpstream(d.safeBrowsingServer, opts)
	if err != nil {
		return err
	}

	return nil
}

/*
expire byte[4]
hash byte[32]
...
*/
func (c *sbCtx) setCache(prefix []byte, hashes []byte) {
	d := make([]byte, 4+len(hashes))
	expire := uint(time.Now().Unix()) + c.cacheTime*60
	binary.BigEndian.PutUint32(d[:4], uint32(expire))
	copy(d[4:], hashes)
	c.cache.Set(prefix, d)
	log.Debug("%s: stored in cache: %v", c.svc, prefix)
}

func (c *sbCtx) getCached() int {
	now := time.Now().Unix()
	hashesToRequest := map[[32]byte]string{}
	for k, v := range c.hashToHost {
		key := k[0:2]
		val := c.cache.Get(key)
		if val != nil {
			expire := binary.BigEndian.Uint32(val)
			if now >= int64(expire) {
				val = nil
			} else {
				for i := 4; i < len(val); i += 32 {
					hash := val[i : i+32]
					var hash32 [32]byte
					copy(hash32[:], hash[0:32])
					_, found := c.hashToHost[hash32]
					if found {
						log.Debug("%s: found in cache: %s: blocked by %v", c.svc, c.host, hash32)
						return 1
					}
				}
			}
		}
		if val == nil {
			hashesToRequest[k] = v
		}
	}

	if len(hashesToRequest) == 0 {
		log.Debug("%s: found in cache: %s: not blocked", c.svc, c.host)
		return -1
	}

	c.hashToHost = hashesToRequest
	return 0
}

type sbCtx struct {
	host       string
	svc        string
	hashToHost map[[32]byte]string
	cache      cache.Cache
	cacheTime  uint
}

func hostnameToHashes(host string) map[[32]byte]string {
	hashes := map[[32]byte]string{}
	tld, icann := publicsuffix.PublicSuffix(host)
	if !icann {
		// private suffixes like cloudfront.net
		tld = ""
	}
	curhost := host

	nDots := 0
	for i := len(curhost) - 1; i >= 0; i-- {
		if curhost[i] == '.' {
			nDots++
			if nDots == 4 {
				curhost = curhost[i+1:] // "xxx.a.b.c.d" -> "a.b.c.d"
				break
			}
		}
	}

	for {
		if curhost == "" {
			// we've reached end of string
			break
		}
		if tld != "" && curhost == tld {
			// we've reached the TLD, don't hash it
			break
		}

		sum := sha256.Sum256([]byte(curhost))
		hashes[sum] = curhost

		pos := strings.IndexByte(curhost, byte('.'))
		if pos < 0 {
			break
		}
		curhost = curhost[pos+1:]
	}
	return hashes
}

// convert hash array to string
func (c *sbCtx) getQuestion() string {
	q := ""
	for hash := range c.hashToHost {
		q += fmt.Sprintf("%s.", hex.EncodeToString(hash[0:2]))
	}
	if c.svc == "SafeBrowsing" {
		q += sbTXTSuffix
	} else {
		q += pcTXTSuffix
	}
	return q
}

// Find the target hash in TXT response
func (c *sbCtx) processTXT(resp *dns.Msg) (bool, [][]byte) {
	matched := false
	hashes := [][]byte{}
	for _, a := range resp.Answer {
		txt, ok := a.(*dns.TXT)
		if !ok {
			continue
		}
		log.Debug("%s: received hashes for %s: %v", c.svc, c.host, txt.Txt)

		for _, t := range txt.Txt {

			if len(t) != 32*2 {
				continue
			}
			hash, err := hex.DecodeString(t)
			if err != nil {
				continue
			}

			hashes = append(hashes, hash)

			if !matched {
				var hash32 [32]byte
				copy(hash32[:], hash)
				hashHost, ok := c.hashToHost[hash32]
				if ok {
					log.Debug("%s: matched %s by %s/%s", c.svc, c.host, hashHost, t)
					matched = true
				}
			}
		}
	}

	return matched, hashes
}

func (c *sbCtx) storeCache(hashes [][]byte) {
	sort.Slice(hashes, func(a, b int) bool {
		return bytes.Compare(hashes[a], hashes[b]) < 0
	})

	var curData []byte
	var prevPrefix []byte
	for i, hash := range hashes {
		prefix := hash[0:2]
		if !bytes.Equal(prefix, prevPrefix) {
			if i != 0 {
				c.setCache(prevPrefix, curData)
				curData = nil
			}
			prevPrefix = hashes[i][0:2]
		}
		curData = append(curData, hash...)
	}

	if len(prevPrefix) != 0 {
		c.setCache(prevPrefix, curData)
	}

	for hash := range c.hashToHost {
		prefix := hash[0:2]
		val := c.cache.Get(prefix)
		if val == nil {
			c.setCache(prefix, nil)
		}
	}
}

// Disabling "dupl": the algorithm of SB/PC is similar, but it uses different data
// nolint:dupl
func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
	if log.GetLevel() >= log.DEBUG {
		timer := log.StartTimer()
		defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
	}

	result := Result{}
	hashes := hostnameToHashes(host)

	c := &sbCtx{
		host:       host,
		svc:        "SafeBrowsing",
		hashToHost: hashes,
		cache:      gctx.safebrowsingCache,
		cacheTime:  d.Config.CacheTime,
	}

	// check cache
	match := c.getCached()
	if match < 0 {
		return result, nil
	} else if match > 0 {
		result.IsFiltered = true
		result.Reason = FilteredSafeBrowsing
		result.Rule = "adguard-malware-shavar"
		return result, nil
	}

	question := c.getQuestion()
	log.Tracef("SafeBrowsing: checking %s: %s", host, question)

	req := dns.Msg{}
	req.SetQuestion(question, dns.TypeTXT)
	resp, err := d.safeBrowsingUpstream.Exchange(&req)
	if err != nil {
		return result, err
	}

	matched, receivedHashes := c.processTXT(resp)
	if matched {
		result.IsFiltered = true
		result.Reason = FilteredSafeBrowsing
		result.Rule = "adguard-malware-shavar"
	}
	c.storeCache(receivedHashes)

	return result, nil
}

// Disabling "dupl": the algorithm of SB/PC is similar, but it uses different data
// nolint:dupl
func (d *Dnsfilter) checkParental(host string) (Result, error) {
	if log.GetLevel() >= log.DEBUG {
		timer := log.StartTimer()
		defer timer.LogElapsed("Parental lookup for %s", host)
	}

	result := Result{}
	hashes := hostnameToHashes(host)

	c := &sbCtx{
		host:       host,
		svc:        "Parental",
		hashToHost: hashes,
		cache:      gctx.parentalCache,
		cacheTime:  d.Config.CacheTime,
	}

	// check cache
	match := c.getCached()
	if match < 0 {
		return result, nil
	} else if match > 0 {
		result.IsFiltered = true
		result.Reason = FilteredParental
		result.Rule = "parental CATEGORY_BLACKLISTED"
		return result, nil
	}

	question := c.getQuestion()
	log.Tracef("Parental: checking %s: %s", host, question)

	req := dns.Msg{}
	req.SetQuestion(question, dns.TypeTXT)
	resp, err := d.parentalUpstream.Exchange(&req)
	if err != nil {
		return result, err
	}

	matched, receivedHashes := c.processTXT(resp)
	if matched {
		result.IsFiltered = true
		result.Reason = FilteredParental
		result.Rule = "parental CATEGORY_BLACKLISTED"
	}
	c.storeCache(receivedHashes)

	return result, err
}

func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
	text := fmt.Sprintf(format, args...)
	log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text)
	http.Error(w, text, code)
}

func (d *Dnsfilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
	d.Config.SafeBrowsingEnabled = true
	d.Config.ConfigModified()
}

func (d *Dnsfilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) {
	d.Config.SafeBrowsingEnabled = false
	d.Config.ConfigModified()
}

func (d *Dnsfilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
	data := map[string]interface{}{
		"enabled": d.Config.SafeBrowsingEnabled,
	}
	jsonVal, err := json.Marshal(data)
	if err != nil {
		httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
	}

	w.Header().Set("Content-Type", "application/json")
	_, err = w.Write(jsonVal)
	if err != nil {
		httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
		return
	}
}

func (d *Dnsfilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) {
	d.Config.ParentalEnabled = true
	d.Config.ConfigModified()
}

func (d *Dnsfilter) handleParentalDisable(w http.ResponseWriter, r *http.Request) {
	d.Config.ParentalEnabled = false
	d.Config.ConfigModified()
}

func (d *Dnsfilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) {
	data := map[string]interface{}{
		"enabled": d.Config.ParentalEnabled,
	}
	jsonVal, err := json.Marshal(data)
	if err != nil {
		httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	_, err = w.Write(jsonVal)
	if err != nil {
		httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
		return
	}
}

func (d *Dnsfilter) registerSecurityHandlers() {
	d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
	d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
	d.Config.HTTPRegister("GET", "/control/safebrowsing/status", d.handleSafeBrowsingStatus)

	d.Config.HTTPRegister("POST", "/control/parental/enable", d.handleParentalEnable)
	d.Config.HTTPRegister("POST", "/control/parental/disable", d.handleParentalDisable)
	d.Config.HTTPRegister("GET", "/control/parental/status", d.handleParentalStatus)

	d.Config.HTTPRegister("POST", "/control/safesearch/enable", d.handleSafeSearchEnable)
	d.Config.HTTPRegister("POST", "/control/safesearch/disable", d.handleSafeSearchDisable)
	d.Config.HTTPRegister("GET", "/control/safesearch/status", d.handleSafeSearchStatus)
}