* SB/PC: use 4-character hash in request

* use hash prefix as the cache key

Squashed commit of the following:

commit d719a84ee9b9cf43aaab4f53d07451645ea836db
Merge: d9d6d443 97df1989
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Aug 13 11:41:28 2020 +0300

    Merge remote-tracking branch 'origin/master' into sbpc

commit d9d6d44376c44959f2216b08e577d8e5c5f65bff
Merge: 0a8b2483 de92c852
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Thu Aug 13 11:23:14 2020 +0300

    Merge remote-tracking branch 'origin/master' into sbpc

commit 0a8b24839683683a9d327ecf57a7d182b3996b1d
Merge: 0255a24a 9b9902f0
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Wed Aug 12 20:21:22 2020 +0300

    Merge remote-tracking branch 'origin/master' into sbpc

commit 0255a24a191efd2e4ef23d6a00a7a9fed8831730
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Wed Aug 12 17:16:57 2020 +0300

    - TestServerCustomClientUpstream(): fix

commit d2311902f887be9621a9d9312c73f899dd269440
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Wed Aug 12 17:07:12 2020 +0300

    * SB/PC: hard-code Family server IP addresses to prevent from requesting them at runtime

commit ee340108f11f98d49a7af2a7e8a228c25ab1537a
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Wed Aug 12 17:05:57 2020 +0300

    * dnsproxy v0.30.1

commit f5f53ba7116ad525204d00b80352202eee88b78c
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Wed Aug 12 14:53:23 2020 +0300

    minor

commit fb4631e2cd570b0fd5ae26ec2b1890361275a5a8
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Aug 10 20:07:27 2020 +0300

    * SB/PC: implement new cache

commit f9f58461a6efbcfacd798f7640a4645cf1971cb2
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Aug 7 19:31:05 2020 +0300

    doc

commit ed69626a6c119ab1a3b187f5afbd4cef708c3159
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Aug 7 18:25:57 2020 +0300

    * SB/PC: use hostname prefix for cache

commit afa8040c8c0836c7e59e6fb9aaf1caccd132ea8f
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jul 31 11:19:49 2020 +0300

    * SB/PC: use 4-character hash in request
This commit is contained in:
Simon Zolin 2020-08-13 11:49:42 +03:00
parent 97df19898f
commit 705a9d909d
9 changed files with 469 additions and 185 deletions

View File

@ -64,6 +64,7 @@ Contents:
* API: Log in * API: Log in
* API: Log out * API: Log out
* API: Get current user info * API: Get current user info
* Safe services
## Relations between subsystems ## Relations between subsystems
@ -1747,3 +1748,40 @@ Response:
} }
If no client is configured then authentication is disabled and server sends an empty response. If no client is configured then authentication is disabled and server sends an empty response.
### Safe services
Check if host name is blocked by SB/PC service:
* For each host name component, search for the result in cache by the first 2 bytes of SHA-256 hashes of host name components (max. is 4, i.e. sub2.sub1.host.com), excluding TLD:
hashes[] = cache_search(sha256(host.com)[0..1])
...
If hash prefix is found, search for a full hash sum in the cached data.
If found, the host is blocked.
If not found, the host is not blocked - don't request data for this prefix from the Family server again.
If hash prefix is not found, request data for this prefix from the Family server.
* Prepare query string which is generated from the first 2 bytes (converted to a 4-character string) of SHA-256 hashes of host name components (max. is 4, i.e. sub2.sub1.host.com), excluding TLD:
qs = ... + string(sha256(sub.host.com)[0..1]) + "." + string(sha256(host.com)[0..1]) + ".sb.dns.adguard.com."
For PC `.pc.dns.adguard.com` suffix is used.
* Send TXT query to Family server, receive response which contains the array of complete hash sums of the blocked hosts
* Check if one of received hash sums (`hashes[]`) matches hash sums for our host name
hashes[0] <> sha256(host.com)
hashes[0] <> sha256(sub.host.com)
hashes[1] <> sha256(host.com)
hashes[1] <> sha256(sub.host.com)
...
* Store all received hash sums in cache:
sha256(host.com)[0..1] -> hashes[0],hashes[1],...
sha256(sub.host.com)[0..1] -> hashes[2],...
...

