Merge: Don't use the system default resolver implicitly
Close #647 * commit 'a1b18776678ee894cd0b558cf8683845056a2dfa': + parental, safesearch: use our own DNS resolver instead of system default + control: use our own DNS resolver instead of system default
This commit is contained in:
commit
00ba63341b
|
@ -35,8 +35,13 @@ var protocols = []string{"tls://", "https://", "tcp://", "sdns://"}
|
|||
const versionCheckURL = "https://adguardteam.github.io/AdGuardHome/version.json"
|
||||
const versionCheckPeriod = time.Hour * 8
|
||||
|
||||
var transport = &http.Transport{
|
||||
DialContext: customDialContext,
|
||||
}
|
||||
|
||||
var client = &http.Client{
|
||||
Timeout: time.Minute * 5,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
var controlLock sync.Mutex
|
||||
|
|
1
dns.go
1
dns.go
|
@ -50,6 +50,7 @@ func generateServerConfig() dnsforward.ServerConfig {
|
|||
FilteringConfig: config.DNS.FilteringConfig,
|
||||
Filters: filters,
|
||||
}
|
||||
newconfig.ResolverAddress = fmt.Sprintf("%s:%d", config.DNS.BindHost, config.DNS.Port)
|
||||
|
||||
if config.TLS.Enabled {
|
||||
newconfig.TLSConfig = config.TLS.TLSConfig
|
||||
|
|
|
@ -3,6 +3,7 @@ package dnsfilter
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -16,6 +17,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/bluele/gcache"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
|
@ -49,6 +51,7 @@ type Config struct {
|
|||
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
||||
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
||||
ResolverAddress string // DNS server address
|
||||
}
|
||||
|
||||
type privateConfig struct {
|
||||
|
@ -159,6 +162,8 @@ var (
|
|||
safeSearchCache gcache.Cache
|
||||
)
|
||||
|
||||
var resolverAddr string // DNS server address
|
||||
|
||||
// Result holds state of hostname check
|
||||
type Result struct {
|
||||
IsFiltered bool `json:",omitempty"` // True if the host name is filtered
|
||||
|
@ -971,6 +976,47 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
|
|||
// lifecycle helper functions
|
||||
//
|
||||
|
||||
// Connect to a remote server resolving hostname using our own DNS server
|
||||
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log.Tracef("network:%v addr:%v", network, addr)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
return con, err
|
||||
}
|
||||
|
||||
r := upstream.NewResolver(resolverAddr, 30*time.Second)
|
||||
addrs, e := r.LookupIPAddr(ctx, host)
|
||||
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
firstErr = nil
|
||||
for _, a := range addrs {
|
||||
addr = fmt.Sprintf("%s:%s", a.String(), port)
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return con, err
|
||||
}
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
// New creates properly initialized DNS Filter that is ready to be used
|
||||
func New(c *Config) *Dnsfilter {
|
||||
d := new(Dnsfilter)
|
||||
|
@ -990,6 +1036,10 @@ func New(c *Config) *Dnsfilter {
|
|||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
if len(c.ResolverAddress) != 0 {
|
||||
resolverAddr = c.ResolverAddress
|
||||
d.transport.DialContext = customDialContext
|
||||
}
|
||||
d.client = http.Client{
|
||||
Transport: d.transport,
|
||||
Timeout: defaultHTTPTimeout,
|
||||
|
|
46
helpers.go
46
helpers.go
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -14,7 +15,10 @@ import (
|
|||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
|
@ -300,6 +304,48 @@ func checkPacketPortAvailable(host string, port int) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Connect to a remote server resolving hostname using our own DNS server
|
||||
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log.Tracef("network:%v addr:%v", network, addr)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
return con, err
|
||||
}
|
||||
|
||||
resolverAddr := fmt.Sprintf("%s:%d", config.DNS.BindHost, config.DNS.Port)
|
||||
r := upstream.NewResolver(resolverAddr, 30*time.Second)
|
||||
addrs, e := r.LookupIPAddr(ctx, host)
|
||||
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
firstErr = nil
|
||||
for _, a := range addrs {
|
||||
addr = fmt.Sprintf("%s:%s", a.String(), port)
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return con, err
|
||||
}
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
// ---------------------
|
||||
// debug logging helpers
|
||||
// ---------------------
|
||||
|
|
Loading…
Reference in New Issue