badguardhome/internal/aghtest/upstream.go

178 lines
3.7 KiB
Go

package aghtest
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"strings"
"sync"
"github.com/miekg/dns"
)
// TestUpstream is a mock of real upstream.
type TestUpstream struct {
// Addr is the address for Address method.
Addr string
// CName is a map of hostname to canonical name.
CName map[string]string
// IPv4 is a map of hostname to IPv4.
IPv4 map[string][]net.IP
// IPv6 is a map of hostname to IPv6.
IPv6 map[string][]net.IP
// Reverse is a map of address to domain name.
Reverse map[string][]string
}
// Exchange implements upstream.Upstream interface for *TestUpstream.
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)
if len(m.Question) == 0 {
return nil, fmt.Errorf("question should not be empty")
}
name := m.Question[0].Name
if cname, ok := u.CName[name]; ok {
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
},
Target: cname,
})
}
var hasRec bool
var rrType uint16
var ips []net.IP
switch m.Question[0].Qtype {
case dns.TypeA:
rrType = dns.TypeA
if ipv4addr, ok := u.IPv4[name]; ok {
hasRec = true
ips = ipv4addr
}
case dns.TypeAAAA:
rrType = dns.TypeAAAA
if ipv6addr, ok := u.IPv6[name]; ok {
hasRec = true
ips = ipv6addr
}
case dns.TypePTR:
names, ok := u.Reverse[name]
if !ok {
break
}
for _, n := range names {
resp.Answer = append(resp.Answer, &dns.PTR{
Hdr: dns.RR_Header{
Name: n,
Rrtype: rrType,
},
Ptr: n,
})
}
}
for _, ip := range ips {
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: rrType,
},
A: ip,
})
}
if len(resp.Answer) == 0 {
if hasRec {
// Set no error RCode if there are some records for
// given Qname but we didn't apply them.
resp.SetRcode(m, dns.RcodeSuccess)
return resp, nil
}
// Set NXDomain RCode otherwise.
resp.SetRcode(m, dns.RcodeNameError)
}
return resp, nil
}
// Address implements upstream.Upstream interface for *TestUpstream.
func (u *TestUpstream) Address() string {
return u.Addr
}
// TestBlockUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestBlockUpstream struct {
Hostname string
Block bool
requestsCount int
lock sync.RWMutex
}
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
// pair.
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock()
defer u.lock.Unlock()
u.requestsCount++
hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:])
if !u.Block {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
}
m := &dns.Msg{}
m.Answer = []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
},
Txt: []string{
hashToReturn,
},
},
}
return m, nil
}
// Address always returns an empty string.
func (u *TestBlockUpstream) Address() string {
return ""
}
// RequestsCount returns the number of handled requests. It's safe for
// concurrent use.
func (u *TestBlockUpstream) RequestsCount() int {
u.lock.Lock()
defer u.lock.Unlock()
return u.requestsCount
}
// TestErrUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestErrUpstream struct {
// The error returned by Exchange may be unwraped to the Err.
Err error
}
// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, fmt.Errorf("errupstream: %w", u.Err)
}
// Address always returns an empty string.
func (u *TestErrUpstream) Address() string {
return ""
}