package dnsforward

import (
	"bytes"
	"crypto/tls"
	"fmt"
	"io/ioutil"
	"log"
	"math/rand"
	"net"
	"net/http"
	"net/url"
	"strings"
	"sync"
	"time"

	"github.com/joomcode/errorx"
	"github.com/miekg/dns"
)

const defaultTimeout = time.Second * 10

type Upstream interface {
	Exchange(m *dns.Msg) (*dns.Msg, error)
	Address() string
}

//
// plain DNS
//
type plainDNS struct {
	address string
}

var defaultUDPClient = dns.Client{
	Timeout: defaultTimeout,
	UDPSize: dns.MaxMsgSize,
}

var defaultTCPClient = dns.Client{
	Net:     "tcp",
	UDPSize: dns.MaxMsgSize,
	Timeout: defaultTimeout,
}

func (p *plainDNS) Address() string { return p.address }

func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
	reply, _, err := defaultUDPClient.Exchange(m, p.address)
	if err != nil && reply != nil && reply.Truncated {
		log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
		reply, _, err = defaultTCPClient.Exchange(m, p.address)
	}
	return reply, err
}

//
// DNS-over-TLS
//
type dnsOverTLS struct {
	address string
	pool    *TLSPool

	sync.RWMutex // protects pool
}

var defaultTLSClient = dns.Client{
	Net:       "tcp-tls",
	Timeout:   defaultTimeout,
	UDPSize:   dns.MaxMsgSize,
	TLSConfig: &tls.Config{},
}

func (p *dnsOverTLS) Address() string { return p.address }

func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
	var pool *TLSPool
	p.RLock()
	pool = p.pool
	p.RUnlock()
	if pool == nil {
		p.Lock()
		// lazy initialize it
		p.pool = &TLSPool{Address: p.address}
		p.Unlock()
	}

	p.RLock()
	poolConn, err := p.pool.Get()
	p.RUnlock()
	if err != nil {
		return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.address)
	}
	c := dns.Conn{Conn: poolConn}
	err = c.WriteMsg(m)
	if err != nil {
		poolConn.Close()
		return nil, errorx.Decorate(err, "Failed to send a request to %s", p.address)
	}

	reply, err := c.ReadMsg()
	if err != nil {
		poolConn.Close()
		return nil, errorx.Decorate(err, "Failed to read a request from %s", p.address)
	}
	p.RLock()
	p.pool.Put(poolConn)
	p.RUnlock()
	return reply, nil
}

//
// DNS-over-https
//
type dnsOverHTTPS struct {
	address string
}

var defaultHTTPSTransport = http.Transport{}

var defaultHTTPSClient = http.Client{
	Transport: &defaultHTTPSTransport,
	Timeout:   defaultTimeout,
}

func (p *dnsOverHTTPS) Address() string { return p.address }

func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
	buf, err := m.Pack()
	if err != nil {
		return nil, errorx.Decorate(err, "Couldn't pack request msg")
	}
	bb := bytes.NewBuffer(buf)
	resp, err := http.Post(p.address, "application/dns-message", bb)
	if resp != nil && resp.Body != nil {
		defer resp.Body.Close()
	}
	if err != nil {
		return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.address)
	}
	body, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.address)
	}
	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.address)
	}
	if len(body) == 0 {
		return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.address)
	}
	response := dns.Msg{}
	err = response.Unpack(body)
	if err != nil {
		return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.address, string(body))
	}
	return &response, nil
}

func (s *Server) chooseUpstream() Upstream {
	upstreams := s.Upstreams
	if upstreams == nil {
		upstreams = defaultValues.Upstreams
	}
	if len(upstreams) == 0 {
		panic("SHOULD NOT HAPPEN: no default upstreams specified")
	}
	if len(upstreams) == 1 {
		return upstreams[0]
	}
	n := rand.Intn(len(upstreams))
	upstream := upstreams[n]
	return upstream
}

func GetUpstream(address string) (Upstream, error) {
	if strings.Contains(address, "://") {
		url, err := url.Parse(address)
		if err != nil {
			return nil, errorx.Decorate(err, "Failed to parse %s", address)
		}
		switch url.Scheme {
		case "dns":
			if url.Port() == "" {
				url.Host += ":53"
			}
			return &plainDNS{address: url.String()}, nil
		case "tls":
			if url.Port() == "" {
				url.Host += ":853"
			}
			return &dnsOverTLS{address: url.String()}, nil
		case "https":
			return &dnsOverHTTPS{address: url.String()}, nil
		default:
			// assume it's plain DNS
			if url.Port() == "" {
				url.Host += ":53"
			}
			return &plainDNS{address: url.String()}, nil
		}
	}

	// we don't have scheme in the url, so it's just a plain DNS host:port
	_, _, err := net.SplitHostPort(address)
	if err != nil {
		// doesn't have port, default to 53
		address = net.JoinHostPort(address, "53")
	}
	return &plainDNS{address: address}, nil
}