diff --git a/app.go b/app.go index 57e9ea56..b55e9863 100644 --- a/app.go +++ b/app.go @@ -119,7 +119,7 @@ func main() { runStatsCollectors() runFilterRefreshers() - http.Handle("/", http.FileServer(box)) + http.Handle("/", optionalAuthHandler(http.FileServer(box))) registerControlHandlers() err = startDNSServer() diff --git a/config.go b/config.go index 20142b42..02c0d147 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,8 @@ type configuration struct { BindHost string `yaml:"bind_host"` BindPort int `yaml:"bind_port"` + AuthName string `yaml:"auth_name"` + AuthPass string `yaml:"auth_pass"` CoreDNS coreDNSConfig `yaml:"coredns"` Filters []filter `yaml:"filters"` UserRules []string `yaml:"user_rules"` diff --git a/control.go b/control.go index 070a52d2..e372cd59 100644 --- a/control.go +++ b/control.go @@ -1298,33 +1298,33 @@ func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { } func registerControlHandlers() { - http.HandleFunc("/control/start", ensurePOST(handleStart)) - http.HandleFunc("/control/stop", ensurePOST(handleStop)) - http.HandleFunc("/control/restart", ensurePOST(handleRestart)) - http.HandleFunc("/control/status", ensureGET(handleStatus)) - http.HandleFunc("/control/stats", ensureGET(handleStats)) - http.HandleFunc("/control/stats_history", ensureGET(handleStatsHistory)) - http.HandleFunc("/control/stats_top", ensureGET(handleStatsTop)) - http.HandleFunc("/control/querylog", handleQueryLog) - http.HandleFunc("/control/querylog_enable", ensurePOST(handleQueryLogEnable)) - http.HandleFunc("/control/querylog_disable", ensurePOST(handleQueryLogDisable)) - http.HandleFunc("/control/set_upstream_dns", ensurePOST(handleSetUpstreamDNS)) - http.HandleFunc("/control/filtering/enable", ensurePOST(handleFilteringEnable)) - http.HandleFunc("/control/filtering/disable", ensurePOST(handleFilteringDisable)) - http.HandleFunc("/control/filtering/status", ensureGET(handleFilteringStatus)) - http.HandleFunc("/control/filtering/add_url", ensurePUT(handleFilteringAddURL)) - http.HandleFunc("/control/filtering/remove_url", ensureDELETE(handleFilteringRemoveURL)) - http.HandleFunc("/control/filtering/enable_url", ensurePOST(handleFilteringEnableURL)) - http.HandleFunc("/control/filtering/disable_url", ensurePOST(handleFilteringDisableURL)) - http.HandleFunc("/control/filtering/set_rules", ensurePUT(handleFilteringSetRules)) - http.HandleFunc("/control/filtering/refresh", ensurePOST(handleFilteringRefresh)) - http.HandleFunc("/control/safebrowsing/enable", ensurePOST(handleSafeBrowsingEnable)) - http.HandleFunc("/control/safebrowsing/disable", ensurePOST(handleSafeBrowsingDisable)) - http.HandleFunc("/control/safebrowsing/status", ensureGET(handleSafeBrowsingStatus)) - http.HandleFunc("/control/parental/enable", ensurePOST(handleParentalEnable)) - http.HandleFunc("/control/parental/disable", ensurePOST(handleParentalDisable)) - http.HandleFunc("/control/parental/status", ensureGET(handleParentalStatus)) - http.HandleFunc("/control/safesearch/enable", ensurePOST(handleSafeSearchEnable)) - http.HandleFunc("/control/safesearch/disable", ensurePOST(handleSafeSearchDisable)) - http.HandleFunc("/control/safesearch/status", ensureGET(handleSafeSearchStatus)) + http.HandleFunc("/control/start", optionalAuth(ensurePOST(handleStart))) + http.HandleFunc("/control/stop", optionalAuth(ensurePOST(handleStop))) + http.HandleFunc("/control/restart", optionalAuth(ensurePOST(handleRestart))) + http.HandleFunc("/control/status", optionalAuth(ensureGET(handleStatus))) + http.HandleFunc("/control/stats", optionalAuth(ensureGET(handleStats))) + http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(handleStatsHistory))) + http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(handleStatsTop))) + http.HandleFunc("/control/querylog", optionalAuth(ensureGET(handleQueryLog))) + http.HandleFunc("/control/querylog_enable", optionalAuth(ensurePOST(handleQueryLogEnable))) + http.HandleFunc("/control/querylog_disable", optionalAuth(ensurePOST(handleQueryLogDisable))) + http.HandleFunc("/control/set_upstream_dns", optionalAuth(ensurePOST(handleSetUpstreamDNS))) + http.HandleFunc("/control/filtering/enable", optionalAuth(ensurePOST(handleFilteringEnable))) + http.HandleFunc("/control/filtering/disable", optionalAuth(ensurePOST(handleFilteringDisable))) + http.HandleFunc("/control/filtering/status", optionalAuth(ensureGET(handleFilteringStatus))) + http.HandleFunc("/control/filtering/add_url", optionalAuth(ensurePUT(handleFilteringAddURL))) + http.HandleFunc("/control/filtering/remove_url", optionalAuth(ensureDELETE(handleFilteringRemoveURL))) + http.HandleFunc("/control/filtering/enable_url", optionalAuth(ensurePOST(handleFilteringEnableURL))) + http.HandleFunc("/control/filtering/disable_url", optionalAuth(ensurePOST(handleFilteringDisableURL))) + http.HandleFunc("/control/filtering/set_rules", optionalAuth(ensurePUT(handleFilteringSetRules))) + http.HandleFunc("/control/filtering/refresh", optionalAuth(ensurePOST(handleFilteringRefresh))) + http.HandleFunc("/control/safebrowsing/enable", optionalAuth(ensurePOST(handleSafeBrowsingEnable))) + http.HandleFunc("/control/safebrowsing/disable", optionalAuth(ensurePOST(handleSafeBrowsingDisable))) + http.HandleFunc("/control/safebrowsing/status", optionalAuth(ensureGET(handleSafeBrowsingStatus))) + http.HandleFunc("/control/parental/enable", optionalAuth(ensurePOST(handleParentalEnable))) + http.HandleFunc("/control/parental/disable", optionalAuth(ensurePOST(handleParentalDisable))) + http.HandleFunc("/control/parental/status", optionalAuth(ensureGET(handleParentalStatus))) + http.HandleFunc("/control/safesearch/enable", optionalAuth(ensurePOST(handleSafeSearchEnable))) + http.HandleFunc("/control/safesearch/disable", optionalAuth(ensurePOST(handleSafeSearchDisable))) + http.HandleFunc("/control/safesearch/status", optionalAuth(ensureGET(handleSafeSearchStatus))) } diff --git a/helpers.go b/helpers.go index 62e260e6..a7e60f77 100644 --- a/helpers.go +++ b/helpers.go @@ -51,6 +51,46 @@ func ensureDELETE(handler func(http.ResponseWriter, *http.Request)) func(http.Re return ensure("DELETE", handler) } +func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if config.AuthName == "" || config.AuthPass == "" { + handler(w, r) + return + } + user, pass, ok := r.BasicAuth() + if !ok || user != config.AuthName || pass != config.AuthPass { + w.Header().Set("WWW-Authenticate", `Basic realm="dnsfilter"`) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Unauthorised.\n")) + return + } + handler(w, r) + } +} + +type authHandler struct { + handler http.Handler +} + +func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if config.AuthName == "" || config.AuthPass == "" { + a.handler.ServeHTTP(w, r) + return + } + user, pass, ok := r.BasicAuth() + if !ok || user != config.AuthName || pass != config.AuthPass { + w.Header().Set("WWW-Authenticate", `Basic realm="dnsfilter"`) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Unauthorised.\n")) + return + } + a.handler.ServeHTTP(w, r) +} + +func optionalAuthHandler(handler http.Handler) http.Handler { + return &authHandler{handler} +} + // -------------------------- // helper functions for stats // --------------------------