diff --git a/control.go b/control.go index edc2647a..e6431941 100644 --- a/control.go +++ b/control.go @@ -1320,6 +1320,23 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) { } } +// -------------- +// DNS-over-HTTPS +// -------------- +func handleDOH(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil { + httpError(w, http.StatusNotFound, "Not Found") + return + } + + if !isRunning() { + httpError(w, http.StatusInternalServerError, "DNS server is not running") + return + } + + dnsServer.ServeHTTP(w, r) +} + // ------------------------ // registration of handlers // ------------------------ @@ -1370,4 +1387,6 @@ func registerControlHandlers() { http.HandleFunc("/control/tls/status", postInstall(optionalAuth(ensureGET(handleTLSStatus)))) http.HandleFunc("/control/tls/configure", postInstall(optionalAuth(ensurePOST(handleTLSConfigure)))) http.HandleFunc("/control/tls/validate", postInstall(optionalAuth(ensurePOST(handleTLSValidate)))) + + http.HandleFunc("/dns-query", postInstall(handleDOH)) } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index e035f8ed..331f0b9c 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "net/http" "strings" "sync" "time" @@ -259,24 +260,38 @@ func (s *Server) Reconfigure(config *ServerConfig) error { return nil } +// ServeHTTP is a HTTP handler method we use to provide DNS-over-HTTPS +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.RLock() + s.dnsProxy.ServeHTTP(w, r) + s.RUnlock() +} + // GetQueryLog returns a map with the current query log ready to be converted to a JSON func (s *Server) GetQueryLog() []map[string]interface{} { + s.RLock() + defer s.RUnlock() return s.queryLog.getQueryLog() } // GetStatsTop returns the current stop stats func (s *Server) GetStatsTop() *StatsTop { + s.RLock() + defer s.RUnlock() return s.queryLog.runningTop.getStatsTop() } // PurgeStats purges current server stats func (s *Server) PurgeStats() { - // TODO: Locks? + s.Lock() + defer s.Unlock() s.stats.purgeStats() } // GetAggregatedStats returns aggregated stats data for the 24 hours func (s *Server) GetAggregatedStats() map[string]interface{} { + s.RLock() + defer s.RUnlock() return s.stats.getAggregatedStats() } @@ -286,6 +301,8 @@ func (s *Server) GetAggregatedStats() map[string]interface{} { // end is end of the time range // returns nil if time unit is not supported func (s *Server) GetStatsHistory(timeUnit time.Duration, startTime time.Time, endTime time.Time) (map[string]interface{}, error) { + s.RLock() + defer s.RUnlock() return s.stats.getStatsHistory(timeUnit, startTime, endTime) }