Pull request: all: allow clientid in access settings

Updates #2624.
Updates #3162.

Squashed commit of the following:

commit 68860da717a23a0bfeba14b7fe10b5e4ad38726d
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 15:41:33 2021 +0300

    all: imp types, names

commit ebd4ec26636853d0d58c4e331e6a78feede20813
Merge: 239eb721 16e5e09c
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 15:14:33 2021 +0300

    Merge branch 'master' into 2624-clientid-access

commit 239eb7215abc47e99a0300a0f4cf56002689b1a9
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 15:13:10 2021 +0300

    all: fix client blocking check

commit e6bece3ea8367b3cbe3d90702a3368c870ad4f13
Merge: 9935f2a3 9d1656b5
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 13:12:28 2021 +0300

    Merge branch 'master' into 2624-clientid-access

commit 9935f2a30bcfae2b853f3ef610c0ab7a56a8f448
Author: Ildar Kamalov <ik@adguard.com>
Date:   Tue Jun 29 11:26:51 2021 +0300

    client: show block button for client id

commit ed786a6a74a081cd89e9d67df3537a4fadd54831
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Jun 25 15:56:23 2021 +0300

    client: imp i18n

commit 4fed21c68473ad408960c08a7d87624cabce1911
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Jun 25 15:34:09 2021 +0300

    all: imp i18n, docs

commit 55e65c0d6b939560c53dcb834a4557eb3853d194
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Jun 25 13:34:01 2021 +0300

    all: fix cache, imp code, docs, tests

commit c1e5a83e76deb44b1f92729bb9ddfcc6a96ac4a8
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Jun 24 19:27:12 2021 +0300

    all: allow clientid in access settings
This commit is contained in:
Ainar Garipov 2021-06-29 15:53:28 +03:00
parent 16e5e09c2e
commit e08a64ebe4
33 changed files with 955 additions and 604 deletions

View File