View File

@ -47,9 +47,9 @@ func NewForTest(c *Config, filters []Filter) *Dnsfilter {
setts = RequestFilteringSettings{} setts = RequestFilteringSettings{}
setts.FilteringEnabled = true setts.FilteringEnabled = true
if c != nil { if c != nil {
c.SafeBrowsingCacheSize = 1000 c.SafeBrowsingCacheSize = 10000
c.ParentalCacheSize = 10000
c.SafeSearchCacheSize = 1000 c.SafeSearchCacheSize = 1000
c.ParentalCacheSize = 1000
c.CacheTime = 30 c.CacheTime = 30
setts.SafeSearchEnabled = c.SafeSearchEnabled setts.SafeSearchEnabled = c.SafeSearchEnabled
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
@ -146,12 +146,6 @@ func TestEtcHostsMatching(t *testing.T) {
// SAFE BROWSING // SAFE BROWSING
func TestSafeBrowsingHash(t *testing.T) {
q, hashes := hostnameToHashParam("1.2.3.4.5.6")
assert.Equal(t, "0132d0fa.b5413b4e.5fa067c1.e7f6c011.", q)
assert.Equal(t, 4, len(hashes))
}
func TestSafeBrowsing(t *testing.T) { func TestSafeBrowsing(t *testing.T) {
d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil)
defer d.Close() defer d.Close()

149
dnsfilter/safe_search.go Normal file
View File

@ -0,0 +1,149 @@
package dnsfilter
import (
"bytes"
"encoding/binary"
"encoding/gob"
"encoding/json"
"fmt"
"net"
"net/http"
"time"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
)
/*
expire byte[4]
res Result
*/
func (d *Dnsfilter) setCacheResult(cache cache.Cache, host string, res Result) int {
var buf bytes.Buffer
expire := uint(time.Now().Unix()) + d.Config.CacheTime*60
var exp []byte
exp = make([]byte, 4)
binary.BigEndian.PutUint32(exp, uint32(expire))
_, _ = buf.Write(exp)
enc := gob.NewEncoder(&buf)
err := enc.Encode(res)
if err != nil {
log.Error("gob.Encode(): %s", err)
return 0
}
val := buf.Bytes()
_ = cache.Set([]byte(host), val)
return len(val)
}
func getCachedResult(cache cache.Cache, host string) (Result, bool) {
data := cache.Get([]byte(host))
if data == nil {
return Result{}, false
}
exp := int(binary.BigEndian.Uint32(data[:4]))
if exp <= int(time.Now().Unix()) {
cache.Del([]byte(host))
return Result{}, false
}
var buf bytes.Buffer
buf.Write(data[4:])
dec := gob.NewDecoder(&buf)
r := Result{}
err := dec.Decode(&r)
if err != nil {
log.Debug("gob.Decode(): %s", err)
return Result{}, false
}
return r, true
}
// SafeSearchDomain returns replacement address for search engine
func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
val, ok := safeSearchDomains[host]
return val, ok
}
func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeSearch: lookup for %s", host)
}
// Check cache. Return cached result if it was found
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host)
if isFound {
// atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
log.Tracef("SafeSearch: found in cache: %s", host)
return cachedValue, nil
}
safeHost, ok := d.SafeSearchDomain(host)
if !ok {
return Result{}, nil
}
res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
if ip := net.ParseIP(safeHost); ip != nil {
res.IP = ip
valLen := d.setCacheResult(gctx.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen)
return res, nil
}
// TODO this address should be resolved with upstream that was configured in dnsforward
addrs, err := net.LookupIP(safeHost)
if err != nil {
log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err)
return Result{}, err
}
for _, i := range addrs {
if ipv4 := i.To4(); ipv4 != nil {
res.IP = ipv4
break
}
}
if len(res.IP) == 0 {
return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost)
}
// Cache result
valLen := d.setCacheResult(gctx.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen)
return res, nil
}
func (d *Dnsfilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeSearchEnabled = true
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeSearchEnabled = false
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": d.Config.SafeSearchEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}

View File

@ -1,4 +1,4 @@
// Parental Control, Safe Browsing, Safe Search // Safe Browsing, Parental Control
package dnsfilter package dnsfilter
@ -6,12 +6,12 @@ import (
"bytes" "bytes"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/gob"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"sort"
"strings" "strings"
"time" "time"
@ -22,9 +22,6 @@ import (
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
) )
// Servers to use for resolution of SB/PC server name
var bootstrapServers = []string{"176.103.130.130", "176.103.130.131"}
const dnsTimeout = 3 * time.Second const dnsTimeout = 3 * time.Second
const defaultSafebrowsingServer = "https://dns-family.adguard.com/dns-query" const defaultSafebrowsingServer = "https://dns-family.adguard.com/dns-query"
const defaultParentalServer = "https://dns-family.adguard.com/dns-query" const defaultParentalServer = "https://dns-family.adguard.com/dns-query"
@ -35,7 +32,15 @@ func (d *Dnsfilter) initSecurityServices() error {
var err error var err error
d.safeBrowsingServer = defaultSafebrowsingServer d.safeBrowsingServer = defaultSafebrowsingServer
d.parentalServer = defaultParentalServer d.parentalServer = defaultParentalServer
opts := upstream.Options{Timeout: dnsTimeout, Bootstrap: bootstrapServers} opts := upstream.Options{
Timeout: dnsTimeout,
ServerIPAddrs: []net.IP{
net.ParseIP("176.103.130.132"),
net.ParseIP("176.103.130.134"),
net.ParseIP("2a00:5a60::bad1:ff"),
net.ParseIP("2a00:5a60::bad2:ff"),
},
}
d.parentalUpstream, err = upstream.AddressToUpstream(d.parentalServer, opts) d.parentalUpstream, err = upstream.AddressToUpstream(d.parentalServer, opts)
if err != nil { if err != nil {
@ -52,115 +57,65 @@ func (d *Dnsfilter) initSecurityServices() error {
/* /*
expire byte[4] expire byte[4]
res Result hash byte[32]
...
*/ */
func (d *Dnsfilter) setCacheResult(cache cache.Cache, host string, res Result) int { func (c *sbCtx) setCache(prefix []byte, hashes []byte) {
var buf bytes.Buffer d := make([]byte, 4+len(hashes))
expire := uint(time.Now().Unix()) + c.cacheTime*60
expire := uint(time.Now().Unix()) + d.Config.CacheTime*60 binary.BigEndian.PutUint32(d[:4], uint32(expire))
var exp []byte copy(d[4:], hashes)
exp = make([]byte, 4) c.cache.Set(prefix, d)
binary.BigEndian.PutUint32(exp, uint32(expire)) log.Debug("%s: stored in cache: %v", c.svc, prefix)
_, _ = buf.Write(exp)
enc := gob.NewEncoder(&buf)
err := enc.Encode(res)
if err != nil {
log.Error("gob.Encode(): %s", err)
return 0
}
val := buf.Bytes()
_ = cache.Set([]byte(host), val)
return len(val)
} }
func getCachedResult(cache cache.Cache, host string) (Result, bool) { func (c *sbCtx) getCached() int {
data := cache.Get([]byte(host)) now := time.Now().Unix()
if data == nil { hashesToRequest := map[[32]byte]string{}
return Result{}, false for k, v := range c.hashToHost {
} key := k[0:2]
val := c.cache.Get(key)
exp := int(binary.BigEndian.Uint32(data[:4])) if val != nil {
if exp <= int(time.Now().Unix()) { expire := binary.BigEndian.Uint32(val)
cache.Del([]byte(host)) if now >= int64(expire) {
return Result{}, false val = nil
} } else {
for i := 4; i < len(val); i += 32 {
var buf bytes.Buffer hash := val[i : i+32]
buf.Write(data[4:]) var hash32 [32]byte
dec := gob.NewDecoder(&buf) copy(hash32[:], hash[0:32])
r := Result{} _, found := c.hashToHost[hash32]
err := dec.Decode(&r) if found {
if err != nil { log.Debug("%s: found in cache: %s: blocked by %v", c.svc, c.host, hash32)
log.Debug("gob.Decode(): %s", err) return 1
return Result{}, false }
} }
}
return r, true }
} if val == nil {
hashesToRequest[k] = v
// SafeSearchDomain returns replacement address for search engine
func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
val, ok := safeSearchDomains[host]
return val, ok
}
func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeSearch: lookup for %s", host)
}
// Check cache. Return cached result if it was found
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host)
if isFound {
// atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
log.Tracef("SafeSearch: found in cache: %s", host)
return cachedValue, nil
}
safeHost, ok := d.SafeSearchDomain(host)
if !ok {
return Result{}, nil
}
res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
if ip := net.ParseIP(safeHost); ip != nil {
res.IP = ip
valLen := d.setCacheResult(gctx.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen)
return res, nil
}
// TODO this address should be resolved with upstream that was configured in dnsforward
addrs, err := net.LookupIP(safeHost)
if err != nil {
log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err)
return Result{}, err
}
for _, i := range addrs {
if ipv4 := i.To4(); ipv4 != nil {
res.IP = ipv4
break
} }
} }
if len(res.IP) == 0 { if len(hashesToRequest) == 0 {
return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost) log.Debug("%s: found in cache: %s: not blocked", c.svc, c.host)
return -1
} }
// Cache result c.hashToHost = hashesToRequest
valLen := d.setCacheResult(gctx.safeSearchCache, host, res) return 0
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen)
return res, nil
} }
// for each dot, hash it and add it to string type sbCtx struct {
// The maximum is 4 components: "a.b.c.d" host string
func hostnameToHashParam(host string) (string, map[string]bool) { svc string
var hashparam bytes.Buffer hashToHost map[[32]byte]string
hashes := map[string]bool{} cache cache.Cache
cacheTime uint
}
func hostnameToHashes(host string) map[[32]byte]string {
hashes := map[[32]byte]string{}
tld, icann := publicsuffix.PublicSuffix(host) tld, icann := publicsuffix.PublicSuffix(host)
if !icann { if !icann {
// private suffixes like cloudfront.net // private suffixes like cloudfront.net
@ -190,8 +145,7 @@ func hostnameToHashParam(host string) (string, map[string]bool) {
} }
sum := sha256.Sum256([]byte(curhost)) sum := sha256.Sum256([]byte(curhost))
hashes[hex.EncodeToString(sum[:])] = true hashes[sum] = curhost
hashparam.WriteString(fmt.Sprintf("%s.", hex.EncodeToString(sum[0:4])))
pos := strings.IndexByte(curhost, byte('.')) pos := strings.IndexByte(curhost, byte('.'))
if pos < 0 { if pos < 0 {
@ -199,26 +153,91 @@ func hostnameToHashParam(host string) (string, map[string]bool) {
} }
curhost = curhost[pos+1:] curhost = curhost[pos+1:]
} }
return hashparam.String(), hashes return hashes
}
// convert hash array to string
func (c *sbCtx) getQuestion() string {
q := ""
for hash := range c.hashToHost {
q += fmt.Sprintf("%s.", hex.EncodeToString(hash[0:2]))
}
if c.svc == "SafeBrowsing" {
q += sbTXTSuffix
} else {
q += pcTXTSuffix
}
return q
} }
// Find the target hash in TXT response // Find the target hash in TXT response
func (d *Dnsfilter) processTXT(svc, host string, resp *dns.Msg, hashes map[string]bool) bool { func (c *sbCtx) processTXT(resp *dns.Msg) (bool, [][]byte) {
matched := false
hashes := [][]byte{}
for _, a := range resp.Answer { for _, a := range resp.Answer {
txt, ok := a.(*dns.TXT) txt, ok := a.(*dns.TXT)
if !ok { if !ok {
continue continue
} }
log.Tracef("%s: hashes for %s: %v", svc, host, txt.Txt) log.Debug("%s: received hashes for %s: %v", c.svc, c.host, txt.Txt)
for _, t := range txt.Txt { for _, t := range txt.Txt {
_, ok := hashes[t]
if ok { if len(t) != 32*2 {
log.Tracef("%s: matched %s by %s", svc, host, t) continue
return true }
hash, err := hex.DecodeString(t)
if err != nil {
continue
}
hashes = append(hashes, hash)
if !matched {
var hash32 [32]byte
copy(hash32[:], hash)
hashHost, ok := c.hashToHost[hash32]
if ok {
log.Debug("%s: matched %s by %s/%s", c.svc, c.host, hashHost, t)
matched = true
}
} }
} }
} }
return false
return matched, hashes
}
func (c *sbCtx) storeCache(hashes [][]byte) {
sort.Slice(hashes, func(a, b int) bool {
return bytes.Compare(hashes[a], hashes[b]) < 0
})
var curData []byte
var prevPrefix []byte
for i, hash := range hashes {
prefix := hash[0:2]
if !bytes.Equal(prefix, prevPrefix) {
if i != 0 {
c.setCache(prevPrefix, curData)
curData = nil
}
prevPrefix = hashes[i][0:2]
}
curData = append(curData, hash...)
}
if len(prevPrefix) != 0 {
c.setCache(prevPrefix, curData)
}
for hash := range c.hashToHost {
prefix := hash[0:2]
val := c.cache.Get(prefix)
if val == nil {
c.setCache(prefix, nil)
}
}
} }
// Disabling "dupl": the algorithm of SB/PC is similar, but it uses different data // Disabling "dupl": the algorithm of SB/PC is similar, but it uses different data
@ -229,18 +248,29 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
defer timer.LogElapsed("SafeBrowsing lookup for %s", host) defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
} }
// check cache result := Result{}
cachedValue, isFound := getCachedResult(gctx.safebrowsingCache, host) hashes := hostnameToHashes(host)
if isFound {
// atomic.AddUint64(&gctx.stats.Safebrowsing.CacheHits, 1) c := &sbCtx{
log.Tracef("SafeBrowsing: found in cache: %s", host) host: host,
return cachedValue, nil svc: "SafeBrowsing",
hashToHost: hashes,
cache: gctx.safebrowsingCache,
cacheTime: d.Config.CacheTime,
} }
result := Result{} // check cache
question, hashes := hostnameToHashParam(host) match := c.getCached()
question = question + sbTXTSuffix if match < 0 {
return result, nil
} else if match > 0 {
result.IsFiltered = true
result.Reason = FilteredSafeBrowsing
result.Rule = "adguard-malware-shavar"
return result, nil
}
question := c.getQuestion()
log.Tracef("SafeBrowsing: checking %s: %s", host, question) log.Tracef("SafeBrowsing: checking %s: %s", host, question)
req := dns.Msg{} req := dns.Msg{}
@ -250,14 +280,14 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
return result, err return result, err
} }
if d.processTXT("SafeBrowsing", host, resp, hashes) { matched, receivedHashes := c.processTXT(resp)
if matched {
result.IsFiltered = true result.IsFiltered = true
result.Reason = FilteredSafeBrowsing result.Reason = FilteredSafeBrowsing
result.Rule = "adguard-malware-shavar" result.Rule = "adguard-malware-shavar"
} }
c.storeCache(receivedHashes)
valLen := d.setCacheResult(gctx.safebrowsingCache, host, result)
log.Debug("SafeBrowsing: stored in cache: %s (%d bytes)", host, valLen)
return result, nil return result, nil
} }
@ -269,18 +299,29 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
defer timer.LogElapsed("Parental lookup for %s", host) defer timer.LogElapsed("Parental lookup for %s", host)
} }
// check cache result := Result{}
cachedValue, isFound := getCachedResult(gctx.parentalCache, host) hashes := hostnameToHashes(host)
if isFound {
// atomic.AddUint64(&gctx.stats.Parental.CacheHits, 1) c := &sbCtx{
log.Tracef("Parental: found in cache: %s", host) host: host,
return cachedValue, nil svc: "Parental",
hashToHost: hashes,
cache: gctx.parentalCache,
cacheTime: d.Config.CacheTime,
} }
result := Result{} // check cache
question, hashes := hostnameToHashParam(host) match := c.getCached()
question = question + pcTXTSuffix if match < 0 {
return result, nil
} else if match > 0 {
result.IsFiltered = true
result.Reason = FilteredParental
result.Rule = "parental CATEGORY_BLACKLISTED"
return result, nil
}
question := c.getQuestion()
log.Tracef("Parental: checking %s: %s", host, question) log.Tracef("Parental: checking %s: %s", host, question)
req := dns.Msg{} req := dns.Msg{}
@ -290,14 +331,14 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
return result, err return result, err
} }
if d.processTXT("Parental", host, resp, hashes) { matched, receivedHashes := c.processTXT(resp)
if matched {
result.IsFiltered = true result.IsFiltered = true
result.Reason = FilteredParental result.Reason = FilteredParental
result.Rule = "parental CATEGORY_BLACKLISTED" result.Rule = "parental CATEGORY_BLACKLISTED"
} }
c.storeCache(receivedHashes)
valLen := d.setCacheResult(gctx.parentalCache, host, result)
log.Debug("Parental: stored in cache: %s (%d bytes)", host, valLen)
return result, err return result, err
} }
@ -362,34 +403,6 @@ func (d *Dnsfilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
} }
} }
func (d *Dnsfilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeSearchEnabled = true
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeSearchEnabled = false
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": d.Config.SafeSearchEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
func (d *Dnsfilter) registerSecurityHandlers() { func (d *Dnsfilter) registerSecurityHandlers() {
d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable) d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable) d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)

