diff --git a/dnsforward/access.go b/dnsforward/access.go new file mode 100644 index 00000000..7f2b4dae --- /dev/null +++ b/dnsforward/access.go @@ -0,0 +1,188 @@ +package dnsforward + +import ( + "encoding/json" + "net" + "net/http" + "sync" + + "github.com/AdguardTeam/golibs/log" +) + +type accessCtx struct { + lock sync.Mutex + + allowedClients map[string]bool // IP addresses of whitelist clients + disallowedClients map[string]bool // IP addresses of clients that should be blocked + + allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients + disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked + + blockedHosts map[string]bool // hosts that should be blocked +} + +func (a *accessCtx) Init(allowedClients, disallowedClients, blockedHosts []string) error { + err := processIPCIDRArray(&a.allowedClients, &a.allowedClientsIPNet, allowedClients) + if err != nil { + return err + } + + err = processIPCIDRArray(&a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients) + if err != nil { + return err + } + + convertArrayToMap(&a.blockedHosts, blockedHosts) + return nil +} + +func convertArrayToMap(dst *map[string]bool, src []string) { + *dst = make(map[string]bool) + for _, s := range src { + (*dst)[s] = true + } +} + +// Split array of IP or CIDR into 2 containers for fast search +func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []string) error { + *dst = make(map[string]bool) + + for _, s := range src { + ip := net.ParseIP(s) + if ip != nil { + (*dst)[s] = true + continue + } + + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + return err + } + *dstIPNet = append(*dstIPNet, *ipnet) + } + + return nil +} + +// Return TRUE if this client should be blocked +func (a *accessCtx) IsBlockedIP(ip string) bool { + a.lock.Lock() + defer a.lock.Unlock() + + if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 { + _, ok := a.allowedClients[ip] + if ok { + return false + } + + if len(a.allowedClientsIPNet) != 0 { + ipAddr := net.ParseIP(ip) + for _, ipnet := range a.allowedClientsIPNet { + if ipnet.Contains(ipAddr) { + return false + } + } + } + + return true + } + + _, ok := a.disallowedClients[ip] + if ok { + return true + } + + if len(a.disallowedClientsIPNet) != 0 { + ipAddr := net.ParseIP(ip) + for _, ipnet := range a.disallowedClientsIPNet { + if ipnet.Contains(ipAddr) { + return true + } + } + } + + return false +} + +// Return TRUE if this domain should be blocked +func (a *accessCtx) IsBlockedDomain(host string) bool { + a.lock.Lock() + _, ok := a.blockedHosts[host] + a.lock.Unlock() + return ok +} + +type accessListJSON struct { + AllowedClients []string `json:"allowed_clients"` + DisallowedClients []string `json:"disallowed_clients"` + BlockedHosts []string `json:"blocked_hosts"` +} + +func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) { + s.Lock() + j := accessListJSON{ + AllowedClients: s.conf.AllowedClients, + DisallowedClients: s.conf.DisallowedClients, + BlockedHosts: s.conf.BlockedHosts, + } + s.Unlock() + + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(j) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) + return + } +} + +func checkIPCIDRArray(src []string) error { + for _, s := range src { + ip := net.ParseIP(s) + if ip != nil { + continue + } + + _, _, err := net.ParseCIDR(s) + if err != nil { + return err + } + } + + return nil +} + +func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) { + j := accessListJSON{} + err := json.NewDecoder(r.Body).Decode(&j) + if err != nil { + httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + return + } + + err = checkIPCIDRArray(j.AllowedClients) + if err == nil { + err = checkIPCIDRArray(j.DisallowedClients) + } + if err != nil { + httpError(r, w, http.StatusBadRequest, "%s", err) + return + } + + a := &accessCtx{} + err = a.Init(j.AllowedClients, j.DisallowedClients, j.BlockedHosts) + if err != nil { + httpError(r, w, http.StatusBadRequest, "access.Init: %s", err) + return + } + + s.Lock() + s.conf.AllowedClients = j.AllowedClients + s.conf.DisallowedClients = j.DisallowedClients + s.conf.BlockedHosts = j.BlockedHosts + s.access = a + s.Unlock() + s.conf.ConfigModified() + + log.Debug("Access: updated lists: %d, %d, %d", + len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts)) +} diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index cace0cbc..73975d78 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -44,12 +44,7 @@ type Server struct { dnsFilter *dnsfilter.Dnsfilter // DNS filter instance queryLog querylog.QueryLog // Query log instance stats stats.Stats - - AllowedClients map[string]bool // IP addresses of whitelist clients - DisallowedClients map[string]bool // IP addresses of clients that should be blocked - AllowedClientsIPNet []net.IPNet // CIDRs of whitelist clients - DisallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked - BlockedHosts map[string]bool // hosts that should be blocked + access *accessCtx webRegistered bool @@ -174,34 +169,6 @@ func (s *Server) Start(config *ServerConfig) error { return s.startInternal(config) } -func convertArrayToMap(dst *map[string]bool, src []string) { - *dst = make(map[string]bool) - for _, s := range src { - (*dst)[s] = true - } -} - -// Split array of IP or CIDR into 2 containers for fast search -func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []string) error { - *dst = make(map[string]bool) - - for _, s := range src { - ip := net.ParseIP(s) - if ip != nil { - (*dst)[s] = true - continue - } - - _, ipnet, err := net.ParseCIDR(s) - if err != nil { - return err - } - *dstIPNet = append(*dstIPNet, *ipnet) - } - - return nil -} - // startInternal starts without locking func (s *Server) startInternal(config *ServerConfig) error { if s.dnsProxy != nil { @@ -240,18 +207,12 @@ func (s *Server) startInternal(config *ServerConfig) error { AllServers: s.conf.AllServers, } - err := processIPCIDRArray(&s.AllowedClients, &s.AllowedClientsIPNet, s.conf.AllowedClients) + s.access = &accessCtx{} + err := s.access.Init(s.conf.AllowedClients, s.conf.DisallowedClients, s.conf.BlockedHosts) if err != nil { return err } - err = processIPCIDRArray(&s.DisallowedClients, &s.DisallowedClientsIPNet, s.conf.DisallowedClients) - if err != nil { - return err - } - - convertArrayToMap(&s.BlockedHosts, s.conf.BlockedHosts) - if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 { proxyConfig.TLSListenAddr = s.conf.TLSListenAddr keypair, err := tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData) @@ -369,59 +330,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.RUnlock() } -// Return TRUE if this client should be blocked -func (s *Server) isBlockedIP(ip string) bool { - if len(s.AllowedClients) != 0 || len(s.AllowedClientsIPNet) != 0 { - _, ok := s.AllowedClients[ip] - if ok { - return false - } - - if len(s.AllowedClientsIPNet) != 0 { - ipAddr := net.ParseIP(ip) - for _, ipnet := range s.AllowedClientsIPNet { - if ipnet.Contains(ipAddr) { - return false - } - } - } - - return true - } - - _, ok := s.DisallowedClients[ip] - if ok { - return true - } - - if len(s.DisallowedClientsIPNet) != 0 { - ipAddr := net.ParseIP(ip) - for _, ipnet := range s.DisallowedClientsIPNet { - if ipnet.Contains(ipAddr) { - return true - } - } - } - - return false -} - -// Return TRUE if this domain should be blocked -func (s *Server) isBlockedDomain(host string) bool { - _, ok := s.BlockedHosts[host] - return ok -} - func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool, error) { ip, _, _ := net.SplitHostPort(d.Addr.String()) - if s.isBlockedIP(ip) { + if s.access.IsBlockedIP(ip) { log.Tracef("Client IP %s is blocked by settings", ip) return false, nil } if len(d.Req.Question) == 1 { host := strings.TrimSuffix(d.Req.Question[0].Name, ".") - if s.isBlockedDomain(host) { + if s.access.IsBlockedDomain(host) { log.Tracef("Domain %s is blocked by settings", host) return false, nil } diff --git a/dnsforward/dnsforward_http.go b/dnsforward/dnsforward_http.go index dec75713..13a9a68d 100644 --- a/dnsforward/dnsforward_http.go +++ b/dnsforward/dnsforward_http.go @@ -283,4 +283,8 @@ func (s *Server) registerHandlers() { s.conf.HTTPRegister("POST", "/control/disable_protection", s.handleProtectionDisable) s.conf.HTTPRegister("POST", "/control/set_upstreams_config", s.handleSetUpstreamConfig) s.conf.HTTPRegister("POST", "/control/test_upstream_dns", s.handleTestUpstreamDNS) + + s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList) + s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet) + } diff --git a/home/control.go b/home/control.go index 97a8d232..a4ce2e0e 100644 --- a/home/control.go +++ b/home/control.go @@ -164,8 +164,6 @@ func registerControlHandlers() { http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) httpRegister(http.MethodPost, "/control/update", handleUpdate) - httpRegister(http.MethodGet, "/control/access/list", handleAccessList) - httpRegister(http.MethodPost, "/control/access/set", handleAccessSet) httpRegister("GET", "/control/profile", handleGetProfile) RegisterFilteringHandlers() diff --git a/home/control_access.go b/home/control_access.go deleted file mode 100644 index abc51101..00000000 --- a/home/control_access.go +++ /dev/null @@ -1,84 +0,0 @@ -package home - -import ( - "encoding/json" - "net" - "net/http" - - "github.com/AdguardTeam/golibs/log" -) - -type accessListJSON struct { - AllowedClients []string `json:"allowed_clients"` - DisallowedClients []string `json:"disallowed_clients"` - BlockedHosts []string `json:"blocked_hosts"` -} - -func handleAccessList(w http.ResponseWriter, r *http.Request) { - config.controlLock.Lock() - j := accessListJSON{ - AllowedClients: config.DNS.AllowedClients, - DisallowedClients: config.DNS.DisallowedClients, - BlockedHosts: config.DNS.BlockedHosts, - } - config.controlLock.Unlock() - - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(j) - if err != nil { - httpError(w, http.StatusInternalServerError, "json.Encode: %s", err) - return - } -} - -func checkIPCIDRArray(src []string) error { - for _, s := range src { - ip := net.ParseIP(s) - if ip != nil { - continue - } - - _, _, err := net.ParseCIDR(s) - if err != nil { - return err - } - } - - return nil -} - -func handleAccessSet(w http.ResponseWriter, r *http.Request) { - - j := accessListJSON{} - err := json.NewDecoder(r.Body).Decode(&j) - if err != nil { - httpError(w, http.StatusBadRequest, "json.Decode: %s", err) - return - } - - err = checkIPCIDRArray(j.AllowedClients) - if err == nil { - err = checkIPCIDRArray(j.DisallowedClients) - } - if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) - return - } - - config.Lock() - config.DNS.AllowedClients = j.AllowedClients - config.DNS.DisallowedClients = j.DisallowedClients - config.DNS.BlockedHosts = j.BlockedHosts - config.Unlock() - - log.Tracef("Update access lists: %d, %d, %d", - len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts)) - - err = writeAllConfigsAndReloadDNS() - if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) - return - } - - returnOK(w) -}