@ -15,6 +15,7 @@ and this project adheres to
### Added ### Added
- Blocking access using client IDs ([#2624], [#3162]).
- `source` directives support in `/etc/network/interfaces` on Linux ([#3257]). - `source` directives support in `/etc/network/interfaces` on Linux ([#3257]).
- RFC 9000 support in DNS-over-QUIC. - RFC 9000 support in DNS-over-QUIC.
- Completely disabling statistics by setting the statistics interval to zero - Completely disabling statistics by setting the statistics interval to zero
@ -80,9 +81,11 @@ released by then.
[#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439 [#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439
[#2441]: https://github.com/AdguardTeam/AdGuardHome/issues/2441 [#2441]: https://github.com/AdguardTeam/AdGuardHome/issues/2441
[#2443]: https://github.com/AdguardTeam/AdGuardHome/issues/2443 [#2443]: https://github.com/AdguardTeam/AdGuardHome/issues/2443
[#2624]: https://github.com/AdguardTeam/AdGuardHome/issues/2624
[#2763]: https://github.com/AdguardTeam/AdGuardHome/issues/2763 [#2763]: https://github.com/AdguardTeam/AdGuardHome/issues/2763
[#3013]: https://github.com/AdguardTeam/AdGuardHome/issues/3013 [#3013]: https://github.com/AdguardTeam/AdGuardHome/issues/3013
[#3136]: https://github.com/AdguardTeam/AdGuardHome/issues/3136 [#3136]: https://github.com/AdguardTeam/AdGuardHome/issues/3136
[#3162]: https://github.com/AdguardTeam/AdGuardHome/issues/3162
[#3166]: https://github.com/AdguardTeam/AdGuardHome/issues/3166 [#3166]: https://github.com/AdguardTeam/AdGuardHome/issues/3166
[#3172]: https://github.com/AdguardTeam/AdGuardHome/issues/3172 [#3172]: https://github.com/AdguardTeam/AdGuardHome/issues/3172
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184 [#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184

View File

@ -159,8 +159,10 @@ attributes to make it work in Markdown renderers that strip "id". -->
* Minimize scope of variables as much as possible. * Minimize scope of variables as much as possible.
* No shadowing, since it can often lead to subtle bugs, especially with * No name shadowing, including of predeclared identifiers, since it can often
errors. lead to subtle bugs, especially with errors. This rule does not apply to
struct fields, since they are always used together with the name of the
struct value, so there isn't any confusion.
* Prefer constants to variables where possible. Avoid global variables. Use * Prefer constants to variables where possible. Avoid global variables. Use
[constant errors] instead of `errors.New`. [constant errors] instead of `errors.New`.

View File

@ -426,9 +426,9 @@
"access_title": "Access settings", "access_title": "Access settings",
"access_desc": "Here you can configure access rules for the AdGuard Home DNS server.", "access_desc": "Here you can configure access rules for the AdGuard Home DNS server.",
"access_allowed_title": "Allowed clients", "access_allowed_title": "Allowed clients",
"access_allowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will accept requests from these IP addresses only.", "access_allowed_desc": "A list of CIDRs, IP addresses, or client IDs. If configured, AdGuard Home will accept requests only from these clients.",
"access_disallowed_title": "Disallowed clients", "access_disallowed_title": "Disallowed clients",
"access_disallowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will drop requests from these IP addresses.", "access_disallowed_desc": "A list of CIDRs, IP addresses, or client IDs. If configured, AdGuard Home will drop requests from these clients. If allowed clients are configured, this field is ignored.",
"access_blocked_title": "Disallowed domains", "access_blocked_title": "Disallowed domains",
"access_blocked_desc": "Not to be confused with filters. AdGuard Home drops DNS queries matching these domains, and these queries don't even appear in the query log. You can specify exact domain names, wildcards, or URL filter rules, e.g. \"example.org\", \"*.example.org\", or \"||example.org^\" correspondingly.", "access_blocked_desc": "Not to be confused with filters. AdGuard Home drops DNS queries matching these domains, and these queries don't even appear in the query log. You can specify exact domain names, wildcards, or URL filter rules, e.g. \"example.org\", \"*.example.org\", or \"||example.org^\" correspondingly.",
"access_settings_saved": "Access settings successfully saved", "access_settings_saved": "Access settings successfully saved",

View File

@ -9,7 +9,7 @@ import Card from '../ui/Card';
import Cell from '../ui/Cell'; import Cell from '../ui/Cell';
import { getPercent, sortIp } from '../../helpers/helpers'; import { getPercent, sortIp } from '../../helpers/helpers';
import { BLOCK_ACTIONS, R_CLIENT_ID, STATUS_COLORS } from '../../helpers/constants'; import { BLOCK_ACTIONS, STATUS_COLORS } from '../../helpers/constants';
import { toggleClientBlock } from '../../actions/access'; import { toggleClientBlock } from '../../actions/access';
import { renderFormattedClientCell } from '../../helpers/renderFormattedClientCell'; import { renderFormattedClientCell } from '../../helpers/renderFormattedClientCell';
import { getStats } from '../../actions/stats'; import { getStats } from '../../actions/stats';
@ -35,10 +35,6 @@ const CountCell = (row) => {
}; };
const renderBlockingButton = (ip, disallowed, disallowed_rule) => { const renderBlockingButton = (ip, disallowed, disallowed_rule) => {
if (R_CLIENT_ID.test(ip)) {
return null;
}
const dispatch = useDispatch(); const dispatch = useDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const processingSet = useSelector((state) => state.access.processingSet); const processingSet = useSelector((state) => state.access.processingSet);

2
go.mod
View File

@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
go 1.16 go 1.16
require ( require (
github.com/AdguardTeam/dnsproxy v0.37.7 github.com/AdguardTeam/dnsproxy v0.38.0
github.com/AdguardTeam/golibs v0.8.0 github.com/AdguardTeam/golibs v0.8.0
github.com/AdguardTeam/urlfilter v0.14.6 github.com/AdguardTeam/urlfilter v0.14.6
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1

4
go.sum
View File

@ -9,8 +9,8 @@ dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf h1:gc042VRSIRSUzZ+Px6xQCRWNJZTaPkomisDfUZmoFNk= github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf h1:gc042VRSIRSUzZ+Px6xQCRWNJZTaPkomisDfUZmoFNk=
github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI= github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
github.com/AdguardTeam/dnsproxy v0.37.7 h1:yp0vEVYobf/1l8iY7es9yMqguw8BUEeC74OGA4G2v2A= github.com/AdguardTeam/dnsproxy v0.38.0 h1:7GyyNJOieIVOgdnhu47exqWjHPQro7wQhqzvQjaZt6M=
github.com/AdguardTeam/dnsproxy v0.37.7/go.mod h1:xMfevPAwpK1ULoLO0CARg/OiUsPH92kfyliXhPTc62M= github.com/AdguardTeam/dnsproxy v0.38.0/go.mod h1:xMfevPAwpK1ULoLO0CARg/OiUsPH92kfyliXhPTc62M=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=

View File

@ -27,10 +27,9 @@ type EtcHostsContainer struct {
lock sync.RWMutex lock sync.RWMutex
// table is the host-to-IPs map. // table is the host-to-IPs map.
table map[string][]net.IP table map[string][]net.IP
// tableReverse is the IP-to-hosts map. // tableReverse is the IP-to-hosts map. The type of the values in the
// // map is []string.
// TODO(a.garipov): Make better use of newtypes. Perhaps a custom map. tableReverse *IPMap
tableReverse map[string][]string
hostsFn string // path to the main hosts-file hostsFn string // path to the main hosts-file
hostsDirs []string // paths to OS-specific directories with hosts-files hostsDirs []string // paths to OS-specific directories with hosts-files
@ -80,7 +79,7 @@ func (ehc *EtcHostsContainer) Init(hostsFn string) {
var err error var err error
ehc.watcher, err = fsnotify.NewWatcher() ehc.watcher, err = fsnotify.NewWatcher()
if err != nil { if err != nil {
log.Error("etchostscontainer: %s", err) log.Error("etchosts: %s", err)
} }
} }
@ -141,7 +140,7 @@ func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP {
copy(ipsCopy, ips) copy(ipsCopy, ips)
} }
log.Debug("etchostscontainer: answer: %s -> %v", host, ipsCopy) log.Debug("etchosts: answer: %s -> %v", host, ipsCopy)
return ipsCopy return ipsCopy
} }
@ -151,38 +150,40 @@ func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts [
return nil return nil
} }
ipReal := UnreverseAddr(addr) ip := UnreverseAddr(addr)
if ipReal == nil { if ip == nil {
return nil return nil
} }
ipStr := ipReal.String()
ehc.lock.RLock() ehc.lock.RLock()
defer ehc.lock.RUnlock() defer ehc.lock.RUnlock()
hosts = ehc.tableReverse[ipStr] v, ok := ehc.tableReverse.Get(ip)
if !ok {
if len(hosts) == 0 { return nil
return nil // not found
} }
log.Debug("etchostscontainer: reverse-lookup: %s -> %s", addr, hosts) hosts, ok = v.([]string)
if !ok {
log.Error("etchosts: bad type %T in tableReverse for %s", v, ip)
return nil
} else if len(hosts) == 0 {
return nil
}
log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts)
return hosts return hosts
} }
// List returns an IP-to-hostnames table. It is safe for concurrent use. // List returns an IP-to-hostnames table. The type of the values in the map is
func (ehc *EtcHostsContainer) List() (ipToHosts map[string][]string) { // []string. It is safe for concurrent use.
func (ehc *EtcHostsContainer) List() (ipToHosts *IPMap) {
ehc.lock.RLock() ehc.lock.RLock()
defer ehc.lock.RUnlock() defer ehc.lock.RUnlock()
ipToHosts = make(map[string][]string, len(ehc.tableReverse)) return ehc.tableReverse.ShallowClone()
for k, v := range ehc.tableReverse {
ipToHosts[k] = v
}
return ipToHosts
} }
// update table // update table
@ -205,29 +206,31 @@ func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string
ok = true ok = true
} }
if ok { if ok {
log.Debug("etchostscontainer: added %s -> %s", ipAddr, host) log.Debug("etchosts: added %s -> %s", ipAddr, host)
} }
} }
// updateTableRev updates the reverse address table. // updateTableRev updates the reverse address table.
func (ehc *EtcHostsContainer) updateTableRev(tableRev map[string][]string, newHost string, ipAddr net.IP) { func (ehc *EtcHostsContainer) updateTableRev(tableRev *IPMap, newHost string, ip net.IP) {
ipStr := ipAddr.String() v, ok := tableRev.Get(ip)
hosts, ok := tableRev[ipStr]
if !ok { if !ok {
tableRev[ipStr] = []string{newHost} tableRev.Set(ip, []string{newHost})
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost) log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
return return
} }
hosts, _ := v.([]string)
for _, host := range hosts { for _, host := range hosts {
if host == newHost { if host == newHost {
return return
} }
} }
tableRev[ipStr] = append(tableRev[ipStr], newHost) hosts = append(hosts, newHost)
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost) tableRev.Set(ip, hosts)
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
} }
// parseHostsLine parses hosts from the fields. // parseHostsLine parses hosts from the fields.
@ -255,12 +258,12 @@ func parseHostsLine(fields []string) (hosts []string) {
// line for one IP are supported. // line for one IP are supported.
func (ehc *EtcHostsContainer) load( func (ehc *EtcHostsContainer) load(
table map[string][]net.IP, table map[string][]net.IP,
tableRev map[string][]string, tableRev *IPMap,
fn string, fn string,
) { ) {
f, err := os.Open(fn) f, err := os.Open(fn)
if err != nil { if err != nil {
log.Error("etchostscontainer: %s", err) log.Error("etchosts: %s", err)
return return
} }
@ -268,11 +271,11 @@ func (ehc *EtcHostsContainer) load(
defer func() { defer func() {
derr := f.Close() derr := f.Close()
if derr != nil { if derr != nil {
log.Error("etchostscontainer: closing file: %s", err) log.Error("etchosts: closing file: %s", err)
} }
}() }()
log.Debug("etchostscontainer: loading hosts from file %s", fn) log.Debug("etchosts: loading hosts from file %s", fn)
s := bufio.NewScanner(f) s := bufio.NewScanner(f)
for s.Scan() { for s.Scan() {
@ -296,7 +299,7 @@ func (ehc *EtcHostsContainer) load(
err = s.Err() err = s.Err()
if err != nil { if err != nil {
log.Error("etchostscontainer: %s", err) log.Error("etchosts: %s", err)
} }
} }
@ -334,7 +337,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
} }
if event.Op&fsnotify.Write == fsnotify.Write { if event.Op&fsnotify.Write == fsnotify.Write {
log.Debug("etchostscontainer: modified: %s", event.Name) log.Debug("etchosts: modified: %s", event.Name)
ehc.updateHosts() ehc.updateHosts()
} }
@ -342,7 +345,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
if !ok { if !ok {
return return
} }
log.Error("etchostscontainer: %s", err) log.Error("etchosts: %s", err)
} }
} }
} }
@ -350,7 +353,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
// updateHosts - loads system hosts // updateHosts - loads system hosts
func (ehc *EtcHostsContainer) updateHosts() { func (ehc *EtcHostsContainer) updateHosts() {
table := make(map[string][]net.IP) table := make(map[string][]net.IP)
tableRev := make(map[string][]string) tableRev := NewIPMap(0)
ehc.load(table, tableRev, ehc.hostsFn) ehc.load(table, tableRev, ehc.hostsFn)
@ -358,7 +361,7 @@ func (ehc *EtcHostsContainer) updateHosts() {
des, err := os.ReadDir(dir) des, err := os.ReadDir(dir)
if err != nil { if err != nil {
if !errors.Is(err, os.ErrNotExist) { if !errors.Is(err, os.ErrNotExist) {
log.Error("etchostscontainer: Opening directory: %q: %s", dir, err) log.Error("etchosts: Opening directory: %q: %s", dir, err)
} }
continue continue

View File

@ -70,7 +70,7 @@ func TestEtcHostsContainerResolution(t *testing.T) {
}) })
t.Run("hosts_file", func(t *testing.T) { t.Run("hosts_file", func(t *testing.T) {
names, ok := ehc.List()["127.0.0.1"] names, ok := ehc.List().Get(net.IP{127, 0, 0, 1})
require.True(t, ok) require.True(t, ok)
assert.Equal(t, []string{"host", "localhost"}, names) assert.Equal(t, []string{"host", "localhost"}, names)
}) })

112
internal/aghnet/ipmap.go Normal file
View File

@ -0,0 +1,112 @@
package aghnet
import (
"fmt"
"net"
)
// ipArr is a representation of an IP address as an array of bytes.
type ipArr [16]byte
// String implements the fmt.Stringer interface for ipArr.
func (a ipArr) String() (s string) {
return net.IP(a[:]).String()
}
// IPMap is a map of IP addresses.
type IPMap struct {
m map[ipArr]interface{}
}
// NewIPMap returns a new empty IP map using hint as a size hint for the
// underlying map.
func NewIPMap(hint int) (m *IPMap) {
return &IPMap{
m: make(map[ipArr]interface{}, hint),
}
}
// ipToArr converts a net.IP into an ipArr.
//
// TODO(a.garipov): Use the slice-to-array conversion in Go 1.17.
func ipToArr(ip net.IP) (a ipArr) {
copy(a[:], ip.To16())
return a
}
// Del deletes ip from the map. Calling Del on a nil *IPMap has no effect, just
// like delete on an empty map doesn't.
func (m *IPMap) Del(ip net.IP) {
if m != nil {
delete(m.m, ipToArr(ip))
}
}
// Get returns the value from the map. Calling Get on a nil *IPMap returns nil
// and false, just like indexing on an empty map does.
func (m *IPMap) Get(ip net.IP) (v interface{}, ok bool) {
if m != nil {
v, ok = m.m[ipToArr(ip)]
return v, ok
}
return nil, false
}
// Len returns the length of the map. A nil *IPMap has a length of zero, just
// like an empty map.
func (m *IPMap) Len() (n int) {
if m == nil {
return 0
}
return len(m.m)
}
// Range calls f for each key and value present in the map in an undefined
// order. If cont is false, range stops the iteration. Calling Range on a nil
// *IPMap has no effect, just like ranging over a nil map.
func (m *IPMap) Range(f func(ip net.IP, v interface{}) (cont bool)) {
if m == nil {
return
}
for k, v := range m.m {
if !f(net.IP(k[:]), v) {
break
}
}
}
// Set sets the value. Set panics if the m is a nil *IPMap, just like a nil map
// does.
func (m *IPMap) Set(ip net.IP, v interface{}) {
m.m[ipToArr(ip)] = v
}
// ShallowClone returns a shallow clone of the map.
func (m *IPMap) ShallowClone() (sclone *IPMap) {
if m == nil {
return nil
}
sclone = NewIPMap(m.Len())
m.Range(func(ip net.IP, v interface{}) (cont bool) {
sclone.Set(ip, v)
return true
})
return sclone
}
// String implements the fmt.Stringer interface for *IPMap.
func (m *IPMap) String() (s string) {
if m == nil {
return "<nil>"
}
return fmt.Sprint(m.m)
}

View File

@ -0,0 +1,142 @@
package aghnet
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIPMap_allocs(t *testing.T) {
ip4 := net.IP{1, 2, 3, 4}
m := NewIPMap(0)
m.Set(ip4, 42)
t.Run("get", func(t *testing.T) {
var v interface{}
var ok bool
allocs := testing.AllocsPerRun(100, func() {
v, ok = m.Get(ip4)
})
require.True(t, ok)
require.Equal(t, 42, v)
assert.Equal(t, float64(0), allocs)
})
t.Run("len", func(t *testing.T) {
var n int
allocs := testing.AllocsPerRun(100, func() {
n = m.Len()
})
require.Equal(t, 1, n)
assert.Equal(t, float64(0), allocs)
})
}
func TestIPMap(t *testing.T) {
ip4 := net.IP{1, 2, 3, 4}
ip6 := net.IP{
0x12, 0x34, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x56, 0x78,
}
val := 42
t.Run("nil", func(t *testing.T) {
var m *IPMap
assert.NotPanics(t, func() {
m.Del(ip4)
m.Del(ip6)
})
assert.NotPanics(t, func() {
v, ok := m.Get(ip4)
assert.Nil(t, v)
assert.False(t, ok)
v, ok = m.Get(ip6)
assert.Nil(t, v)
assert.False(t, ok)
})
assert.NotPanics(t, func() {
assert.Equal(t, 0, m.Len())
})
assert.NotPanics(t, func() {
n := 0
m.Range(func(_ net.IP, _ interface{}) (cont bool) {
n++
return true
})
assert.Equal(t, 0, n)
})
assert.Panics(t, func() {
m.Set(ip4, val)
})
assert.Panics(t, func() {
m.Set(ip6, val)
})
assert.NotPanics(t, func() {
sclone := m.ShallowClone()
assert.Nil(t, sclone)
})
})
testIPMap := func(t *testing.T, ip net.IP, s string) {
m := NewIPMap(0)
assert.Equal(t, 0, m.Len())
v, ok := m.Get(ip)
assert.Nil(t, v)
assert.False(t, ok)
m.Set(ip, val)
v, ok = m.Get(ip)
assert.Equal(t, val, v)
assert.True(t, ok)
n := 0
m.Range(func(ipKey net.IP, v interface{}) (cont bool) {
assert.Equal(t, ip.To16(), ipKey)
assert.Equal(t, val, v)
n++
return false
})
assert.Equal(t, 1, n)
sclone := m.ShallowClone()
assert.Equal(t, m, sclone)
assert.Equal(t, s, m.String())
m.Del(ip)
v, ok = m.Get(ip)
assert.Nil(t, v)
assert.False(t, ok)
assert.Equal(t, 0, m.Len())
}
t.Run("ipv4", func(t *testing.T) {
testIPMap(t, ip4, "map[1.2.3.4:42]")
})
t.Run("ipv6", func(t *testing.T) {
testIPMap(t, ip6, "map[1234::5678:42]")
})
}

View File

@ -6,138 +6,163 @@ import (
"net" "net"
"net/http" "net/http"
"strings" "strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist" "github.com/AdguardTeam/urlfilter/filterlist"
) )
// accessCtx controls IP and client blocking that takes place before all other
// processing. An accessCtx is safe for concurrent use.
type accessCtx struct { type accessCtx struct {
lock sync.Mutex allowedIPs *aghnet.IPMap
blockedIPs *aghnet.IPMap
// allowedClients are the IP addresses of clients in the allowlist. allowedClientIDs *aghstrings.Set
allowedClients *aghstrings.Set blockedClientIDs *aghstrings.Set
// disallowedClients are the IP addresses of clients in the blocklist. blockedHostsEng *urlfilter.DNSEngine
disallowedClients *aghstrings.Set
allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients // TODO(a.garipov): Create a type for a set of IP networks.
disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked // aghnet.IPNetSet?
allowedNets []*net.IPNet
blockedHostsEngine *urlfilter.DNSEngine // finds hosts that should be blocked blockedNets []*net.IPNet
} }
func newAccessCtx(allowedClients, disallowedClients, blockedHosts []string) (a *accessCtx, err error) { // unit is a convenient alias for struct{}
a = &accessCtx{ type unit = struct{}
allowedClients: aghstrings.NewSet(),
disallowedClients: aghstrings.NewSet(), // processAccessClients is a helper for processing a list of client strings,
// which may be an IP address, a CIDR, or a ClientID.
func processAccessClients(
clientStrs []string,
ips *aghnet.IPMap,
nets *[]*net.IPNet,
clientIDs *aghstrings.Set,
) (err error) {
for i, s := range clientStrs {
if ip := net.ParseIP(s); ip != nil {
ips.Set(ip, unit{})
} else if cidrIP, ipnet, cidrErr := net.ParseCIDR(s); cidrErr == nil {
ipnet.IP = cidrIP
*nets = append(*nets, ipnet)
} else {
idErr := ValidateClientID(s)
if idErr != nil {
return fmt.Errorf(
"value %q at index %d: bad ip, cidr, or clientid",
s,
i,
)
} }
err = processIPCIDRArray(a.allowedClients, &a.allowedClientsIPNet, allowedClients) clientIDs.Add(s)
if err != nil {
return nil, fmt.Errorf("processing allowed clients: %w", err)
} }
err = processIPCIDRArray(a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients)
if err != nil {
return nil, fmt.Errorf("processing disallowed clients: %w", err)
}
b := &strings.Builder{}
for _, s := range blockedHosts {
aghstrings.WriteToBuilder(b, strings.ToLower(s), "\n")
}
listArray := []filterlist.RuleList{}
list := &filterlist.StringRuleList{
ID: int(0),
RulesText: b.String(),
IgnoreCosmetic: true,
}
listArray = append(listArray, list)
rulesStorage, err := filterlist.NewRuleStorage(listArray)
if err != nil {
return nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
}
a.blockedHostsEngine = urlfilter.NewDNSEngine(rulesStorage)
return a, nil
}
// Split array of IP or CIDR into 2 containers for fast search
func processIPCIDRArray(dst *aghstrings.Set, dstIPNet *[]net.IPNet, src []string) error {
for _, s := range src {
ip := net.ParseIP(s)
if ip != nil {
dst.Add(s)
continue
}
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
return err
}
*dstIPNet = append(*dstIPNet, *ipnet)
} }
return nil return nil
} }
// IsBlockedIP - return TRUE if this client should be blocked // newAccessCtx creates a new accessCtx.
// Returns the item from the "disallowedClients" list that lead to blocking IP. func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty, a = &accessCtx{
// but the ip does not belong to it. allowedIPs: aghnet.NewIPMap(0),
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) { blockedIPs: aghnet.NewIPMap(0),
ipStr := ip.String()
a.lock.Lock() allowedClientIDs: aghstrings.NewSet(),
defer a.lock.Unlock() blockedClientIDs: aghstrings.NewSet(),
if a.allowedClients.Len() != 0 || len(a.allowedClientsIPNet) != 0 {
if a.allowedClients.Has(ipStr) {
return false, ""
} }
if len(a.allowedClientsIPNet) != 0 { err = processAccessClients(allowed, a.allowedIPs, &a.allowedNets, a.allowedClientIDs)
for _, ipnet := range a.allowedClientsIPNet { if err != nil {
if ipnet.Contains(ip) { return nil, fmt.Errorf("adding allowed: %w", err)
return false, ""
}
}
} }
return true, "" err = processAccessClients(blocked, a.blockedIPs, &a.blockedNets, a.blockedClientIDs)
if err != nil {
return nil, fmt.Errorf("adding blocked: %w", err)
} }
if a.disallowedClients.Has(ipStr) { b := &strings.Builder{}
return true, ipStr for _, h := range blockedHosts {
aghstrings.WriteToBuilder(b, strings.ToLower(h), "\n")
} }
if len(a.disallowedClientsIPNet) != 0 { lists := []filterlist.RuleList{
for _, ipnet := range a.disallowedClientsIPNet { &filterlist.StringRuleList{
if ipnet.Contains(ip) { ID: int(0),
return true, ipnet.String() RulesText: b.String(),
} IgnoreCosmetic: true,
} },
} }
return false, "" rulesStrg, err := filterlist.NewRuleStorage(lists)
if err != nil {
return nil, fmt.Errorf("adding blocked hosts: %w", err)
} }
// IsBlockedDomain - return TRUE if this domain should be blocked a.blockedHostsEng = urlfilter.NewDNSEngine(rulesStrg)
func (a *accessCtx) IsBlockedDomain(host string) (ok bool) {
a.lock.Lock()
defer a.lock.Unlock()
_, ok = a.blockedHostsEngine.Match(strings.ToLower(host)) return a, nil
}
// allowlistMode returns true if this *accessCtx is in the allowlist mode.
func (a *accessCtx) allowlistMode() (ok bool) {
return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
}
// isBlockedClientID returns true if the ClientID should be blocked.
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
allowlistMode := a.allowlistMode()
if id == "" {
// In allowlist mode, consider requests without client IDs
// blocked by default.
return allowlistMode
}
if allowlistMode {
return !a.allowedClientIDs.Has(id)
}
return a.blockedClientIDs.Has(id)
}
// isBlockedHost returns true if host should be blocked.
func (a *accessCtx) isBlockedHost(host string) (ok bool) {
_, ok = a.blockedHostsEng.Match(strings.ToLower(host))
return ok return ok
} }
// isBlockedIP returns the status of the IP address blocking as well as the rule
// that blocked it.
func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
blocked = true
ips := a.blockedIPs
ipnets := a.blockedNets
if a.allowlistMode() {
// Enable allowlist mode and use the allowlist sets.
blocked = false
ips = a.allowedIPs
ipnets = a.allowedNets
}
if _, ok := ips.Get(ip); ok {
return blocked, ip.String()
}
for _, ipnet := range ipnets {
if ipnet.Contains(ip) {
return blocked, ipnet.String()
}
}
return !blocked, ""
}
type accessListJSON struct { type accessListJSON struct {
AllowedClients []string `json:"allowed_clients"` AllowedClients []string `json:"allowed_clients"`
DisallowedClients []string `json:"disallowed_clients"` DisallowedClients []string `json:"disallowed_clients"`
@ -161,62 +186,43 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(j) err := json.NewEncoder(w).Encode(j)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err)
return return
} }
} }
func checkIPCIDRArray(src []string) error {
for _, s := range src {
ip := net.ParseIP(s)
if ip != nil {
continue
}
_, _, err := net.ParseCIDR(s)
if err != nil {
return err
}
}
return nil
}
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
j := accessListJSON{} list := accessListJSON{}
err := json.NewDecoder(r.Body).Decode(&j) err := json.NewDecoder(r.Body).Decode(&list)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) httpError(r, w, http.StatusBadRequest, "decoding request: %s", err)
return
}
err = checkIPCIDRArray(j.AllowedClients)
if err == nil {
err = checkIPCIDRArray(j.DisallowedClients)
}
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return return
} }
var a *accessCtx var a *accessCtx
a, err = newAccessCtx(j.AllowedClients, j.DisallowedClients, j.BlockedHosts) a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err) httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
return return
} }
defer log.Debug("Access: updated lists: %d, %d, %d", defer log.Debug(
len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts)) "access: updated lists: %d, %d, %d",
len(list.AllowedClients),
len(list.DisallowedClients),
len(list.BlockedHosts),
)
defer s.conf.ConfigModified() defer s.conf.ConfigModified()
s.serverLock.Lock() s.serverLock.Lock()
defer s.serverLock.Unlock() defer s.serverLock.Unlock()
s.conf.AllowedClients = j.AllowedClients s.conf.AllowedClients = list.AllowedClients
s.conf.DisallowedClients = j.DisallowedClients s.conf.DisallowedClients = list.DisallowedClients
s.conf.BlockedHosts = j.BlockedHosts s.conf.BlockedHosts = list.BlockedHosts
s.access = a s.access = a
} }

