package dnsforward import ( "bytes" "encoding/binary" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) // uint* sizes in bytes to improve readability. // // TODO(e.burkov): Remove when there will be a more regardful way to define // those. See https://github.com/golang/go/issues/29982. const ( uint16sz = 2 uint64sz = 8 ) // recursionDetector detects recursion in DNS forwarding. type recursionDetector struct { recentRequests cache.Cache ttl time.Duration } // check checks if the passed req was already sent by the server. func (rd *recursionDetector) check(msg dns.Msg) (ok bool) { if len(msg.Question) == 0 { return false } key := msgToSignature(msg) expireData := rd.recentRequests.Get(key) if expireData == nil { return false } expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData))) return time.Now().Before(expire) } // add caches the msg if it has anything in the questions section. func (rd *recursionDetector) add(msg dns.Msg) { now := time.Now() if len(msg.Question) == 0 { return } key := msgToSignature(msg) expire64 := uint64(now.Add(rd.ttl).UnixNano()) expire := make([]byte, uint64sz) binary.BigEndian.PutUint64(expire, expire64) rd.recentRequests.Set(key, expire) } // clear clears the recent requests cache. func (rd *recursionDetector) clear() { rd.recentRequests.Clear() } // newRecursionDetector returns the initialized *recursionDetector. func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) { return &recursionDetector{ recentRequests: cache.New(cache.Config{ EnableLRU: true, MaxCount: suspectsNum, }), ttl: ttl, } } // msgToSignature converts msg into it's signature represented in bytes. func msgToSignature(msg dns.Msg) (sig []byte) { sig = make([]byte, uint16sz*2+aghnet.MaxDomainNameLen) // The binary.BigEndian byte order is used everywhere except when the // real machine's endianess is needed. byteOrder := binary.BigEndian byteOrder.PutUint16(sig[0:], msg.Id) q := msg.Question[0] byteOrder.PutUint16(sig[uint16sz:], q.Qtype) copy(sig[2*uint16sz:], []byte(q.Name)) return sig } // msgToSignatureSlow converts msg into it's signature represented in bytes in // the less efficient way. // // See BenchmarkMsgToSignature. func msgToSignatureSlow(msg dns.Msg) (sig []byte) { type msgSignature struct { name [aghnet.MaxDomainNameLen]byte id uint16 qtype uint16 } b := bytes.NewBuffer(sig) q := msg.Question[0] signature := msgSignature{ id: msg.Id, qtype: q.Qtype, } copy(signature.name[:], q.Name) if err := binary.Write(b, binary.BigEndian, signature); err != nil { log.Debug("writing message signature: %s", err) } return b.Bytes() }