91
dnsfilter/sb_pc_test.go Normal file
View File

@ -0,0 +1,91 @@
package dnsfilter
import (
"crypto/sha256"
"strings"
"testing"
"github.com/AdguardTeam/golibs/cache"
"github.com/stretchr/testify/assert"
)
func TestSafeBrowsingHash(t *testing.T) {
// test hostnameToHashes()
hashes := hostnameToHashes("1.2.3.sub.host.com")
assert.Equal(t, 3, len(hashes))
_, ok := hashes[sha256.Sum256([]byte("3.sub.host.com"))]
assert.True(t, ok)
_, ok = hashes[sha256.Sum256([]byte("sub.host.com"))]
assert.True(t, ok)
_, ok = hashes[sha256.Sum256([]byte("host.com"))]
assert.True(t, ok)
_, ok = hashes[sha256.Sum256([]byte("com"))]
assert.False(t, ok)
c := &sbCtx{
svc: "SafeBrowsing",
}
// test getQuestion()
c.hashToHost = hashes
q := c.getQuestion()
assert.True(t, strings.Index(q, "7a1b.") >= 0)
assert.True(t, strings.Index(q, "af5a.") >= 0)
assert.True(t, strings.Index(q, "eb11.") >= 0)
assert.True(t, strings.Index(q, "sb.dns.adguard.com.") > 0)
}
func TestSafeBrowsingCache(t *testing.T) {
c := &sbCtx{
svc: "SafeBrowsing",
cacheTime: 100,
}
conf := cache.Config{}
c.cache = cache.New(conf)
// store in cache hashes for "3.sub.host.com" and "host.com"
// and empty data for hash-prefix for "sub.host.com"
hash := sha256.Sum256([]byte("sub.host.com"))
c.hashToHost = make(map[[32]byte]string)
c.hashToHost[hash] = "sub.host.com"
var hashesArray [][]byte
hash4 := sha256.Sum256([]byte("3.sub.host.com"))
hashesArray = append(hashesArray, hash4[:])
hash2 := sha256.Sum256([]byte("host.com"))
hashesArray = append(hashesArray, hash2[:])
c.storeCache(hashesArray)
// match "3.sub.host.com" or "host.com" from cache
c.hashToHost = make(map[[32]byte]string)
hash = sha256.Sum256([]byte("3.sub.host.com"))
c.hashToHost[hash] = "3.sub.host.com"
hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("host.com"))
c.hashToHost[hash] = "host.com"
assert.Equal(t, 1, c.getCached())
// match "sub.host.com" from cache
c.hashToHost = make(map[[32]byte]string)
hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost[hash] = "sub.host.com"
assert.Equal(t, -1, c.getCached())
// match "sub.host.com" from cache,
// but another hash for "nonexisting.com" is not in cache
// which means that we must get data from server for it
c.hashToHost = make(map[[32]byte]string)
hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("nonexisting.com"))
c.hashToHost[hash] = "nonexisting.com"
assert.Equal(t, 0, c.getCached())
hash = sha256.Sum256([]byte("sub.host.com"))
_, ok := c.hashToHost[hash]
assert.False(t, ok)
hash = sha256.Sum256([]byte("nonexisting.com"))
_, ok = c.hashToHost[hash]
assert.True(t, ok)
}