View File

@ -8,99 +8,23 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestIsBlockedIP(t *testing.T) { func TestIsBlockedClientID(t *testing.T) {
const ( clientID := "client-1"
ip int = iota clients := []string{clientID}
cidr
)
rules := []string{ a, err := newAccessCtx(clients, nil, nil)
ip: "1.1.1.1",
cidr: "2.2.0.0/16",
}
testCases := []struct {
name string
allowed bool
ip net.IP
wantDis bool
wantRule string
}{{
name: "allow_ip",
allowed: true,
ip: net.IPv4(1, 1, 1, 1),
wantDis: false,
wantRule: "",
}, {
name: "disallow_ip",
allowed: true,
ip: net.IPv4(1, 1, 1, 2),
wantDis: true,
wantRule: "",
}, {
name: "allow_cidr",
allowed: true,
ip: net.IPv4(2, 2, 1, 1),
wantDis: false,
wantRule: "",
}, {
name: "disallow_cidr",
allowed: true,
ip: net.IPv4(2, 3, 1, 1),
wantDis: true,
wantRule: "",
}, {
name: "allow_ip",
allowed: false,
ip: net.IPv4(1, 1, 1, 1),
wantDis: true,
wantRule: rules[ip],
}, {
name: "disallow_ip",
allowed: false,
ip: net.IPv4(1, 1, 1, 2),
wantDis: false,
wantRule: "",
}, {
name: "allow_cidr",
allowed: false,
ip: net.IPv4(2, 2, 1, 1),
wantDis: true,
wantRule: rules[cidr],
}, {
name: "disallow_cidr",
allowed: false,
ip: net.IPv4(2, 3, 1, 1),
wantDis: false,
wantRule: "",
}}
for _, tc := range testCases {
prefix := "allowed_"
if !tc.allowed {
prefix = "disallowed_"
}
t.Run(prefix+tc.name, func(t *testing.T) {
allowedRules := rules
var disallowedRules []string
if !tc.allowed {
allowedRules, disallowedRules = disallowedRules, allowedRules
}
aCtx, err := newAccessCtx(allowedRules, disallowedRules, nil)
require.NoError(t, err) require.NoError(t, err)
disallowed, rule := aCtx.IsBlockedIP(tc.ip) assert.False(t, a.isBlockedClientID(clientID))
assert.Equal(t, tc.wantDis, disallowed)
assert.Equal(t, tc.wantRule, rule) a, err = newAccessCtx(nil, clients, nil)
}) require.NoError(t, err)
}
assert.True(t, a.isBlockedClientID(clientID))
} }
func TestIsBlockedDomain(t *testing.T) { func TestIsBlockedHost(t *testing.T) {
aCtx, err := newAccessCtx(nil, nil, []string{ a, err := newAccessCtx(nil, nil, []string{
"host1", "host1",
"*.host.com", "*.host.com",
"||host3.com^", "||host3.com^",
@ -109,49 +33,105 @@ func TestIsBlockedDomain(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
domain string host string
want bool want bool
}{{ }{{
name: "plain_match", name: "plain_match",
domain: "host1", host: "host1",
want: true, want: true,
}, { }, {
name: "plain_mismatch", name: "plain_mismatch",
domain: "host2", host: "host2",
want: false, want: false,
}, { }, {
name: "wildcard_type-1_match_short", name: "subdomain_match_short",
domain: "asdf.host.com", host: "asdf.host.com",
want: true, want: true,
}, { }, {
name: "wildcard_type-1_match_long", name: "subdomain_match_long",
domain: "qwer.asdf.host.com", host: "qwer.asdf.host.com",
want: true, want: true,
}, { }, {
name: "wildcard_type-1_mismatch_no-lead", name: "subdomain_mismatch_no_lead",
domain: "host.com", host: "host.com",
want: false, want: false,
}, { }, {
name: "wildcard_type-1_mismatch_bad-asterisk", name: "subdomain_mismatch_bad_asterisk",
domain: "asdf.zhost.com", host: "asdf.zhost.com",
want: false, want: false,
}, { }, {
name: "wildcard_type-2_match_simple", name: "rule_match_simple",
domain: "host3.com", host: "host3.com",
want: true, want: true,
}, { }, {
name: "wildcard_type-2_match_complex", name: "rule_match_complex",
domain: "asdf.host3.com", host: "asdf.host3.com",
want: true, want: true,
}, { }, {
name: "wildcard_type-2_mismatch", name: "rule_mismatch",
domain: ".host3.com", host: ".host3.com",
want: false, want: false,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain)) assert.Equal(t, tc.want, a.isBlockedHost(tc.host))
}) })
} }
} }
func TestIsBlockedIP(t *testing.T) {
clients := []string{
"1.2.3.4",
"5.6.7.8/24",
}
allowCtx, err := newAccessCtx(clients, nil, nil)
require.NoError(t, err)
blockCtx, err := newAccessCtx(nil, clients, nil)
require.NoError(t, err)
testCases := []struct {
name string
wantRule string
ip net.IP
wantBlocked bool
}{{
name: "match_ip",
wantRule: "1.2.3.4",
ip: net.IP{1, 2, 3, 4},
wantBlocked: true,
}, {
name: "match_cidr",
wantRule: "5.6.7.8/24",
ip: net.IP{5, 6, 7, 100},
wantBlocked: true,
}, {
name: "no_match_ip",
wantRule: "",
ip: net.IP{9, 2, 3, 4},
wantBlocked: false,
}, {
name: "no_match_cidr",
wantRule: "",
ip: net.IP{9, 6, 7, 100},
wantBlocked: false,
}}
t.Run("allow", func(t *testing.T) {
for _, tc := range testCases {
blocked, rule := allowCtx.isBlockedIP(tc.ip)
assert.Equal(t, !tc.wantBlocked, blocked)
assert.Equal(t, tc.wantRule, rule)
}
})
t.Run("block", func(t *testing.T) {
for _, tc := range testCases {
blocked, rule := blockCtx.isBlockedIP(tc.ip)
assert.Equal(t, tc.wantBlocked, blocked)
assert.Equal(t, tc.wantRule, rule)
}
})
}

