package aghtest import ( "crypto/sha256" "encoding/hex" "fmt" "net" "strings" "sync" "github.com/miekg/dns" ) // TestUpstream is a mock of real upstream. type TestUpstream struct { // Addr is the address for Address method. Addr string // CName is a map of hostname to canonical name. CName map[string]string // IPv4 is a map of hostname to IPv4. IPv4 map[string][]net.IP // IPv6 is a map of hostname to IPv6. IPv6 map[string][]net.IP // Reverse is a map of address to domain name. Reverse map[string][]string } // Exchange implements upstream.Upstream interface for *TestUpstream. func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { resp = &dns.Msg{} resp.SetReply(m) if len(m.Question) == 0 { return nil, fmt.Errorf("question should not be empty") } name := m.Question[0].Name if cname, ok := u.CName[name]; ok { resp.Answer = append(resp.Answer, &dns.CNAME{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeCNAME, }, Target: cname, }) } var hasRec bool var rrType uint16 var ips []net.IP switch m.Question[0].Qtype { case dns.TypeA: rrType = dns.TypeA if ipv4addr, ok := u.IPv4[name]; ok { hasRec = true ips = ipv4addr } case dns.TypeAAAA: rrType = dns.TypeAAAA if ipv6addr, ok := u.IPv6[name]; ok { hasRec = true ips = ipv6addr } case dns.TypePTR: names, ok := u.Reverse[name] if !ok { break } for _, n := range names { resp.Answer = append(resp.Answer, &dns.PTR{ Hdr: dns.RR_Header{ Name: n, Rrtype: rrType, }, Ptr: n, }) } } for _, ip := range ips { resp.Answer = append(resp.Answer, &dns.A{ Hdr: dns.RR_Header{ Name: name, Rrtype: rrType, }, A: ip, }) } if len(resp.Answer) == 0 { if hasRec { // Set no error RCode if there are some records for // given Qname but we didn't apply them. resp.SetRcode(m, dns.RcodeSuccess) return resp, nil } // Set NXDomain RCode otherwise. resp.SetRcode(m, dns.RcodeNameError) } return resp, nil } // Address implements upstream.Upstream interface for *TestUpstream. func (u *TestUpstream) Address() string { return u.Addr } // TestBlockUpstream implements upstream.Upstream interface for replacing real // upstream in tests. type TestBlockUpstream struct { Hostname string Block bool requestsCount int lock sync.RWMutex } // Exchange returns a message unique for TestBlockUpstream's Hostname-Block // pair. func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) { u.lock.Lock() defer u.lock.Unlock() u.requestsCount++ hash := sha256.Sum256([]byte(u.Hostname)) hashToReturn := hex.EncodeToString(hash[:]) if !u.Block { hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28) } m := &dns.Msg{} m.Answer = []dns.RR{ &dns.TXT{ Hdr: dns.RR_Header{ Name: r.Question[0].Name, }, Txt: []string{ hashToReturn, }, }, } return m, nil } // Address always returns an empty string. func (u *TestBlockUpstream) Address() string { return "" } // RequestsCount returns the number of handled requests. It's safe for // concurrent use. func (u *TestBlockUpstream) RequestsCount() int { u.lock.Lock() defer u.lock.Unlock() return u.requestsCount } // TestErrUpstream implements upstream.Upstream interface for replacing real // upstream in tests. type TestErrUpstream struct { // The error returned by Exchange may be unwraped to the Err. Err error } // Exchange always returns nil Msg and non-nil error. func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) { return nil, fmt.Errorf("errupstream: %w", u.Err) } // Address always returns an empty string. func (u *TestErrUpstream) Address() string { return "" }