Added a unit-test for custom upstreams
This commit is contained in:
parent
67a39045fc
commit
ae51de9335
@ -249,6 +249,39 @@ func TestBlockedRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerCustomClientUpstream(t *testing.T) {
|
||||
s := createTestServer(t)
|
||||
err := s.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
}
|
||||
s.conf.GetCustomUpstreamByClient = func(clientAddr string) *proxy.UpstreamConfig {
|
||||
uc := &proxy.UpstreamConfig{}
|
||||
u := &testUpstream{}
|
||||
u.ipv4 = map[string][]net.IP{}
|
||||
u.ipv4["host."] = []net.IP{net.ParseIP("192.168.0.1")}
|
||||
uc.Upstreams = append(uc.Upstreams, u)
|
||||
return uc
|
||||
}
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
// Send test request
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: "host.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
|
||||
reply, err := dns.Exchange(&req, addr.String())
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
assert.NotNil(t, reply.Answer)
|
||||
assert.Equal(t, "192.168.0.1", reply.Answer[0].(*dns.A).A.String())
|
||||
assert.Nil(t, s.Stop())
|
||||
}
|
||||
|
||||
// testUpstream is a mock of real upstream.
|
||||
// specify fields with necessary values to simulate real upstream behaviour
|
||||
type testUpstream struct {
|
||||
|
@ -1,14 +0,0 @@
|
||||
package dnsforward
|
||||
|
||||
import "net"
|
||||
|
||||
// GetIPString is a helper function that extracts IP address from net.Addr
|
||||
func GetIPString(addr net.Addr) string {
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
return addr.IP.String()
|
||||
case *net.TCPAddr:
|
||||
return addr.IP.String()
|
||||
}
|
||||
return ""
|
||||
}
|
@ -8,6 +8,17 @@ import (
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
)
|
||||
|
||||
// GetIPString is a helper function that extracts IP address from net.Addr
|
||||
func GetIPString(addr net.Addr) string {
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
return addr.IP.String()
|
||||
case *net.TCPAddr:
|
||||
return addr.IP.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func stringArrayDup(a []string) []string {
|
||||
a2 := make([]string, len(a))
|
||||
copy(a2, a)
|
||||
|
Loading…
Reference in New Issue
Block a user