View File

@ -2,6 +2,7 @@ package dnsforward
import ( import (
"crypto/tls" "crypto/tls"
"encoding/binary"
"fmt" "fmt"
"path" "path"
"strings" "strings"
@ -50,15 +51,15 @@ func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (
return clientID, nil return clientID, nil
} }
// processClientIDHTTPS extracts the client's ID from the path of the // clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
// client's DNS-over-HTTPS request. // client's DNS-over-HTTPS request.
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) { func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
pctx := ctx.proxyCtx
r := pctx.HTTPRequest r := pctx.HTTPRequest
if r == nil { if r == nil {
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto) return "", fmt.Errorf(
"proxy ctx http request of proto %s is nil",
return resultCodeError pctx.Proto,
)
} }
origPath := r.URL.Path origPath := r.URL.Path
@ -68,34 +69,25 @@ func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
} }
if len(parts) == 0 || parts[0] != "dns-query" { if len(parts) == 0 || parts[0] != "dns-query" {
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath) return "", fmt.Errorf("client id check: invalid path %q", origPath)
return resultCodeError
} }
clientID := ""
switch len(parts) { switch len(parts) {
case 1: case 1:
// Just /dns-query, no client ID. // Just /dns-query, no client ID.
return resultCodeSuccess return "", nil
case 2: case 2:
clientID = parts[1] clientID = parts[1]
default: default:
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath) return "", fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
return resultCodeError
} }
err := ValidateClientID(clientID) err = ValidateClientID(clientID)
if err != nil { if err != nil {
ctx.err = fmt.Errorf("client id check: %w", err) return "", fmt.Errorf("client id check: %w", err)
return resultCodeError
} }
ctx.clientID = clientID return clientID, nil
return resultCodeSuccess
} }
// tlsConn is a narrow interface for *tls.Conn to simplify testing. // tlsConn is a narrow interface for *tls.Conn to simplify testing.
@ -108,53 +100,73 @@ type quicSession interface {
ConnectionState() (cs quic.ConnectionState) ConnectionState() (cs quic.ConnectionState)
} }
// processClientID extracts the client's ID from the server name of the client's // clientIDFromDNSContext extracts the client's ID from the server name of the
// DoT or DoQ request or the path of the client's DoH. // client's DoT or DoQ request or the path of the client's DoH. If the protocol
func processClientID(dctx *dnsContext) (rc resultCode) { // is not one of these, clientID is an empty string and err is nil.
pctx := dctx.proxyCtx func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto proto := pctx.Proto
if proto == proxy.ProtoHTTPS { if proto == proxy.ProtoHTTPS {
return processClientIDHTTPS(dctx) return clientIDFromDNSContextHTTPS(pctx)
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC { } else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return resultCodeSuccess return "", nil
} }
srvConf := dctx.srv.conf hostSrvName := s.conf.ServerName
hostSrvName := srvConf.TLSConfig.ServerName
if hostSrvName == "" { if hostSrvName == "" {
return resultCodeSuccess return "", nil
} }
cliSrvName := "" cliSrvName := ""
if proto == proxy.ProtoTLS { switch proto {
case proxy.ProtoTLS:
conn := pctx.Conn conn := pctx.Conn
tc, ok := conn.(tlsConn) tc, ok := conn.(tlsConn)
if !ok { if !ok {
dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn) return "", fmt.Errorf(
"proxy ctx conn of proto %s is %T, want *tls.Conn",
return resultCodeError proto,
conn,
)
} }
cliSrvName = tc.ConnectionState().ServerName cliSrvName = tc.ConnectionState().ServerName
} else if proto == proxy.ProtoQUIC { case proxy.ProtoQUIC:
qs, ok := pctx.QUICSession.(quicSession) qs, ok := pctx.QUICSession.(quicSession)
if !ok { if !ok {
dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession) return "", fmt.Errorf(
"proxy ctx quic session of proto %s is %T, want quic.Session",
return resultCodeError proto,
pctx.QUICSession,
)
} }
cliSrvName = qs.ConnectionState().TLS.ServerName cliSrvName = qs.ConnectionState().TLS.ServerName
} }
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck) clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil { if err != nil {
dctx.err = fmt.Errorf("client id check: %w", err) return "", fmt.Errorf("client id check: %w", err)
return resultCodeError
} }
dctx.clientID = clientID return clientID, nil
}
// processClientID puts the clientID into the DNS context, if there is one.
func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
var key [8]byte
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
clientIDData := s.clientIDCache.Get(key[:])
if clientIDData == nil {
return resultCodeSuccess
}
dctx.clientID = string(clientIDData)
return resultCodeSuccess return resultCodeSuccess
} }

View File

@ -45,15 +45,14 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
return cs return cs
} }
func TestProcessClientID(t *testing.T) { func TestServer_clientIDFromDNSContext(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
proto string proto proxy.Proto
hostSrvName string hostSrvName string
cliSrvName string cliSrvName string
wantClientID string wantClientID string
wantErrMsg string wantErrMsg string
wantRes resultCode
strictSNI bool strictSNI bool
}{{ }{{
name: "udp", name: "udp",
@ -62,7 +61,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "", cliSrvName: "",
wantClientID: "", wantClientID: "",
wantErrMsg: "", wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: false, strictSNI: false,
}, { }, {
name: "tls_no_client_id", name: "tls_no_client_id",
@ -71,7 +69,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "example.com", cliSrvName: "example.com",
wantClientID: "", wantClientID: "",
wantErrMsg: "", wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true, strictSNI: true,
}, { }, {
name: "tls_no_client_server_name", name: "tls_no_client_server_name",
@ -81,7 +78,6 @@ func TestProcessClientID(t *testing.T) {
wantClientID: "", wantClientID: "",
wantErrMsg: `client id check: client server name "" ` + wantErrMsg: `client id check: client server name "" ` +
`doesn't match host server name "example.com"`, `doesn't match host server name "example.com"`,
wantRes: resultCodeError,
strictSNI: true, strictSNI: true,
}, { }, {
name: "tls_no_client_server_name_no_strict", name: "tls_no_client_server_name_no_strict",
@ -90,7 +86,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "", cliSrvName: "",
wantClientID: "", wantClientID: "",
wantErrMsg: "", wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: false, strictSNI: false,
}, { }, {
name: "tls_client_id", name: "tls_client_id",
@ -99,7 +94,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "cli.example.com", cliSrvName: "cli.example.com",
wantClientID: "cli", wantClientID: "cli",
wantErrMsg: "", wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true, strictSNI: true,
}, { }, {
name: "tls_client_id_hostname_error", name: "tls_client_id_hostname_error",
@ -109,7 +103,6 @@ func TestProcessClientID(t *testing.T) {
wantClientID: "", wantClientID: "",
wantErrMsg: `client id check: client server name "cli.example.net" ` + wantErrMsg: `client id check: client server name "cli.example.net" ` +
`doesn't match host server name "example.com"`, `doesn't match host server name "example.com"`,
wantRes: resultCodeError,
strictSNI: true, strictSNI: true,
}, { }, {
name: "tls_invalid_client_id", name: "tls_invalid_client_id",
@ -119,7 +112,6 @@ func TestProcessClientID(t *testing.T) {
wantClientID: "", wantClientID: "",
wantErrMsg: `client id check: invalid client id "!!!": ` + wantErrMsg: `client id check: invalid client id "!!!": ` +
`invalid char '!' at index 0`, `invalid char '!' at index 0`,
wantRes: resultCodeError,
strictSNI: true, strictSNI: true,
}, { }, {
name: "tls_client_id_too_long", name: "tls_client_id_too_long",
@ -131,7 +123,6 @@ func TestProcessClientID(t *testing.T) {
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` + wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` + `pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
`label is too long, max: 63`, `label is too long, max: 63`,
wantRes: resultCodeError,
strictSNI: true, strictSNI: true,
}, { }, {
name: "quic_client_id", name: "quic_client_id",
@ -140,7 +131,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "cli.example.com", cliSrvName: "cli.example.com",
wantClientID: "cli", wantClientID: "cli",
wantErrMsg: "", wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true, strictSNI: true,
}} }}
@ -150,6 +140,7 @@ func TestProcessClientID(t *testing.T) {
ServerName: tc.hostSrvName, ServerName: tc.hostSrvName,
StrictSNICheck: tc.strictSNI, StrictSNICheck: tc.strictSNI,
} }
srv := &Server{ srv := &Server{
conf: ServerConfig{TLSConfig: tlsConf}, conf: ServerConfig{TLSConfig: tlsConf},
} }
@ -168,79 +159,68 @@ func TestProcessClientID(t *testing.T) {
} }
} }
dctx := &dnsContext{ pctx := &proxy.DNSContext{
srv: srv,
proxyCtx: &proxy.DNSContext{
Proto: tc.proto, Proto: tc.proto,
Conn: conn, Conn: conn,
QUICSession: qs, QUICSession: qs,
},
} }
res := processClientID(dctx) clientID, err := srv.clientIDFromDNSContext(pctx)
assert.Equal(t, tc.wantRes, res) assert.Equal(t, tc.wantClientID, clientID)
assert.Equal(t, tc.wantClientID, dctx.clientID)
if tc.wantErrMsg == "" { if tc.wantErrMsg == "" {
assert.NoError(t, dctx.err) assert.NoError(t, err)
} else { } else {
require.Error(t, dctx.err) require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
assert.Equal(t, tc.wantErrMsg, err.Error())
} }
}) })
} }
} }
func TestProcessClientID_https(t *testing.T) { func TestClientIDFromDNSContextHTTPS(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
path string path string
wantClientID string wantClientID string
wantErrMsg string wantErrMsg string
wantRes resultCode
}{{ }{{
name: "no_client_id", name: "no_client_id",
path: "/dns-query", path: "/dns-query",
wantClientID: "", wantClientID: "",
wantErrMsg: "", wantErrMsg: "",
wantRes: resultCodeSuccess,
}, { }, {
name: "no_client_id_slash", name: "no_client_id_slash",
path: "/dns-query/", path: "/dns-query/",
wantClientID: "", wantClientID: "",
wantErrMsg: "", wantErrMsg: "",
wantRes: resultCodeSuccess,
}, { }, {
name: "client_id", name: "client_id",
path: "/dns-query/cli", path: "/dns-query/cli",
wantClientID: "cli", wantClientID: "cli",
wantErrMsg: "", wantErrMsg: "",
wantRes: resultCodeSuccess,
}, { }, {
name: "client_id_slash", name: "client_id_slash",
path: "/dns-query/cli/", path: "/dns-query/cli/",
wantClientID: "cli", wantClientID: "cli",
wantErrMsg: "", wantErrMsg: "",
wantRes: resultCodeSuccess,
}, { }, {
name: "bad_url", name: "bad_url",
path: "/foo", path: "/foo",
wantClientID: "", wantClientID: "",
wantErrMsg: `client id check: invalid path "/foo"`, wantErrMsg: `client id check: invalid path "/foo"`,
wantRes: resultCodeError,
}, { }, {
name: "extra", name: "extra",
path: "/dns-query/cli/foo", path: "/dns-query/cli/foo",
wantClientID: "", wantClientID: "",
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`, wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
wantRes: resultCodeError,
}, { }, {
name: "invalid_client_id", name: "invalid_client_id",
path: "/dns-query/!!!", path: "/dns-query/!!!",
wantClientID: "", wantClientID: "",
wantErrMsg: `client id check: invalid client id "!!!": ` + wantErrMsg: `client id check: invalid client id "!!!": ` +
`invalid char '!' at index 0`, `invalid char '!' at index 0`,
wantRes: resultCodeError,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
@ -251,23 +231,20 @@ func TestProcessClientID_https(t *testing.T) {
}, },
} }
dctx := &dnsContext{ pctx := &proxy.DNSContext{
proxyCtx: &proxy.DNSContext{
Proto: proxy.ProtoHTTPS, Proto: proxy.ProtoHTTPS,
HTTPRequest: r, HTTPRequest: r,
},
} }
res := processClientID(dctx) clientID, err := clientIDFromDNSContextHTTPS(pctx)
assert.Equal(t, tc.wantRes, res) assert.Equal(t, tc.wantClientID, clientID)
assert.Equal(t, tc.wantClientID, dctx.clientID)
if tc.wantErrMsg == "" { if tc.wantErrMsg == "" {
assert.NoError(t, dctx.err) assert.NoError(t, err)
} else { } else {
require.Error(t, dctx.err) require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) assert.Equal(t, tc.wantErrMsg, err.Error())
} }
}) })
} }

