From 452a668a5ba66efa265621104fe3be17ae156ed5 Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Wed, 21 Aug 2019 14:39:37 +0300 Subject: [PATCH] * control: refactor: all handlers are registered via httpRegister() * move some code to home/control_filtering.go --- home/blocked_services.go | 8 +- home/clients.go | 13 +- home/control.go | 351 +++++--------------------------------- home/control_access.go | 3 - home/control_filtering.go | 237 +++++++++++++++++++++++++ home/control_install.go | 3 - home/control_tls.go | 9 +- home/control_update.go | 4 +- home/dhcp.go | 6 - home/dns_rewrites.go | 9 +- home/helpers.go | 6 +- 11 files changed, 301 insertions(+), 348 deletions(-) create mode 100644 home/control_filtering.go diff --git a/home/blocked_services.go b/home/blocked_services.go index 23b74c51..5f9e58b9 100644 --- a/home/blocked_services.go +++ b/home/blocked_services.go @@ -73,8 +73,6 @@ func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []stri } func handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - config.RLock() list := config.DNS.BlockedServices config.RUnlock() @@ -88,8 +86,6 @@ func handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { } func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - list := []string{} err := json.NewDecoder(r.Body).Decode(&list) if err != nil { @@ -114,6 +110,6 @@ func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { // RegisterBlockedServicesHandlers - register HTTP handlers func RegisterBlockedServicesHandlers() { - http.HandleFunc("/control/blocked_services/list", postInstall(optionalAuth(ensureGET(handleBlockedServicesList)))) - http.HandleFunc("/control/blocked_services/set", postInstall(optionalAuth(ensurePOST(handleBlockedServicesSet)))) + httpRegister(http.MethodGet, "/control/blocked_services/list", handleBlockedServicesList) + httpRegister(http.MethodPost, "/control/blocked_services/set", handleBlockedServicesSet) } diff --git a/home/clients.go b/home/clients.go index 66253844..1de77a79 100644 --- a/home/clients.go +++ b/home/clients.go @@ -382,8 +382,6 @@ type clientListJSON struct { // respond with information about configured clients func handleGetClients(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - data := clientListJSON{} config.clients.lock.Lock() @@ -456,7 +454,6 @@ func jsonToClient(cj clientJSON) (*Client, error) { // Add a new client func handleAddClient(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) body, err := ioutil.ReadAll(r.Body) if err != nil { httpError(w, http.StatusBadRequest, "failed to read request body: %s", err) @@ -491,7 +488,6 @@ func handleAddClient(w http.ResponseWriter, r *http.Request) { // Remove client func handleDelClient(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) body, err := ioutil.ReadAll(r.Body) if err != nil { httpError(w, http.StatusBadRequest, "failed to read request body: %s", err) @@ -521,7 +517,6 @@ type updateJSON struct { // Update client's properties func handleUpdateClient(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) body, err := ioutil.ReadAll(r.Body) if err != nil { httpError(w, http.StatusBadRequest, "failed to read request body: %s", err) @@ -557,8 +552,8 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) { // RegisterClientsHandlers registers HTTP handlers func RegisterClientsHandlers() { - http.HandleFunc("/control/clients", postInstall(optionalAuth(ensureGET(handleGetClients)))) - http.HandleFunc("/control/clients/add", postInstall(optionalAuth(ensurePOST(handleAddClient)))) - http.HandleFunc("/control/clients/delete", postInstall(optionalAuth(ensurePOST(handleDelClient)))) - http.HandleFunc("/control/clients/update", postInstall(optionalAuth(ensurePOST(handleUpdateClient)))) + httpRegister(http.MethodGet, "/control/clients", handleGetClients) + httpRegister(http.MethodPost, "/control/clients/add", handleAddClient) + httpRegister(http.MethodPost, "/control/clients/delete", handleDelClient) + httpRegister(http.MethodPost, "/control/clients/update", handleUpdateClient) } diff --git a/home/control.go b/home/control.go index 8f528d77..6601b364 100644 --- a/home/control.go +++ b/home/control.go @@ -4,10 +4,8 @@ import ( "bytes" "encoding/json" "fmt" - "io/ioutil" "net" "net/http" - "os" "sort" "strconv" "strings" @@ -19,7 +17,6 @@ import ( "github.com/AdguardTeam/golibs/utils" "github.com/NYTimes/gziphandler" "github.com/miekg/dns" - govalidator "gopkg.in/asaskevich/govalidator.v4" ) const updatePeriod = time.Hour * 24 @@ -114,8 +111,6 @@ func getDNSAddresses() []string { } func handleStatus(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - data := map[string]interface{}{ "dns_addresses": getDNSAddresses(), "http_port": config.BindPort, @@ -144,13 +139,11 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { } func handleProtectionEnable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) config.DNS.ProtectionEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleProtectionDisable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) config.DNS.ProtectionEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } @@ -159,19 +152,16 @@ func handleProtectionDisable(w http.ResponseWriter, r *http.Request) { // stats // ----- func handleQueryLogEnable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) config.DNS.QueryLogEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) config.DNS.QueryLogEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleQueryLog(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) data := config.dnsServer.GetQueryLog() jsonVal, err := json.Marshal(data) @@ -188,7 +178,6 @@ func handleQueryLog(w http.ResponseWriter, r *http.Request) { } func handleStatsTop(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) s := config.dnsServer.GetStatsTop() // use manual json marshalling because we want maps to be sorted by value @@ -235,7 +224,6 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) { // handleStatsReset resets the stats caches func handleStatsReset(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) config.dnsServer.PurgeStats() _, err := fmt.Fprintf(w, "OK\n") if err != nil { @@ -245,7 +233,6 @@ func handleStatsReset(w http.ResponseWriter, r *http.Request) { // handleStats returns aggregated stats data for the 24 hours func handleStats(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) summed := config.dnsServer.GetAggregatedStats() statsJSON, err := json.Marshal(summed) @@ -263,7 +250,6 @@ func handleStats(w http.ResponseWriter, r *http.Request) { // HandleStatsHistory returns historical stats data for the 24 hours func handleStatsHistory(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) // handle time unit and prepare our time window size timeUnitString := r.URL.Query().Get("time_unit") var timeUnit time.Duration @@ -346,7 +332,6 @@ type upstreamConfig struct { } func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) newconfig := upstreamConfig{} err := json.NewDecoder(r.Body).Decode(&newconfig) if err != nil { @@ -484,7 +469,6 @@ func checkPlainDNS(upstream string) error { } func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) upstreamConfig := upstreamConfig{} err := json.NewDecoder(r.Body).Decode(&upstreamConfig) if err != nil { @@ -571,261 +555,21 @@ func checkDNS(input string, bootstrap []string) error { return nil } -// --------- -// filtering -// --------- - -func handleFilteringEnable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - config.DNS.FilteringEnabled = true - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleFilteringDisable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - config.DNS.FilteringEnabled = false - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - data := map[string]interface{}{ - "enabled": config.DNS.FilteringEnabled, - } - - config.RLock() - data["filters"] = config.Filters - data["user_rules"] = config.UserRules - jsonVal, err := json.Marshal(data) - config.RUnlock() - - if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err) - return - } -} - -func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - f := filter{} - err := json.NewDecoder(r.Body).Decode(&f) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) - return - } - - if len(f.URL) == 0 { - http.Error(w, "URL parameter was not specified", http.StatusBadRequest) - return - } - - if valid := govalidator.IsRequestURL(f.URL); !valid { - http.Error(w, "URL parameter is not valid request URL", http.StatusBadRequest) - return - } - - // Check for duplicates - if filterExists(f.URL) { - httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", f.URL) - return - } - - // Set necessary properties - f.ID = assignUniqueFilterID() - f.Enabled = true - - // Download the filter contents - ok, err := f.update() - if err != nil { - httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", f.URL, err) - return - } - if f.RulesCount == 0 { - httpError(w, http.StatusBadRequest, "Filter at the url %s has no rules (maybe it points to blank page?)", f.URL) - return - } - if !ok { - httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", f.URL) - return - } - - // Save the filter contents - err = f.save() - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to save filter %d due to %s", f.ID, err) - return - } - - // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it - // TODO: since we directly feed filters in-memory, revisit if writing configs is always necessary - if !filterAdd(f) { - httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", f.URL) - return - } - - err = writeAllConfigs() - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) - return - } - - err = reconfigureDNSServer() - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't reconfigure the DNS server: %s", err) - return - } - - _, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount) - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) - } -} - -func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - - type request struct { - URL string `json:"url"` - } - req := request{} - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) - return - } - - if valid := govalidator.IsRequestURL(req.URL); !valid { - http.Error(w, "URL parameter is not valid request URL", http.StatusBadRequest) - return - } - - // Stop DNS server: - // we close urlfilter object which in turn closes file descriptors to filter files. - // Otherwise, Windows won't allow us to remove the file which is being currently used. - _ = config.dnsServer.Stop() - - // go through each element and delete if url matches - config.Lock() - newFilters := config.Filters[:0] - for _, filter := range config.Filters { - if filter.URL != req.URL { - newFilters = append(newFilters, filter) - } else { - // Remove the filter file - err := os.Remove(filter.Path()) - if err != nil && !os.IsNotExist(err) { - config.Unlock() - httpError(w, http.StatusInternalServerError, "Couldn't remove the filter file: %s", err) - return - } - log.Debug("os.Remove(%s)", filter.Path()) - } - } - // Update the configuration after removing filter files - config.Filters = newFilters - config.Unlock() - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - parameters, err := parseParametersFromBody(r.Body) - if err != nil { - httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err) - return - } - - url, ok := parameters["url"] - if !ok { - http.Error(w, "URL parameter was not specified", http.StatusBadRequest) - return - } - - if valid := govalidator.IsRequestURL(url); !valid { - http.Error(w, "URL parameter is not valid request URL", http.StatusBadRequest) - return - } - - found := filterEnable(url, true) - if !found { - http.Error(w, "URL parameter was not previously added", http.StatusBadRequest) - return - } - - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - parameters, err := parseParametersFromBody(r.Body) - if err != nil { - httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err) - return - } - - url, ok := parameters["url"] - if !ok { - http.Error(w, "URL parameter was not specified", http.StatusBadRequest) - return - } - - if valid := govalidator.IsRequestURL(url); !valid { - http.Error(w, "URL parameter is not valid request URL", http.StatusBadRequest) - return - } - - found := filterEnable(url, false) - if !found { - http.Error(w, "URL parameter was not previously added", http.StatusBadRequest) - return - } - - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - body, err := ioutil.ReadAll(r.Body) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) - return - } - - config.UserRules = strings.Split(string(body), "\n") - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - updated := refreshFiltersIfNecessary(true) - fmt.Fprintf(w, "OK %d filters updated\n", updated) -} - // ------------ // safebrowsing // ------------ func handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) config.DNS.SafeBrowsingEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) config.DNS.SafeBrowsingEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) data := map[string]interface{}{ "enabled": config.DNS.SafeBrowsingEnabled, } @@ -846,7 +590,6 @@ func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) { // parental // -------- func handleParentalEnable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) parameters, err := parseParametersFromBody(r.Body) if err != nil { httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err) @@ -891,13 +634,11 @@ func handleParentalEnable(w http.ResponseWriter, r *http.Request) { } func handleParentalDisable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) config.DNS.ParentalEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleParentalStatus(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) data := map[string]interface{}{ "enabled": config.DNS.ParentalEnabled, } @@ -923,19 +664,16 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) { // ------------ func handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) config.DNS.SafeSearchEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) config.DNS.SafeSearchEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) data := map[string]interface{}{ "enabled": config.DNS.SafeSearchEnabled, } @@ -957,7 +695,6 @@ func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { // DNS-over-HTTPS // -------------- func handleDOH(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) if r.TLS == nil { httpError(w, http.StatusNotFound, "Not Found") return @@ -975,49 +712,49 @@ func handleDOH(w http.ResponseWriter, r *http.Request) { // registration of handlers // ------------------------ func registerControlHandlers() { - http.HandleFunc("/control/status", postInstall(optionalAuth(ensureGET(handleStatus)))) - http.HandleFunc("/control/enable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionEnable)))) - http.HandleFunc("/control/disable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionDisable)))) - http.Handle("/control/querylog", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureGETHandler(handleQueryLog))))) - http.HandleFunc("/control/querylog_enable", postInstall(optionalAuth(ensurePOST(handleQueryLogEnable)))) - http.HandleFunc("/control/querylog_disable", postInstall(optionalAuth(ensurePOST(handleQueryLogDisable)))) - http.HandleFunc("/control/set_upstreams_config", postInstall(optionalAuth(ensurePOST(handleSetUpstreamConfig)))) - http.HandleFunc("/control/test_upstream_dns", postInstall(optionalAuth(ensurePOST(handleTestUpstreamDNS)))) - http.HandleFunc("/control/i18n/change_language", postInstall(optionalAuth(ensurePOST(handleI18nChangeLanguage)))) - http.HandleFunc("/control/i18n/current_language", postInstall(optionalAuth(ensureGET(handleI18nCurrentLanguage)))) - http.HandleFunc("/control/stats_top", postInstall(optionalAuth(ensureGET(handleStatsTop)))) - http.HandleFunc("/control/stats", postInstall(optionalAuth(ensureGET(handleStats)))) - http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(handleStatsHistory)))) - http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(handleStatsReset)))) + httpRegister(http.MethodGet, "/control/status", handleStatus) + httpRegister(http.MethodPost, "/control/enable_protection", handleProtectionEnable) + httpRegister(http.MethodPost, "/control/disable_protection", handleProtectionDisable) + httpRegister(http.MethodGet, "/control/querylog", handleQueryLog) + httpRegister(http.MethodPost, "/control/querylog_enable", handleQueryLogEnable) + httpRegister(http.MethodPost, "/control/querylog_disable", handleQueryLogDisable) + httpRegister(http.MethodPost, "/control/set_upstreams_config", handleSetUpstreamConfig) + httpRegister(http.MethodPost, "/control/test_upstream_dns", handleTestUpstreamDNS) + httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage) + httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage) + httpRegister(http.MethodGet, "/control/stats_top", handleStatsTop) + httpRegister(http.MethodGet, "/control/stats", handleStats) + httpRegister(http.MethodGet, "/control/stats_history", handleStatsHistory) + httpRegister(http.MethodPost, "/control/stats_reset", handleStatsReset) http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) - http.HandleFunc("/control/update", postInstall(optionalAuth(ensurePOST(handleUpdate)))) - http.HandleFunc("/control/filtering/enable", postInstall(optionalAuth(ensurePOST(handleFilteringEnable)))) - http.HandleFunc("/control/filtering/disable", postInstall(optionalAuth(ensurePOST(handleFilteringDisable)))) - http.HandleFunc("/control/filtering/add_url", postInstall(optionalAuth(ensurePOST(handleFilteringAddURL)))) - http.HandleFunc("/control/filtering/remove_url", postInstall(optionalAuth(ensurePOST(handleFilteringRemoveURL)))) - http.HandleFunc("/control/filtering/enable_url", postInstall(optionalAuth(ensurePOST(handleFilteringEnableURL)))) - http.HandleFunc("/control/filtering/disable_url", postInstall(optionalAuth(ensurePOST(handleFilteringDisableURL)))) - http.HandleFunc("/control/filtering/refresh", postInstall(optionalAuth(ensurePOST(handleFilteringRefresh)))) - http.HandleFunc("/control/filtering/status", postInstall(optionalAuth(ensureGET(handleFilteringStatus)))) - http.HandleFunc("/control/filtering/set_rules", postInstall(optionalAuth(ensurePOST(handleFilteringSetRules)))) - http.HandleFunc("/control/safebrowsing/enable", postInstall(optionalAuth(ensurePOST(handleSafeBrowsingEnable)))) - http.HandleFunc("/control/safebrowsing/disable", postInstall(optionalAuth(ensurePOST(handleSafeBrowsingDisable)))) - http.HandleFunc("/control/safebrowsing/status", postInstall(optionalAuth(ensureGET(handleSafeBrowsingStatus)))) - http.HandleFunc("/control/parental/enable", postInstall(optionalAuth(ensurePOST(handleParentalEnable)))) - http.HandleFunc("/control/parental/disable", postInstall(optionalAuth(ensurePOST(handleParentalDisable)))) - http.HandleFunc("/control/parental/status", postInstall(optionalAuth(ensureGET(handleParentalStatus)))) - http.HandleFunc("/control/safesearch/enable", postInstall(optionalAuth(ensurePOST(handleSafeSearchEnable)))) - http.HandleFunc("/control/safesearch/disable", postInstall(optionalAuth(ensurePOST(handleSafeSearchDisable)))) - http.HandleFunc("/control/safesearch/status", postInstall(optionalAuth(ensureGET(handleSafeSearchStatus)))) - http.HandleFunc("/control/dhcp/status", postInstall(optionalAuth(ensureGET(handleDHCPStatus)))) - http.HandleFunc("/control/dhcp/interfaces", postInstall(optionalAuth(ensureGET(handleDHCPInterfaces)))) - http.HandleFunc("/control/dhcp/set_config", postInstall(optionalAuth(ensurePOST(handleDHCPSetConfig)))) - http.HandleFunc("/control/dhcp/find_active_dhcp", postInstall(optionalAuth(ensurePOST(handleDHCPFindActiveServer)))) - http.HandleFunc("/control/dhcp/add_static_lease", postInstall(optionalAuth(ensurePOST(handleDHCPAddStaticLease)))) - http.HandleFunc("/control/dhcp/remove_static_lease", postInstall(optionalAuth(ensurePOST(handleDHCPRemoveStaticLease)))) + httpRegister(http.MethodPost, "/control/update", handleUpdate) + httpRegister(http.MethodPost, "/control/filtering/enable", handleFilteringEnable) + httpRegister(http.MethodPost, "/control/filtering/disable", handleFilteringDisable) + httpRegister(http.MethodPost, "/control/filtering/add_url", handleFilteringAddURL) + httpRegister(http.MethodPost, "/control/filtering/remove_url", handleFilteringRemoveURL) + httpRegister(http.MethodPost, "/control/filtering/enable_url", handleFilteringEnableURL) + httpRegister(http.MethodPost, "/control/filtering/disable_url", handleFilteringDisableURL) + httpRegister(http.MethodPost, "/control/filtering/refresh", handleFilteringRefresh) + httpRegister(http.MethodGet, "/control/filtering/status", handleFilteringStatus) + httpRegister(http.MethodPost, "/control/filtering/set_rules", handleFilteringSetRules) + httpRegister(http.MethodPost, "/control/safebrowsing/enable", handleSafeBrowsingEnable) + httpRegister(http.MethodPost, "/control/safebrowsing/disable", handleSafeBrowsingDisable) + httpRegister(http.MethodGet, "/control/safebrowsing/status", handleSafeBrowsingStatus) + httpRegister(http.MethodPost, "/control/parental/enable", handleParentalEnable) + httpRegister(http.MethodPost, "/control/parental/disable", handleParentalDisable) + httpRegister(http.MethodGet, "/control/parental/status", handleParentalStatus) + httpRegister(http.MethodPost, "/control/safesearch/enable", handleSafeSearchEnable) + httpRegister(http.MethodPost, "/control/safesearch/disable", handleSafeSearchDisable) + httpRegister(http.MethodGet, "/control/safesearch/status", handleSafeSearchStatus) + httpRegister(http.MethodGet, "/control/dhcp/status", handleDHCPStatus) + httpRegister(http.MethodGet, "/control/dhcp/interfaces", handleDHCPInterfaces) + httpRegister(http.MethodPost, "/control/dhcp/set_config", handleDHCPSetConfig) + httpRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", handleDHCPFindActiveServer) + httpRegister(http.MethodPost, "/control/dhcp/add_static_lease", handleDHCPAddStaticLease) + httpRegister(http.MethodPost, "/control/dhcp/remove_static_lease", handleDHCPRemoveStaticLease) - http.HandleFunc("/control/access/list", postInstall(optionalAuth(ensureGET(handleAccessList)))) - http.HandleFunc("/control/access/set", postInstall(optionalAuth(ensurePOST(handleAccessSet)))) + httpRegister(http.MethodGet, "/control/access/list", handleAccessList) + httpRegister(http.MethodPost, "/control/access/set", handleAccessSet) RegisterTLSHandlers() RegisterClientsHandlers() @@ -1026,3 +763,9 @@ func registerControlHandlers() { http.HandleFunc("/dns-query", postInstall(handleDOH)) } + +type httpHandlerType func(http.ResponseWriter, *http.Request) + +func httpRegister(method string, url string, handler httpHandlerType) { + http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler))))) +} diff --git a/home/control_access.go b/home/control_access.go index f33df8b4..abc51101 100644 --- a/home/control_access.go +++ b/home/control_access.go @@ -15,8 +15,6 @@ type accessListJSON struct { } func handleAccessList(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - config.controlLock.Lock() j := accessListJSON{ AllowedClients: config.DNS.AllowedClients, @@ -50,7 +48,6 @@ func checkIPCIDRArray(src []string) error { } func handleAccessSet(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) j := accessListJSON{} err := json.NewDecoder(r.Body).Decode(&j) diff --git a/home/control_filtering.go b/home/control_filtering.go new file mode 100644 index 00000000..8eded166 --- /dev/null +++ b/home/control_filtering.go @@ -0,0 +1,237 @@ +package home + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "strings" + + "github.com/AdguardTeam/golibs/log" + "github.com/asaskevich/govalidator" +) + +func handleFilteringEnable(w http.ResponseWriter, r *http.Request) { + config.DNS.FilteringEnabled = true + httpUpdateConfigReloadDNSReturnOK(w, r) +} + +func handleFilteringDisable(w http.ResponseWriter, r *http.Request) { + config.DNS.FilteringEnabled = false + httpUpdateConfigReloadDNSReturnOK(w, r) +} + +func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { + data := map[string]interface{}{ + "enabled": config.DNS.FilteringEnabled, + } + + config.RLock() + data["filters"] = config.Filters + data["user_rules"] = config.UserRules + jsonVal, err := json.Marshal(data) + config.RUnlock() + + if err != nil { + httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonVal) + if err != nil { + httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err) + return + } +} + +func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { + f := filter{} + err := json.NewDecoder(r.Body).Decode(&f) + if err != nil { + httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) + return + } + + if len(f.URL) == 0 { + http.Error(w, "URL parameter was not specified", http.StatusBadRequest) + return + } + + if valid := govalidator.IsRequestURL(f.URL); !valid { + http.Error(w, "URL parameter is not valid request URL", http.StatusBadRequest) + return + } + + // Check for duplicates + if filterExists(f.URL) { + httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", f.URL) + return + } + + // Set necessary properties + f.ID = assignUniqueFilterID() + f.Enabled = true + + // Download the filter contents + ok, err := f.update() + if err != nil { + httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", f.URL, err) + return + } + if f.RulesCount == 0 { + httpError(w, http.StatusBadRequest, "Filter at the url %s has no rules (maybe it points to blank page?)", f.URL) + return + } + if !ok { + httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", f.URL) + return + } + + // Save the filter contents + err = f.save() + if err != nil { + httpError(w, http.StatusBadRequest, "Failed to save filter %d due to %s", f.ID, err) + return + } + + // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it + // TODO: since we directly feed filters in-memory, revisit if writing configs is always necessary + if !filterAdd(f) { + httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", f.URL) + return + } + + err = writeAllConfigs() + if err != nil { + httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) + return + } + + err = reconfigureDNSServer() + if err != nil { + httpError(w, http.StatusInternalServerError, "Couldn't reconfigure the DNS server: %s", err) + return + } + + _, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount) + if err != nil { + httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) + } +} + +func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { + + type request struct { + URL string `json:"url"` + } + req := request{} + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) + return + } + + if valid := govalidator.IsRequestURL(req.URL); !valid { + http.Error(w, "URL parameter is not valid request URL", http.StatusBadRequest) + return + } + + // Stop DNS server: + // we close urlfilter object which in turn closes file descriptors to filter files. + // Otherwise, Windows won't allow us to remove the file which is being currently used. + _ = config.dnsServer.Stop() + + // go through each element and delete if url matches + config.Lock() + newFilters := config.Filters[:0] + for _, filter := range config.Filters { + if filter.URL != req.URL { + newFilters = append(newFilters, filter) + } else { + // Remove the filter file + err := os.Remove(filter.Path()) + if err != nil && !os.IsNotExist(err) { + config.Unlock() + httpError(w, http.StatusInternalServerError, "Couldn't remove the filter file: %s", err) + return + } + log.Debug("os.Remove(%s)", filter.Path()) + } + } + // Update the configuration after removing filter files + config.Filters = newFilters + config.Unlock() + httpUpdateConfigReloadDNSReturnOK(w, r) +} + +func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) { + parameters, err := parseParametersFromBody(r.Body) + if err != nil { + httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err) + return + } + + url, ok := parameters["url"] + if !ok { + http.Error(w, "URL parameter was not specified", http.StatusBadRequest) + return + } + + if valid := govalidator.IsRequestURL(url); !valid { + http.Error(w, "URL parameter is not valid request URL", http.StatusBadRequest) + return + } + + found := filterEnable(url, true) + if !found { + http.Error(w, "URL parameter was not previously added", http.StatusBadRequest) + return + } + + httpUpdateConfigReloadDNSReturnOK(w, r) +} + +func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) { + parameters, err := parseParametersFromBody(r.Body) + if err != nil { + httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err) + return + } + + url, ok := parameters["url"] + if !ok { + http.Error(w, "URL parameter was not specified", http.StatusBadRequest) + return + } + + if valid := govalidator.IsRequestURL(url); !valid { + http.Error(w, "URL parameter is not valid request URL", http.StatusBadRequest) + return + } + + found := filterEnable(url, false) + if !found { + http.Error(w, "URL parameter was not previously added", http.StatusBadRequest) + return + } + + httpUpdateConfigReloadDNSReturnOK(w, r) +} + +func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) + return + } + + config.UserRules = strings.Split(string(body), "\n") + httpUpdateConfigReloadDNSReturnOK(w, r) +} + +func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { + updated := refreshFiltersIfNecessary(true) + fmt.Fprintf(w, "OK %d filters updated\n", updated) +} diff --git a/home/control_install.go b/home/control_install.go index 30c48cca..e1193f6b 100644 --- a/home/control_install.go +++ b/home/control_install.go @@ -20,7 +20,6 @@ type firstRunData struct { // Get initial installation settings func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) data := firstRunData{} data.WebPort = 80 data.DNSPort = 53 @@ -65,7 +64,6 @@ type checkConfigResp struct { // Check if ports are available, respond with results func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) reqData := checkConfigReq{} respData := checkConfigResp{} err := json.NewDecoder(r.Body).Decode(&reqData) @@ -190,7 +188,6 @@ func copyInstallSettings(dst *configuration, src *configuration) { // Apply new configuration, start DNS server, restart Web server func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) newSettings := applyConfigReq{} err := json.NewDecoder(r.Body).Decode(&newSettings) if err != nil { diff --git a/home/control_tls.go b/home/control_tls.go index ea83dba1..f1f3d98d 100644 --- a/home/control_tls.go +++ b/home/control_tls.go @@ -25,18 +25,16 @@ import ( // RegisterTLSHandlers registers HTTP handlers for TLS configuration func RegisterTLSHandlers() { - http.HandleFunc("/control/tls/status", postInstall(optionalAuth(ensureGET(handleTLSStatus)))) - http.HandleFunc("/control/tls/configure", postInstall(optionalAuth(ensurePOST(handleTLSConfigure)))) - http.HandleFunc("/control/tls/validate", postInstall(optionalAuth(ensurePOST(handleTLSValidate)))) + httpRegister(http.MethodGet, "/control/tls/status", handleTLSStatus) + httpRegister(http.MethodPost, "/control/tls/configure", handleTLSConfigure) + httpRegister(http.MethodPost, "/control/tls/validate", handleTLSValidate) } func handleTLSStatus(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) marshalTLS(w, config.TLS) } func handleTLSValidate(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) data, err := unmarshalTLS(r) if err != nil { httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) @@ -62,7 +60,6 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) { } func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) data, err := unmarshalTLS(r) if err != nil { httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) diff --git a/home/control_update.go b/home/control_update.go index 3e02d71d..163c3acd 100644 --- a/home/control_update.go +++ b/home/control_update.go @@ -57,10 +57,9 @@ type getVersionJSONRequest struct { // Get the latest available version from the Internet func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) if config.disableUpdate { - log.Tracef("New app version check is disabled by user") + httpError(w, http.StatusInternalServerError, "New app version check is disabled by user") return } @@ -499,7 +498,6 @@ func finishUpdate(u *updateInfo) { // Perform an update procedure to the latest available version func handleUpdate(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) if len(config.versionCheckJSON) == 0 { httpError(w, http.StatusBadRequest, "/update request isn't allowed now") diff --git a/home/dhcp.go b/home/dhcp.go index 6d828d01..c184b661 100644 --- a/home/dhcp.go +++ b/home/dhcp.go @@ -38,7 +38,6 @@ func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string] } func handleDHCPStatus(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) leases := convertLeases(config.dhcpServer.Leases(), true) staticLeases := convertLeases(config.dhcpServer.StaticLeases(), false) status := map[string]interface{}{ @@ -67,7 +66,6 @@ type dhcpServerConfigJSON struct { } func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) newconfig := dhcpServerConfigJSON{} err := json.NewDecoder(r.Body).Decode(&newconfig) if err != nil { @@ -116,7 +114,6 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { } func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) response := map[string]interface{}{} ifaces, err := getValidNetInterfaces() @@ -181,7 +178,6 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { // . Check if a static IP is configured for the network interface // Respond with results func handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) body, err := ioutil.ReadAll(r.Body) if err != nil { errorText := fmt.Sprintf("failed to read request body: %s", err) @@ -366,7 +362,6 @@ func setStaticIP(ifaceName string) error { } func handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) lj := leaseJSON{} err := json.NewDecoder(r.Body).Decode(&lj) @@ -397,7 +392,6 @@ func handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) { } func handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) lj := leaseJSON{} err := json.NewDecoder(r.Body).Decode(&lj) diff --git a/home/dns_rewrites.go b/home/dns_rewrites.go index 816739ab..e58c50d7 100644 --- a/home/dns_rewrites.go +++ b/home/dns_rewrites.go @@ -14,7 +14,6 @@ type rewriteEntryJSON struct { } func handleRewriteList(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) arr := []*rewriteEntryJSON{} @@ -37,7 +36,6 @@ func handleRewriteList(w http.ResponseWriter, r *http.Request) { } func handleRewriteAdd(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) jsent := rewriteEntryJSON{} err := json.NewDecoder(r.Body).Decode(&jsent) @@ -66,7 +64,6 @@ func handleRewriteAdd(w http.ResponseWriter, r *http.Request) { } func handleRewriteDelete(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) jsent := rewriteEntryJSON{} err := json.NewDecoder(r.Body).Decode(&jsent) @@ -101,7 +98,7 @@ func handleRewriteDelete(w http.ResponseWriter, r *http.Request) { } func registerRewritesHandlers() { - http.HandleFunc("/control/rewrite/list", postInstall(optionalAuth(ensureGET(handleRewriteList)))) - http.HandleFunc("/control/rewrite/add", postInstall(optionalAuth(ensurePOST(handleRewriteAdd)))) - http.HandleFunc("/control/rewrite/delete", postInstall(optionalAuth(ensurePOST(handleRewriteDelete)))) + httpRegister(http.MethodGet, "/control/rewrite/list", handleRewriteList) + httpRegister(http.MethodPost, "/control/rewrite/add", handleRewriteAdd) + httpRegister(http.MethodPost, "/control/rewrite/delete", handleRewriteDelete) } diff --git a/home/helpers.go b/home/helpers.go index 7ff08004..23744fc0 100644 --- a/home/helpers.go +++ b/home/helpers.go @@ -29,6 +29,8 @@ import ( // ---------------------------------- func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { + log.Debug("%s %v", r.Method, r.URL) + if r.Method != method { http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed) return @@ -60,9 +62,9 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.handler(w, r) } -func ensureGETHandler(handler func(http.ResponseWriter, *http.Request)) http.Handler { +func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler { h := httpHandler{} - h.handler = ensure("GET", handler) + h.handler = ensure(method, handler) return &h }