Added a unit-test for custom upstreams
This commit is contained in:
parent
67a39045fc
commit
ae51de9335
@ -249,6 +249,39 @@ func TestBlockedRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerCustomClientUpstream(t *testing.T) {
|
||||||
|
s := createTestServer(t)
|
||||||
|
err := s.Start()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
s.conf.GetCustomUpstreamByClient = func(clientAddr string) *proxy.UpstreamConfig {
|
||||||
|
uc := &proxy.UpstreamConfig{}
|
||||||
|
u := &testUpstream{}
|
||||||
|
u.ipv4 = map[string][]net.IP{}
|
||||||
|
u.ipv4["host."] = []net.IP{net.ParseIP("192.168.0.1")}
|
||||||
|
uc.Upstreams = append(uc.Upstreams, u)
|
||||||
|
return uc
|
||||||
|
}
|
||||||
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
|
// Send test request
|
||||||
|
req := dns.Msg{}
|
||||||
|
req.Id = dns.Id()
|
||||||
|
req.RecursionDesired = true
|
||||||
|
req.Question = []dns.Question{
|
||||||
|
{Name: "host.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||||
|
}
|
||||||
|
|
||||||
|
reply, err := dns.Exchange(&req, addr.String())
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||||
|
assert.NotNil(t, reply.Answer)
|
||||||
|
assert.Equal(t, "192.168.0.1", reply.Answer[0].(*dns.A).A.String())
|
||||||
|
assert.Nil(t, s.Stop())
|
||||||
|
}
|
||||||
|
|
||||||
// testUpstream is a mock of real upstream.
|
// testUpstream is a mock of real upstream.
|
||||||
// specify fields with necessary values to simulate real upstream behaviour
|
// specify fields with necessary values to simulate real upstream behaviour
|
||||||
type testUpstream struct {
|
type testUpstream struct {
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
package dnsforward
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
// GetIPString is a helper function that extracts IP address from net.Addr
|
|
||||||
func GetIPString(addr net.Addr) string {
|
|
||||||
switch addr := addr.(type) {
|
|
||||||
case *net.UDPAddr:
|
|
||||||
return addr.IP.String()
|
|
||||||
case *net.TCPAddr:
|
|
||||||
return addr.IP.String()
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
@ -8,6 +8,17 @@ import (
|
|||||||
"github.com/AdguardTeam/golibs/utils"
|
"github.com/AdguardTeam/golibs/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetIPString is a helper function that extracts IP address from net.Addr
|
||||||
|
func GetIPString(addr net.Addr) string {
|
||||||
|
switch addr := addr.(type) {
|
||||||
|
case *net.UDPAddr:
|
||||||
|
return addr.IP.String()
|
||||||
|
case *net.TCPAddr:
|
||||||
|
return addr.IP.String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func stringArrayDup(a []string) []string {
|
func stringArrayDup(a []string) []string {
|
||||||
a2 := make([]string, len(a))
|
a2 := make([]string, len(a))
|
||||||
copy(a2, a)
|
copy(a2, a)
|
||||||
|
Loading…
Reference in New Issue
Block a user