View File

@ -331,7 +331,7 @@ func (s *Server) prepareUpstreamSettings() error {
upstreams = aghstrings.FilterOut(upstreams, aghstrings.IsCommentOrEmpty) upstreams = aghstrings.FilterOut(upstreams, aghstrings.IsCommentOrEmpty)
upstreamConfig, err := proxy.ParseUpstreamsConfig( upstreamConfig, err := proxy.ParseUpstreamsConfig(
upstreams, upstreams,
upstream.Options{ &upstream.Options{
Bootstrap: s.conf.BootstrapDNS, Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout, Timeout: s.conf.UpstreamTimeout,
}, },
@ -342,10 +342,10 @@ func (s *Server) prepareUpstreamSettings() error {
if len(upstreamConfig.Upstreams) == 0 { if len(upstreamConfig.Upstreams) == 0 {
log.Info("warning: no default upstream servers specified, using %v", defaultDNS) log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
var uc proxy.UpstreamConfig var uc *proxy.UpstreamConfig
uc, err = proxy.ParseUpstreamsConfig( uc, err = proxy.ParseUpstreamsConfig(
defaultDNS, defaultDNS,
upstream.Options{ &upstream.Options{
Bootstrap: s.conf.BootstrapDNS, Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout, Timeout: s.conf.UpstreamTimeout,
}, },
@ -356,7 +356,8 @@ func (s *Server) prepareUpstreamSettings() error {
upstreamConfig.Upstreams = uc.Upstreams upstreamConfig.Upstreams = uc.Upstreams
} }
s.conf.UpstreamConfig = &upstreamConfig s.conf.UpstreamConfig = upstreamConfig
return nil return nil
} }

View File

@ -89,7 +89,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
s.processInternalHosts, s.processInternalHosts,
s.processRestrictLocal, s.processRestrictLocal,
s.processInternalIPAddrs, s.processInternalIPAddrs,
processClientID, s.processClientID,
processFilteringBeforeRequest, processFilteringBeforeRequest,
s.processLocalPTR, s.processLocalPTR,
s.processUpstream, s.processUpstream,
@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
s.tableHostToIP = t s.tableHostToIP = t
} }
func (s *Server) setTableIPToHost(t ipToHostTable) { func (s *Server) setTableIPToHost(t *aghnet.IPMap) {
s.tableIPToHostLock.Lock() s.tableIPToHostLock.Lock()
defer s.tableIPToHostLock.Unlock() defer s.tableIPToHostLock.Unlock()
@ -188,13 +188,13 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
} }
var hostToIP hostToIPTable var hostToIP hostToIPTable
var ipToHost ipToHostTable var ipToHost *aghnet.IPMap
if add { if add {
hostToIP = make(hostToIPTable)
ipToHost = make(ipToHostTable)
ll := s.dhcpServer.Leases(dhcpd.LeasesAll) ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
hostToIP = make(hostToIPTable, len(ll))
ipToHost = aghnet.NewIPMap(len(ll))
for _, l := range ll { for _, l := range ll {
// TODO(a.garipov): Remove this after we're finished // TODO(a.garipov): Remove this after we're finished
// with the client hostname validations in the DHCP // with the client hostname validations in the DHCP
@ -210,14 +210,14 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
lowhost := strings.ToLower(l.Hostname) lowhost := strings.ToLower(l.Hostname)
ipToHost[l.IP.String()] = lowhost ipToHost.Set(l.IP, lowhost)
ip := make(net.IP, 4) ip := make(net.IP, 4)
copy(ip, l.IP.To4()) copy(ip, l.IP.To4())
hostToIP[lowhost] = ip hostToIP[lowhost] = ip
} }
log.Debug("dns: added %d A/PTR entries from DHCP", len(ipToHost)) log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len())
} }
s.setTableHostToIP(hostToIP) s.setTableHostToIP(hostToIP)
@ -377,7 +377,15 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
return "", false return "", false
} }
host, ok = s.tableIPToHost[ip.String()] var v interface{}
v, ok = s.tableIPToHost.Get(ip)
var typOK bool
if host, typOK = v.(string); !typOK {
log.Error("dns: bad type %T in tableIPToHost for %s", v, ip)
return "", false
}
return host, ok return host, ok
} }

View File

@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -26,6 +27,11 @@ import (
// DefaultTimeout is the default upstream timeout // DefaultTimeout is the default upstream timeout
const DefaultTimeout = 10 * time.Second const DefaultTimeout = 10 * time.Second
// defaultClientIDCacheCount is the default count of items in the LRU client ID
// cache. The assumption here is that there won't be more than this many
// requests between the BeforeRequestHandler stage and the actual processing.
const defaultClientIDCacheCount = 1024
const ( const (
safeBrowsingBlockHost = "standard-block.dns.adguard.com" safeBrowsingBlockHost = "standard-block.dns.adguard.com"
parentalBlockHost = "family-block.dns.adguard.com" parentalBlockHost = "family-block.dns.adguard.com"
@ -44,12 +50,6 @@ var webRegistered bool
// hostToIPTable is an alias for the type of Server.tableHostToIP. // hostToIPTable is an alias for the type of Server.tableHostToIP.
type hostToIPTable = map[string]net.IP type hostToIPTable = map[string]net.IP
// ipToHostTable is an alias for the type of Server.tableIPToHost.
//
// TODO(a.garipov): Define an IPMap type in aghnet and use here and in other
// places?
type ipToHostTable = map[string]string
// Server is the main way to start a DNS server. // Server is the main way to start a DNS server.
// //
// Example: // Example:
@ -81,9 +81,13 @@ type Server struct {
tableHostToIP hostToIPTable tableHostToIP hostToIPTable
tableHostToIPLock sync.Mutex tableHostToIPLock sync.Mutex
tableIPToHost ipToHostTable tableIPToHost *aghnet.IPMap
tableIPToHostLock sync.Mutex tableIPToHostLock sync.Mutex
// clientIDCache is a temporary storage for clientIDs that were
// extracted during the BeforeRequestHandler stage.
clientIDCache cache.Cache
// DNS proxy instance for internal usage // DNS proxy instance for internal usage
// We don't Start() it and so no listen port is required. // We don't Start() it and so no listen port is required.
internalProxy *proxy.Proxy internalProxy *proxy.Proxy
@ -152,6 +156,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
subnetDetector: p.SubnetDetector, subnetDetector: p.SubnetDetector,
localDomainSuffix: localDomainSuffix, localDomainSuffix: localDomainSuffix,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum), recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{
EnableLRU: true,
MaxCount: defaultClientIDCacheCount,
}),
} }
// TODO(e.burkov): Enable the refresher after the actual implementation // TODO(e.burkov): Enable the refresher after the actual implementation
@ -414,19 +422,22 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs) log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
var upsConfig proxy.UpstreamConfig var upsConfig *proxy.UpstreamConfig
upsConfig, err = proxy.ParseUpstreamsConfig(localAddrs, upstream.Options{ upsConfig, err = proxy.ParseUpstreamsConfig(
localAddrs,
&upstream.Options{
Bootstrap: bootstraps, Bootstrap: bootstraps,
Timeout: defaultLocalTimeout, Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's ceritificates? // TODO(e.burkov): Should we verify server's ceritificates?
}) },
)
if err != nil { if err != nil {
return fmt.Errorf("parsing upstreams: %w", err) return fmt.Errorf("parsing upstreams: %w", err)
} }
s.localResolvers = &proxy.Proxy{ s.localResolvers = &proxy.Proxy{
Config: proxy.Config{ Config: proxy.Config{
UpstreamConfig: &upsConfig, UpstreamConfig: upsConfig,
}, },
} }
@ -577,11 +588,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
// IsBlockedIP - return TRUE if this client should be blocked // IsBlockedClient returns true if the client is blocked by the current access
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) { // settings.
if ip == nil { func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
return false, "" s.serverLock.RLock()
defer s.serverLock.RUnlock()
allowlistMode := s.access.allowlistMode()
blockedByIP, rule := s.access.isBlockedIP(ip)
blockedByClientID := s.access.isBlockedClientID(clientID)
// Allow if at least one of the checks allows in allowlist mode, but
// block if at least one of the checks blocks in blocklist mode.
if allowlistMode && blockedByIP && blockedByClientID {
log.Debug("client %s (id %q) is not in access allowlist", ip, clientID)
// Return now without substituting the empty rule for the
// clientID because the rule can't be empty here.
return true, rule
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
log.Debug("client %s (id %q) is in access blocklist", ip, clientID)
blocked = true
} }
return s.access.IsBlockedIP(ip) if rule == "" {
rule = clientID
}
return blocked, rule
} }

