diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 07f65084..a1afd7a9 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -44,7 +44,7 @@ type Server struct { once sync.Once sync.RWMutex - ServerConfig + conf ServerConfig } // NewServer creates a new instance of the dnsforward.Server @@ -123,7 +123,7 @@ func (s *Server) Start(config *ServerConfig) error { // startInternal starts without locking func (s *Server) startInternal(config *ServerConfig) error { if config != nil { - s.ServerConfig = *config + s.conf = *config } if s.dnsFilter != nil || s.dnsProxy != nil { @@ -158,21 +158,21 @@ func (s *Server) startInternal(config *ServerConfig) error { }) proxyConfig := proxy.Config{ - UDPListenAddr: s.UDPListenAddr, - TCPListenAddr: s.TCPListenAddr, - Ratelimit: s.Ratelimit, - RatelimitWhitelist: s.RatelimitWhitelist, - RefuseAny: s.RefuseAny, + UDPListenAddr: s.conf.UDPListenAddr, + TCPListenAddr: s.conf.TCPListenAddr, + Ratelimit: s.conf.Ratelimit, + RatelimitWhitelist: s.conf.RatelimitWhitelist, + RefuseAny: s.conf.RefuseAny, CacheEnabled: true, - Upstreams: s.Upstreams, - DomainsReservedUpstreams: s.DomainsReservedUpstreams, + Upstreams: s.conf.Upstreams, + DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams, Handler: s.handleDNSRequest, - AllServers: s.AllServers, + AllServers: s.conf.AllServers, } - if s.TLSListenAddr != nil && s.CertificateChain != "" && s.PrivateKey != "" { - proxyConfig.TLSListenAddr = s.TLSListenAddr - keypair, err := tls.X509KeyPair([]byte(s.CertificateChain), []byte(s.PrivateKey)) + if s.conf.TLSListenAddr != nil && s.conf.CertificateChain != "" && s.conf.PrivateKey != "" { + proxyConfig.TLSListenAddr = s.conf.TLSListenAddr + keypair, err := tls.X509KeyPair([]byte(s.conf.CertificateChain), []byte(s.conf.PrivateKey)) if err != nil { return errorx.Decorate(err, "Failed to parse TLS keypair") } @@ -202,10 +202,10 @@ func (s *Server) startInternal(config *ServerConfig) error { // Initializes the DNS filter func (s *Server) initDNSFilter() error { log.Tracef("Creating dnsfilter") - s.dnsFilter = dnsfilter.New(&s.Config) + s.dnsFilter = dnsfilter.New(&s.conf.Config) // add rules only if they are enabled - if s.FilteringEnabled { - err := s.dnsFilter.AddRules(s.Filters) + if s.conf.FilteringEnabled { + err := s.dnsFilter.AddRules(s.conf.Filters) if err != nil { return errorx.Decorate(err, "could not initialize dnsfilter") } @@ -336,11 +336,11 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { msg := d.Req // don't log ANY request if refuseAny is enabled - if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny { + if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny { shouldLog = false } - if s.QueryLogEnabled && shouldLog { + if s.conf.QueryLogEnabled && shouldLog { elapsed := time.Since(start) upstreamAddr := "" if d.Upstream != nil { @@ -361,7 +361,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error host := strings.TrimSuffix(msg.Question[0].Name, ".") s.RLock() - protectionEnabled := s.ProtectionEnabled + protectionEnabled := s.conf.ProtectionEnabled dnsFilter := s.dnsFilter s.RUnlock() @@ -402,7 +402,7 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu return s.genARecord(m, result.IP) } - if s.BlockingMode == "null_ip" { + if s.conf.BlockingMode == "null_ip" { return s.genARecord(m, net.IPv4zero) } @@ -420,7 +420,7 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg { resp := dns.Msg{} resp.SetReply(request) - answer, err := dns.NewRR(fmt.Sprintf("%s %d A %s", request.Question[0].Name, s.BlockedResponseTTL, ip.String())) + answer, err := dns.NewRR(fmt.Sprintf("%s %d A %s", request.Question[0].Name, s.conf.BlockedResponseTTL, ip.String())) if err != nil { log.Printf("Couldn't generate A record for replacement host '%s': %s", ip.String(), err) return s.genServerFailure(request) @@ -489,7 +489,7 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR { Hdr: dns.RR_Header{ Name: zone, Rrtype: dns.TypeSOA, - Ttl: s.BlockedResponseTTL, + Ttl: s.conf.BlockedResponseTTL, Class: dns.ClassINET, }, Mbox: "hostmaster.", // zone will be appended later if it's not empty or "." diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 0c1325c7..1a38e96b 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -86,7 +86,7 @@ func TestDotServer(t *testing.T) { s := createTestServer(t) defer removeDataDir(t) - s.TLSConfig = TLSConfig{ + s.conf.TLSConfig = TLSConfig{ TLSListenAddr: &net.TCPAddr{Port: 0}, CertificateChain: string(certPem), PrivateKey: string(keyPem), @@ -149,7 +149,7 @@ func TestServerRace(t *testing.T) { func TestSafeSearch(t *testing.T) { s := createTestServer(t) - s.SafeSearchEnabled = true + s.conf.SafeSearchEnabled = true defer removeDataDir(t) err := s.Start(nil) if err != nil { @@ -295,7 +295,7 @@ func TestBlockedRequest(t *testing.T) { func TestNullBlockedRequest(t *testing.T) { s := createTestServer(t) - s.FilteringConfig.BlockingMode = "null_ip" + s.conf.FilteringConfig.BlockingMode = "null_ip" defer removeDataDir(t) err := s.Start(nil) if err != nil { @@ -451,14 +451,14 @@ func TestBlockedBySafeBrowsing(t *testing.T) { func createTestServer(t *testing.T) *Server { s := NewServer(createDataDir(t)) - s.UDPListenAddr = &net.UDPAddr{Port: 0} - s.TCPListenAddr = &net.TCPAddr{Port: 0} + s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} + s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} - s.QueryLogEnabled = true - s.FilteringConfig.FilteringEnabled = true - s.FilteringConfig.ProtectionEnabled = true - s.FilteringConfig.SafeBrowsingEnabled = true - s.Filters = make([]dnsfilter.Filter, 0) + s.conf.QueryLogEnabled = true + s.conf.FilteringConfig.FilteringEnabled = true + s.conf.FilteringConfig.ProtectionEnabled = true + s.conf.FilteringConfig.SafeBrowsingEnabled = true + s.conf.Filters = make([]dnsfilter.Filter, 0) rules := []string{ "||nxdomain.example.org^", @@ -466,7 +466,7 @@ func createTestServer(t *testing.T) *Server { "127.0.0.1 host.example.org", } filter := dnsfilter.Filter{ID: 1, Rules: rules} - s.Filters = append(s.Filters, filter) + s.conf.Filters = append(s.conf.Filters, filter) return s }