View File

@ -133,8 +133,8 @@ var defaultValues = ServerConfig{
// createProxyConfig creates and validates configuration for the main proxy // createProxyConfig creates and validates configuration for the main proxy
func (s *Server) createProxyConfig() (proxy.Config, error) { func (s *Server) createProxyConfig() (proxy.Config, error) {
proxyConfig := proxy.Config{ proxyConfig := proxy.Config{
UDPListenAddr: s.conf.UDPListenAddr, UDPListenAddr: []*net.UDPAddr{s.conf.UDPListenAddr},
TCPListenAddr: s.conf.TCPListenAddr, TCPListenAddr: []*net.TCPAddr{s.conf.TCPListenAddr},
Ratelimit: int(s.conf.Ratelimit), Ratelimit: int(s.conf.Ratelimit),
RatelimitWhitelist: s.conf.RatelimitWhitelist, RatelimitWhitelist: s.conf.RatelimitWhitelist,
RefuseAny: s.conf.RefuseAny, RefuseAny: s.conf.RefuseAny,
@ -229,7 +229,7 @@ func (s *Server) prepareIntlProxy() {
// prepareTLS - prepares TLS configuration for the DNS proxy // prepareTLS - prepares TLS configuration for the DNS proxy
func (s *Server) prepareTLS(proxyConfig *proxy.Config) error { func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 { if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 {
proxyConfig.TLSListenAddr = s.conf.TLSListenAddr proxyConfig.TLSListenAddr = []*net.TCPAddr{s.conf.TLSListenAddr}
var err error var err error
s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData) s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData)
if err != nil { if err != nil {

View File

@ -252,10 +252,6 @@ func TestBlockedRequest(t *testing.T) {
func TestServerCustomClientUpstream(t *testing.T) { func TestServerCustomClientUpstream(t *testing.T) {
s := createTestServer(t) s := createTestServer(t)
err := s.Start()
if err != nil {
t.Fatalf("Failed to start server: %s", err)
}
s.conf.GetCustomUpstreamByClient = func(clientAddr string) *proxy.UpstreamConfig { s.conf.GetCustomUpstreamByClient = func(clientAddr string) *proxy.UpstreamConfig {
uc := &proxy.UpstreamConfig{} uc := &proxy.UpstreamConfig{}
u := &testUpstream{} u := &testUpstream{}
@ -264,6 +260,9 @@ func TestServerCustomClientUpstream(t *testing.T) {
uc.Upstreams = append(uc.Upstreams, u) uc.Upstreams = append(uc.Upstreams, u)
return uc return uc
} }
assert.Nil(t, s.Start())
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// Send test request // Send test request

2
go.mod
View File

@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
go 1.14 go 1.14
require ( require (
github.com/AdguardTeam/dnsproxy v0.29.1 github.com/AdguardTeam/dnsproxy v0.30.1
github.com/AdguardTeam/golibs v0.4.2 github.com/AdguardTeam/golibs v0.4.2
github.com/AdguardTeam/urlfilter v0.11.2 github.com/AdguardTeam/urlfilter v0.11.2
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1

4
go.sum
View File

@ -1,5 +1,5 @@
github.com/AdguardTeam/dnsproxy v0.29.1 h1:Stc+JLh67C9K38vbrH2920+3FnbXKkFzYQqRiu5auUo= github.com/AdguardTeam/dnsproxy v0.30.1 h1:SnsL5kM/eFTrtLLdww1EePOhVDZTWzMkse+5tadGhvc=
github.com/AdguardTeam/dnsproxy v0.29.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58= github.com/AdguardTeam/dnsproxy v0.30.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o= github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=