View File

@ -257,19 +257,22 @@ func TestServer(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
proto string net string
proto proxy.Proto
}{{ }{{
name: "message_over_udp", name: "message_over_udp",
net: "",
proto: proxy.ProtoUDP, proto: proxy.ProtoUDP,
}, { }, {
name: "message_over_tcp", name: "message_over_tcp",
net: "tcp",
proto: proxy.ProtoTCP, proto: proxy.ProtoTCP,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
addr := s.dnsProxy.Addr(tc.proto) addr := s.dnsProxy.Addr(tc.proto)
client := dns.Client{Net: tc.proto} client := dns.Client{Net: tc.net}
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String()) reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err) require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
@ -324,7 +327,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
// Message over UDP. // Message over UDP.
req := createGoogleATestMessage() req := createGoogleATestMessage()
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
client := dns.Client{Net: proxy.ProtoUDP} client := &dns.Client{}
reply, _, err := client.Exchange(req, addr.String()) reply, _, err := client.Exchange(req, addr.String())
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err) require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
@ -376,7 +379,7 @@ func TestDoQServer(t *testing.T) {
// Create a DNS-over-QUIC upstream. // Create a DNS-over-QUIC upstream.
addr := s.dnsProxy.Addr(proxy.ProtoQUIC) addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
opts := upstream.Options{InsecureSkipVerify: true} opts := &upstream.Options{InsecureSkipVerify: true}
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts) u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
require.NoError(t, err) require.NoError(t, err)
@ -420,7 +423,7 @@ func TestServerRace(t *testing.T) {
// Message over UDP. // Message over UDP.
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
conn, err := dns.Dial(proxy.ProtoUDP, addr.String()) conn, err := dns.Dial("udp", addr.String())
require.NoErrorf(t, err, "cannot connect to the proxy: %s", err) require.NoErrorf(t, err, "cannot connect to the proxy: %s", err)
sendTestMessagesAsync(t, conn) sendTestMessagesAsync(t, conn)
@ -445,7 +448,7 @@ func TestSafeSearch(t *testing.T) {
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
client := dns.Client{Net: proxy.ProtoUDP} client := &dns.Client{}
yandexIP := net.IP{213, 180, 193, 56} yandexIP := net.IP{213, 180, 193, 56}
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com") googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
@ -507,7 +510,6 @@ func TestInvalidRequest(t *testing.T) {
// Send a DNS request without question. // Send a DNS request without question.
_, _, err := (&dns.Client{ _, _, err := (&dns.Client{
Net: proxy.ProtoUDP,
Timeout: 500 * time.Millisecond, Timeout: 500 * time.Millisecond,
}).Exchange(&req, addr) }).Exchange(&req, addr)

View File

@ -1,6 +1,7 @@
package dnsforward package dnsforward
import ( import (
"encoding/binary"
"fmt" "fmt"
"strings" "strings"
@ -11,23 +12,39 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) { // beforeRequestHandler is the handler that is called before any other
ip := aghnet.IPFromAddr(d.Addr) // processing, including logs. It performs access checks and puts the client
disallowed, _ := s.access.IsBlockedIP(ip) // ID, if there is one, into the server's cache.
if disallowed { func (s *Server) beforeRequestHandler(
log.Tracef("Client IP %s is blocked by settings", ip) _ *proxy.Proxy,
pctx *proxy.DNSContext,
) (reply bool, err error) {
ip := aghnet.IPFromAddr(pctx.Addr)
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return false, fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(ip, clientID)
if blocked {
return false, nil return false, nil
} }
if len(d.Req.Question) == 1 { if len(pctx.Req.Question) == 1 {
host := strings.TrimSuffix(d.Req.Question[0].Name, ".") host := strings.TrimSuffix(pctx.Req.Question[0].Name, ".")
if s.access.IsBlockedDomain(host) { if s.access.isBlockedHost(host) {
log.Tracef("domain %s is blocked by access settings", host) log.Debug("host %s is in access blocklist", host)
return false, nil return false, nil
} }
} }
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return true, nil return true, nil
} }

View File

@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (string, error) {
return boot, fmt.Errorf("invalid bootstrap server address: empty") return boot, fmt.Errorf("invalid bootstrap server address: empty")
} }
if _, err := upstream.NewResolver(boot, upstream.Options{Timeout: 0}); err != nil { if _, err := upstream.NewResolver(boot, nil); err != nil {
return boot, fmt.Errorf("invalid bootstrap server address: %w", err) return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
} }
} }
@ -348,7 +348,7 @@ func ValidateUpstreams(upstreams []string) (err error) {
_, err = proxy.ParseUpstreamsConfig( _, err = proxy.ParseUpstreamsConfig(
upstreams, upstreams,
upstream.Options{ &upstream.Options{
Bootstrap: []string{}, Bootstrap: []string{},
Timeout: DefaultTimeout, Timeout: DefaultTimeout,
}, },
@ -546,7 +546,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
log.Debug("checking if dns server %q works...", input) log.Debug("checking if dns server %q works...", input)
var u upstream.Upstream var u upstream.Upstream
u, err = upstream.AddressToUpstream(input, upstream.Options{ u, err = upstream.AddressToUpstream(input, &upstream.Options{
Bootstrap: bootstrap, Bootstrap: bootstrap,
Timeout: timeout, Timeout: timeout,
}) })

View File

@ -46,7 +46,7 @@ func (l *testStats) Update(e stats.Entry) {
func TestProcessQueryLogsAndStats(t *testing.T) { func TestProcessQueryLogsAndStats(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
proto string proto proxy.Proto
addr net.Addr addr net.Addr
clientID string clientID string
wantLogProto querylog.ClientProto wantLogProto querylog.ClientProto
@ -156,7 +156,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RParental, wantStatResult: stats.RParental,
}} }}
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{}) ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
require.Nil(t, err) require.Nil(t, err)
for _, tc := range testCases { for _, tc := range testCases {

View File

@ -49,7 +49,7 @@ func (d *DNSFilter) initSecurityServices() error {
var err error var err error
d.safeBrowsingServer = defaultSafebrowsingServer d.safeBrowsingServer = defaultSafebrowsingServer
d.parentalServer = defaultParentalServer d.parentalServer = defaultParentalServer
opts := upstream.Options{ opts := &upstream.Options{
Timeout: dnsTimeout, Timeout: dnsTimeout,
ServerIPAddrs: []net.IP{ ServerIPAddrs: []net.IP{
{94, 140, 14, 15}, {94, 140, 14, 15},

View File

@ -80,7 +80,10 @@ type clientsContainer struct {
// different types (string, net.IP, and so on). // different types (string, net.IP, and so on).
list map[string]*Client // name -> client list map[string]*Client // name -> client
idIndex map[string]*Client // ID -> client idIndex map[string]*Client // ID -> client
ipToRC map[string]*RuntimeClient // IP -> runtime client
// ipToRC is the IP address to *RuntimeClient map.
ipToRC *aghnet.IPMap
lock sync.Mutex lock sync.Mutex
allTags *aghstrings.Set allTags *aghstrings.Set
@ -109,7 +112,7 @@ func (clients *clientsContainer) Init(
} }
clients.list = make(map[string]*Client) clients.list = make(map[string]*Client)
clients.idIndex = make(map[string]*Client) clients.idIndex = make(map[string]*Client)
clients.ipToRC = make(map[string]*RuntimeClient) clients.ipToRC = aghnet.NewIPMap(0)
clients.allTags = aghstrings.NewSet(clientTags...) clients.allTags = aghstrings.NewSet(clientTags...)
@ -250,18 +253,17 @@ func (clients *clientsContainer) onHostsChanged() {
clients.addFromHostsFile() clients.addFromHostsFile()
} }
// Exists checks if client with this ID already exists. // Exists checks if client with this IP address already exists.
func (clients *clientsContainer) Exists(id string, source clientSource) (ok bool) { func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
_, ok = clients.findLocked(id) _, ok = clients.findLocked(ip.String())
if ok { if ok {
return true return true
} }
var rc *RuntimeClient rc, ok := clients.findRuntimeClientLocked(ip)
rc, ok = clients.ipToRC[id]
if !ok { if !ok {
return false return false
} }
@ -288,13 +290,14 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
for _, id := range ids { for _, id := range ids {
var name string var name string
whois := &querylog.ClientWHOIS{} whois := &querylog.ClientWHOIS{}
ip := net.ParseIP(id)
c, ok := clients.Find(id) c, ok := clients.Find(id)
if ok { if ok {
name = c.Name name = c.Name
} else { } else if ip != nil {
var rc RuntimeClient var rc *RuntimeClient
rc, ok = clients.FindRuntimeClient(id) rc, ok = clients.FindRuntimeClient(ip)
if !ok { if !ok {
continue continue
} }
@ -303,8 +306,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
whois = toQueryLogWHOIS(rc.WHOISInfo) whois = toQueryLogWHOIS(rc.WHOISInfo)
} }
ip := net.ParseIP(id) disallowed, disallowedRule := clients.dnsServer.IsBlockedClient(ip, id)
disallowed, disallowedRule := clients.dnsServer.IsBlockedIP(ip)
return &querylog.Client{ return &querylog.Client{
Name: name, Name: name,
@ -356,10 +358,10 @@ func (clients *clientsContainer) findUpstreams(
return c.upstreamConfig, nil return c.upstreamConfig, nil
} }
var conf proxy.UpstreamConfig var conf *proxy.UpstreamConfig
conf, err = proxy.ParseUpstreamsConfig( conf, err = proxy.ParseUpstreamsConfig(
upstreams, upstreams,
upstream.Options{ &upstream.Options{
Bootstrap: config.DNS.BootstrapDNS, Bootstrap: config.DNS.BootstrapDNS,
Timeout: config.DNS.UpstreamTimeout.Duration, Timeout: config.DNS.UpstreamTimeout.Duration,
}, },
@ -368,9 +370,9 @@ func (clients *clientsContainer) findUpstreams(
return nil, err return nil, err
} }
c.upstreamConfig = &conf c.upstreamConfig = conf
return &conf, nil return conf, nil
} }
// findLocked searches for a client by its ID. For internal use only. // findLocked searches for a client by its ID. For internal use only.
@ -423,22 +425,35 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return nil, false return nil, false
} }
// findRuntimeClientLocked finds a runtime client by their IP address. For
// internal use only.
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
var v interface{}
v, ok = clients.ipToRC.Get(ip)
if !ok {
return nil, false
}
rc, ok = v.(*RuntimeClient)
if !ok {
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
return nil, false
}
return rc, true
}
// FindRuntimeClient finds a runtime client by their IP. // FindRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) FindRuntimeClient(ip string) (RuntimeClient, bool) { func (clients *clientsContainer) FindRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) {
ipAddr := net.ParseIP(ip) if ip == nil {
if ipAddr == nil { return nil, false
return RuntimeClient{}, false
} }
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
rc, ok := clients.ipToRC[ip] return clients.findRuntimeClientLocked(ip)
if ok {
return *rc, true
}
return RuntimeClient{}, false
} }
// check validates the client. // check validates the client.
@ -621,17 +636,17 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
} }
// SetWHOISInfo sets the WHOIS information for a client. // SetWHOISInfo sets the WHOIS information for a client.
func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISInfo) { func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
_, ok := clients.findLocked(ip) _, ok := clients.findLocked(ip.String())
if ok { if ok {
log.Debug("clients: client for %s is already created, ignore whois info", ip) log.Debug("clients: client for %s is already created, ignore whois info", ip)
return return
} }
rc, ok := clients.ipToRC[ip] rc, ok := clients.findRuntimeClientLocked(ip)
if ok { if ok {
rc.WHOISInfo = wi rc.WHOISInfo = wi
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi) log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
@ -646,14 +661,15 @@ func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISI
} }
rc.WHOISInfo = wi rc.WHOISInfo = wi
clients.ipToRC[ip] = rc
clients.ipToRC.Set(ip, rc)
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
} }
// AddHost adds a new IP-hostname pairing. The priorities of the sources is // AddHost adds a new IP-hostname pairing. The priorities of the sources is
// taken into account. ok is true if the pairing was added. // taken into account. ok is true if the pairing was added.
func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) { func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -663,9 +679,9 @@ func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok
} }
// addHostLocked adds a new IP-hostname pairing. For internal use only. // addHostLocked adds a new IP-hostname pairing. For internal use only.
func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource) (ok bool) { func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) {
var rc *RuntimeClient var rc *RuntimeClient
rc, ok = clients.ipToRC[ip] rc, ok = clients.findRuntimeClientLocked(ip)
if ok { if ok {
if rc.Source > src { if rc.Source > src {
return false return false
@ -679,10 +695,10 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource
WHOISInfo: &RuntimeClientWHOISInfo{}, WHOISInfo: &RuntimeClientWHOISInfo{},
} }
clients.ipToRC[ip] = rc clients.ipToRC.Set(ip, rc)
} }
log.Debug("clients: added %q -> %q [%d]", ip, host, len(clients.ipToRC)) log.Debug("clients: added %s -> %q [%d]", ip, host, clients.ipToRC.Len())
return true return true
} }
@ -690,12 +706,21 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource
// rmHostsBySrc removes all entries that match the specified source. // rmHostsBySrc removes all entries that match the specified source.
func (clients *clientsContainer) rmHostsBySrc(src clientSource) { func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
n := 0 n := 0
for k, v := range clients.ipToRC { clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
if v.Source == src { rc, ok := v.(*RuntimeClient)
delete(clients.ipToRC, k) if !ok {
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
return true
}
if rc.Source == src {
clients.ipToRC.Del(ip)
n++ n++
} }
}
return true
})
log.Debug("clients: removed %d client aliases", n) log.Debug("clients: removed %d client aliases", n)
} }
@ -715,16 +740,23 @@ func (clients *clientsContainer) addFromHostsFile() {
clients.rmHostsBySrc(ClientSourceHostsFile) clients.rmHostsBySrc(ClientSourceHostsFile)
n := 0 n := 0
for ip, names := range hosts { hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
names, ok := v.([]string)
if !ok {
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
}
for _, name := range names { for _, name := range names {
ok := clients.addHostLocked(ip, name, ClientSourceHostsFile) ok = clients.addHostLocked(ip, name, ClientSourceHostsFile)
if ok { if ok {
n++ n++
} }
} }
}
log.Debug("Clients: added %d client aliases from system hosts-file", n) return true
})
log.Debug("clients: added %d client aliases from system hosts-file", n)
} }
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
@ -752,15 +784,16 @@ func (clients *clientsContainer) addFromSystemARP() {
// TODO(a.garipov): Rewrite to use bufio.Scanner. // TODO(a.garipov): Rewrite to use bufio.Scanner.
lines := strings.Split(string(data), "\n") lines := strings.Split(string(data), "\n")
for _, ln := range lines { for _, ln := range lines {
open := strings.Index(ln, " (") lparen := strings.Index(ln, " (")
close := strings.Index(ln, ") ") rparen := strings.Index(ln, ") ")
if open == -1 || close == -1 || open >= close { if lparen == -1 || rparen == -1 || lparen >= rparen {
continue continue
} }
host := ln[:open] host := ln[:lparen]
ip := ln[open+2 : close] ipStr := ln[lparen+2 : rparen]
if aghnet.ValidateDomainName(host) != nil || net.ParseIP(ip) == nil { ip := net.ParseIP(ipStr)
if aghnet.ValidateDomainName(host) != nil || ip == nil {
continue continue
} }
@ -796,7 +829,7 @@ func (clients *clientsContainer) updateFromDHCP(add bool) {
continue continue
} }
ok := clients.addHostLocked(l.IP.String(), l.Hostname, ClientSourceDHCP) ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
if ok { if ok {
n++ n++
} }

View File

@ -26,6 +26,7 @@ func TestClients(t *testing.T) {
ok, err := clients.Add(c) ok, err := clients.Add(c)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
c = &Client{ c = &Client{
@ -35,23 +36,27 @@ func TestClients(t *testing.T) {
ok, err = clients.Add(c) ok, err = clients.Add(c)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
c, ok = clients.Find("1.1.1.1") c, ok = clients.Find("1.1.1.1")
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "client1", c.Name) assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("1:2:3::4") c, ok = clients.Find("1:2:3::4")
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "client1", c.Name) assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("2.2.2.2") c, ok = clients.Find("2.2.2.2")
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "client2", c.Name) assert.Equal(t, "client2", c.Name)
assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile)) assert.False(t, clients.Exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile)) assert.True(t, clients.Exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile))
}) })
t.Run("add_fail_name", func(t *testing.T) { t.Run("add_fail_name", func(t *testing.T) {
@ -101,8 +106,8 @@ func TestClients(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.False(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) assert.True(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
err = clients.Update("client1", &Client{ err = clients.Update("client1", &Client{
IDs: []string{"1.1.1.2"}, IDs: []string{"1.1.1.2"},
@ -113,21 +118,25 @@ func TestClients(t *testing.T) {
c, ok := clients.Find("1.1.1.2") c, ok := clients.Find("1.1.1.2")
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name) assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.UseOwnSettings) assert.True(t, c.UseOwnSettings)
nilCli, ok := clients.list["client1"] nilCli, ok := clients.list["client1"]
require.False(t, ok) require.False(t, ok)
assert.Nil(t, nilCli) assert.Nil(t, nilCli)
require.Len(t, c.IDs, 1) require.Len(t, c.IDs, 1)
assert.Equal(t, "1.1.1.2", c.IDs[0]) assert.Equal(t, "1.1.1.2", c.IDs[0])
}) })
t.Run("del_success", func(t *testing.T) { t.Run("del_success", func(t *testing.T) {
ok := clients.Del("client1-renamed") ok := clients.Del("client1-renamed")
require.True(t, ok) require.True(t, ok)
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
assert.False(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
}) })
t.Run("del_fail", func(t *testing.T) { t.Run("del_fail", func(t *testing.T) {
@ -136,37 +145,44 @@ func TestClients(t *testing.T) {
}) })
t.Run("addhost_success", func(t *testing.T) { t.Run("addhost_success", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP) ip := net.IP{1, 1, 1, 1}
ok, err := clients.AddHost(ip, "host", ClientSourceARP)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP) ok, err = clients.AddHost(ip, "host2", ClientSourceARP)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile) ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists(ip, ClientSourceHostsFile))
}) })
t.Run("dhcp_replaces_arp", func(t *testing.T) { t.Run("dhcp_replaces_arp", func(t *testing.T) {
ok, err := clients.AddHost("1.2.3.4", "from_arp", ClientSourceARP) ip := net.IP{1, 2, 3, 4}
ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.True(t, clients.Exists(ip, ClientSourceARP))
assert.True(t, clients.Exists("1.2.3.4", ClientSourceARP)) ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
ok, err = clients.AddHost("1.2.3.4", "from_dhcp", ClientSourceDHCP)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok)
assert.True(t, clients.Exists("1.2.3.4", ClientSourceDHCP)) assert.True(t, ok)
assert.True(t, clients.Exists(ip, ClientSourceDHCP))
}) })
t.Run("addhost_fail", func(t *testing.T) { t.Run("addhost_fail", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS) ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, ok) assert.False(t, ok)
}) })
@ -183,31 +199,39 @@ func TestClientsWHOIS(t *testing.T) {
} }
t.Run("new_client", func(t *testing.T) { t.Run("new_client", func(t *testing.T) {
clients.SetWHOISInfo("1.1.1.255", whois) ip := net.IP{1, 1, 1, 255}
clients.SetWHOISInfo(ip, whois)
v, _ := clients.ipToRC.Get(ip)
require.NotNil(t, v)
require.NotNil(t, clients.ipToRC["1.1.1.255"]) rc, ok := v.(*RuntimeClient)
require.True(t, ok)
require.NotNil(t, rc)
h := clients.ipToRC["1.1.1.255"] assert.Equal(t, rc.WHOISInfo, whois)
require.NotNil(t, h)
assert.Equal(t, h.WHOISInfo, whois)
}) })
t.Run("existing_auto-client", func(t *testing.T) { t.Run("existing_auto-client", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) ip := net.IP{1, 1, 1, 1}
ok, err := clients.AddHost(ip, "host", ClientSourceRDNS)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
clients.SetWHOISInfo("1.1.1.1", whois) clients.SetWHOISInfo(ip, whois)
v, _ := clients.ipToRC.Get(ip)
require.NotNil(t, v)
require.NotNil(t, clients.ipToRC["1.1.1.1"]) rc, ok := v.(*RuntimeClient)
h := clients.ipToRC["1.1.1.1"] require.True(t, ok)
require.NotNil(t, h) require.NotNil(t, rc)
assert.Equal(t, h.WHOISInfo, whois) assert.Equal(t, rc.WHOISInfo, whois)
}) })
t.Run("can't_set_manually-added", func(t *testing.T) { t.Run("can't_set_manually-added", func(t *testing.T) {
ip := net.IP{1, 1, 1, 2}
ok, err := clients.Add(&Client{ ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.2"}, IDs: []string{"1.1.1.2"},
Name: "client1", Name: "client1",
@ -215,8 +239,10 @@ func TestClientsWHOIS(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
clients.SetWHOISInfo("1.1.1.2", whois) clients.SetWHOISInfo(ip, whois)
require.Nil(t, clients.ipToRC["1.1.1.2"]) v, _ := clients.ipToRC.Get(ip)
require.Nil(t, v)
assert.True(t, clients.Del("client1")) assert.True(t, clients.Del("client1"))
}) })
} }
@ -228,16 +254,18 @@ func TestClientsAddExisting(t *testing.T) {
clients.Init(nil, nil, nil) clients.Init(nil, nil, nil)
t.Run("simple", func(t *testing.T) { t.Run("simple", func(t *testing.T) {
ip := net.IP{1, 1, 1, 1}
// Add a client. // Add a client.
ok, err := clients.Add(&Client{ ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"}, IDs: []string{ip.String(), "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1", Name: "client1",
}) })
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
// Now add an auto-client with the same IP. // Now add an auto-client with the same IP.
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS) ok, err = clients.AddHost(ip, "test", ClientSourceRDNS)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
}) })
@ -245,7 +273,7 @@ func TestClientsAddExisting(t *testing.T) {
t.Run("complicated", func(t *testing.T) { t.Run("complicated", func(t *testing.T) {
var err error var err error
testIP := net.IP{1, 2, 3, 4} ip := net.IP{1, 2, 3, 4}
// First, init a DHCP server with a single static lease. // First, init a DHCP server with a single static lease.
config := dhcpd.ServerConfig{ config := dhcpd.ServerConfig{
@ -267,7 +295,7 @@ func TestClientsAddExisting(t *testing.T) {
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{ err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: testIP, IP: ip,
Hostname: "testhost", Hostname: "testhost",
Expiry: time.Now().Add(time.Hour), Expiry: time.Now().Add(time.Hour),
}) })
@ -275,7 +303,7 @@ func TestClientsAddExisting(t *testing.T) {
// Add a new client with the same IP as for a client with MAC. // Add a new client with the same IP as for a client with MAC.
ok, err := clients.Add(&Client{ ok, err := clients.Add(&Client{
IDs: []string{testIP.String()}, IDs: []string{ip.String()},
Name: "client2", Name: "client2",
}) })
require.NoError(t, err) require.NoError(t, err)

View File

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"github.com/AdguardTeam/golibs/log"
) )
// clientJSON is a common structure used by several handlers to deal with // clientJSON is a common structure used by several handlers to deal with
@ -44,13 +46,13 @@ type clientJSON struct {
type runtimeClientJSON struct { type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"` WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
IP string `json:"ip"`
Name string `json:"name"` Name string `json:"name"`
Source string `json:"source"` Source string `json:"source"`
IP net.IP `json:"ip"`
} }
type clientListJSON struct { type clientListJSON struct {
Clients []clientJSON `json:"clients"` Clients []*clientJSON `json:"clients"`
RuntimeClients []runtimeClientJSON `json:"auto_clients"` RuntimeClients []runtimeClientJSON `json:"auto_clients"`
Tags []string `json:"supported_tags"` Tags []string `json:"supported_tags"`
} }
@ -66,11 +68,20 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
cj := clientToJSON(c) cj := clientToJSON(c)
data.Clients = append(data.Clients, cj) data.Clients = append(data.Clients, cj)
} }
for ip, rc := range clients.ipToRC {
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
rc, ok := v.(*RuntimeClient)
if !ok {
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
return true
}
cj := runtimeClientJSON{ cj := runtimeClientJSON{
IP: ip,
Name: rc.Host,
WHOISInfo: rc.WHOISInfo, WHOISInfo: rc.WHOISInfo,
Name: rc.Host,
IP: ip,
} }
cj.Source = "etc/hosts" cj.Source = "etc/hosts"
@ -86,7 +97,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
} }
data.RuntimeClients = append(data.RuntimeClients, cj) data.RuntimeClients = append(data.RuntimeClients, cj)
}
return true
})
data.Tags = clientTags data.Tags = clientTags
@ -118,8 +131,8 @@ func jsonToClient(cj clientJSON) (c *Client) {
} }
// Convert Client object to JSON // Convert Client object to JSON
func clientToJSON(c *Client) clientJSON { func clientToJSON(c *Client) (cj *clientJSON) {
cj := clientJSON{ return &clientJSON{
Name: c.Name, Name: c.Name,
IDs: c.IDs, IDs: c.IDs,
Tags: c.Tags, Tags: c.Tags,
@ -134,19 +147,6 @@ func clientToJSON(c *Client) clientJSON {
Upstreams: c.Upstreams, Upstreams: c.Upstreams,
} }
return cj
}
// runtimeClientToJSON converts a RuntimeClient into a JSON struct.
func runtimeClientToJSON(ip string, rc RuntimeClient) (cj clientJSON) {
cj = clientJSON{
Name: rc.Host,
IDs: []string{ip},
WHOISInfo: rc.WHOISInfo,
}
return cj
} }
// Add a new client // Add a new client
@ -230,7 +230,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
// Get the list of clients by IP address list // Get the list of clients by IP address list
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) { func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query() q := r.URL.Query()
data := []map[string]clientJSON{} data := []map[string]*clientJSON{}
for i := 0; i < len(q); i++ { for i := 0; i < len(q); i++ {
idStr := q.Get(fmt.Sprintf("ip%d", i)) idStr := q.Get(fmt.Sprintf("ip%d", i))
if idStr == "" { if idStr == "" {
@ -239,20 +239,16 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
ip := net.ParseIP(idStr) ip := net.ParseIP(idStr)
c, ok := clients.Find(idStr) c, ok := clients.Find(idStr)
var cj clientJSON var cj *clientJSON
if !ok { if !ok {
var found bool cj = clients.findRuntime(ip, idStr)
cj, found = clients.findRuntime(ip, idStr)
if !found {
continue
}
} else { } else {
cj = clientToJSON(c) cj = clientToJSON(c)
disallowed, rule := clients.dnsServer.IsBlockedIP(ip) disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
} }
data = append(data, map[string]clientJSON{ data = append(data, map[string]*clientJSON{
idStr: cj, idStr: cj,
}) })
} }
@ -265,39 +261,37 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
} }
// findRuntime looks up the IP in runtime and temporary storages, like // findRuntime looks up the IP in runtime and temporary storages, like
// /etc/hosts tables, DHCP leases, or blocklists. // /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj clientJSON, found bool) { // non-nil.
if ip == nil { func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
return cj, false rc, ok := clients.FindRuntimeClient(ip)
}
rc, ok := clients.FindRuntimeClient(idStr)
if !ok { if !ok {
// It is still possible that the IP used to be in the runtime // It is still possible that the IP used to be in the runtime
// clients list, but then the server was reloaded. So, check // clients list, but then the server was reloaded. So, check
// the DNS server's blocked IP list. // the DNS server's blocked IP list.
// //
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428. // See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.dnsServer.IsBlockedIP(ip) disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
if rule == "" { cj = &clientJSON{
return clientJSON{}, false
}
cj = clientJSON{
IDs: []string{idStr}, IDs: []string{idStr},
Disallowed: &disallowed, Disallowed: &disallowed,
DisallowedRule: &rule, DisallowedRule: &rule,
WHOISInfo: &RuntimeClientWHOISInfo{}, WHOISInfo: &RuntimeClientWHOISInfo{},
} }
return cj, true return cj
} }
cj = runtimeClientToJSON(idStr, rc) cj = &clientJSON{
disallowed, rule := clients.dnsServer.IsBlockedIP(ip) Name: rc.Host,
IDs: []string{idStr},
WHOISInfo: rc.WHOISInfo,
}
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
return cj, true return cj
} }
// RegisterClientsHandlers registers HTTP handlers // RegisterClientsHandlers registers HTTP handlers

View File

@ -105,8 +105,8 @@ func isRunning() bool {
return Context.dnsServer != nil && Context.dnsServer.IsRunning() return Context.dnsServer != nil && Context.dnsServer.IsRunning()
} }
func onDNSRequest(d *proxy.DNSContext) { func onDNSRequest(pctx *proxy.DNSContext) {
ip := aghnet.IPFromAddr(d.Addr) ip := aghnet.IPFromAddr(pctx.Addr)
if ip == nil { if ip == nil {
// This would be quite weird if we get here. // This would be quite weird if we get here.
return return

View File

@ -503,7 +503,7 @@ Please note, that this is crucial for a server to be able to use privileged port
You have two options: You have two options:
1. Run AdGuard Home with root privileges 1. Run AdGuard Home with root privileges
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability: 2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
https://github.com/AdguardTeam/AdGuardHome/internal/wiki/Getting-Started#running-without-superuser` https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`
log.Fatal(msg) log.Fatal(msg)
} }

View File

@ -102,12 +102,7 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) {
func (r *RDNS) Begin(ip net.IP) { func (r *RDNS) Begin(ip net.IP) {
r.ensurePrivateCache() r.ensurePrivateCache()
if r.isCached(ip) { if r.isCached(ip) || r.clients.Exists(ip, ClientSourceRDNS) {
return
}
id := ip.String()
if r.clients.Exists(id, ClientSourceRDNS) {
return return
} }
@ -138,6 +133,6 @@ func (r *RDNS) workerLoop() {
// Don't handle any errors since AddHost doesn't return non-nil // Don't handle any errors since AddHost doesn't return non-nil
// errors for now. // errors for now.
_, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS) _, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
} }
} }

View File

@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
@ -84,7 +85,7 @@ func TestRDNS_Begin(t *testing.T) {
clients: &clientsContainer{ clients: &clientsContainer{
list: map[string]*Client{}, list: map[string]*Client{},
idIndex: tc.cliIDIndex, idIndex: tc.cliIDIndex,
ipToRC: map[string]*RuntimeClient{}, ipToRC: aghnet.NewIPMap(0),
allTags: aghstrings.NewSet(), allTags: aghstrings.NewSet(),
}, },
} }
@ -204,7 +205,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
cc := &clientsContainer{ cc := &clientsContainer{
list: map[string]*Client{}, list: map[string]*Client{},
idIndex: map[string]*Client{}, idIndex: map[string]*Client{},
ipToRC: map[string]*RuntimeClient{}, ipToRC: aghnet.NewIPMap(0),
allTags: aghstrings.NewSet(), allTags: aghstrings.NewSet(),
} }
ch := make(chan net.IP) ch := make(chan net.IP)
@ -236,7 +237,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
return return
} }
assert.True(t, cc.Exists(tc.cliIP.String(), ClientSourceRDNS)) assert.True(t, cc.Exists(tc.cliIP, ClientSourceRDNS))
}) })
} }
} }

View File

@ -252,7 +252,6 @@ func (w *WHOIS) workerLoop() {
continue continue
} }
id := ip.String() w.clients.SetWHOISInfo(ip, info)
w.clients.SetWHOISInfo(id, info)
} }
} }

View File

@ -720,7 +720,10 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
a := convertMapToSlice(m, int(maxCount)) a := convertMapToSlice(m, int(maxCount))
d := []net.IP{} d := []net.IP{}
for _, it := range a { for _, it := range a {
d = append(d, net.ParseIP(it.Name)) ip := net.ParseIP(it.Name)
if ip != nil {
d = append(d, ip)
}
} }
return d return d
} }

View File

@ -4,6 +4,11 @@
## v0.107: API changes ## v0.107: API changes
### Client IDs in Access Settings
* The `POST /control/access/set` HTTP API now accepts client IDs in
`"allowed_clients"` and `"disallowed_clients"` fields.
### The new field `"unicode_name"` in `DNSQuestion` ### The new field `"unicode_name"` in `DNSQuestion`
* The new optional field `"unicode_name"` is the Unicode representation of * The new optional field `"unicode_name"` is the Unicode representation of
@ -17,7 +22,7 @@
### Disabling Statistics ### Disabling Statistics
* The API `POST /control/stats_config` HTTP API allows disabling statistics by * The `POST /control/stats_config` HTTP API allows disabling statistics by
setting `"interval"` to `0`. setting `"interval"` to `0`.
### `POST /control/dhcp/reset_leases` ### `POST /control/dhcp/reset_leases`

View File

@ -1957,10 +1957,7 @@
'disallowed_rule': 'disallowed_rule':
'type': 'string' 'type': 'string'
'description': > 'description': >
The rule due to which the client is disallowed. If disallowed is The rule due to which the client is allowed or blocked.
set to true, and this string is empty, then the client IP is
disallowed by the "allowed IP list", that is it is not included in
the allowed list.
'name': 'name':
'description': > 'description': >
Persistent client's name or an empty string if this is a runtime Persistent client's name or an empty string if this is a runtime
@ -2352,17 +2349,19 @@
'description': 'Client and host access list' 'description': 'Client and host access list'
'properties': 'properties':
'allowed_clients': 'allowed_clients':
'description': 'Allowlist of clients.' 'description': >
The allowlist of clients: IP addresses, CIDRs, or client IDs.
'items': 'items':
'type': 'string' 'type': 'string'
'type': 'array' 'type': 'array'
'disallowed_clients': 'disallowed_clients':
'description': 'Blocklist of clients.' 'description': >
The blocklist of clients: IP addresses, CIDRs, or client IDs.
'items': 'items':
'type': 'string' 'type': 'string'
'type': 'array' 'type': 'array'
'blocked_hosts': 'blocked_hosts':
'description': 'Blocklist of hosts.' 'description': 'The blocklist of hosts.'
'items': 'items':
'type': 'string' 'type': 'string'
'type': 'array' 'type': 'array'