diff --git a/AGHTechDoc.md b/AGHTechDoc.md index c686ecf7..5a8beee8 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -71,7 +71,6 @@ Contents: ![](doc/agh-arch.png) - ## First startup The first application startup is detected when there's no .yaml configuration file. diff --git a/dhcpd/dhcpd.go b/dhcpd/dhcpd.go index b7af5dca..b40a8ae8 100644 --- a/dhcpd/dhcpd.go +++ b/dhcpd/dhcpd.go @@ -18,6 +18,8 @@ import ( const defaultDiscoverTime = time.Second * 3 const leaseExpireStatic = 1 +var webHandlersRegistered = false + // Lease contains the necessary information about a DHCP lease // field ordering is important -- yaml fields will mirror ordering from here type Lease struct { @@ -121,9 +123,6 @@ func Create(config ServerConfig) *Server { return nil } } - if s.conf.HTTPRegister != nil { - s.registerHandlers() - } // we can't delay database loading until DHCP server is started, // because we need static leases functionality available beforehand @@ -221,6 +220,11 @@ func (s *Server) setConfig(config ServerConfig) error { // Start will listen on port 67 and serve DHCP requests. func (s *Server) Start() error { + if !webHandlersRegistered && s.conf.HTTPRegister != nil { + webHandlersRegistered = true + s.registerHandlers() + } + // TODO: don't close if interface and addresses are the same if s.conn != nil { s.closeConn() diff --git a/home/blocked_services.go b/dnsfilter/blocked_services.go similarity index 77% rename from home/blocked_services.go rename to dnsfilter/blocked_services.go index 5c14499d..4969b2d9 100644 --- a/home/blocked_services.go +++ b/dnsfilter/blocked_services.go @@ -1,10 +1,9 @@ -package home +package dnsfilter import ( "encoding/json" "net/http" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" ) @@ -119,7 +118,7 @@ var serviceRulesArray = []svc{ } // convert array to map -func initServices() { +func initBlockedServices() { serviceRules = make(map[string][]*rules.NetworkRule) for _, s := range serviceRulesArray { netRules := []*rules.NetworkRule{} @@ -135,15 +134,20 @@ func initServices() { } } -// Return TRUE if a blocked service name is known -func blockedSvcKnown(s string) bool { +// BlockedSvcKnown - return TRUE if a blocked service name is known +func BlockedSvcKnown(s string) bool { _, ok := serviceRules[s] return ok } // ApplyBlockedServices - set blocked services settings for this DNS request -func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []string) { - setts.ServicesRules = []dnsfilter.ServiceEntry{} +func (d *Dnsfilter) ApplyBlockedServices(setts *RequestFilteringSettings, list []string, global bool) { + setts.ServicesRules = []ServiceEntry{} + if global { + d.confLock.RLock() + defer d.confLock.RUnlock() + list = d.Config.BlockedServices + } for _, name := range list { rules, ok := serviceRules[name] @@ -152,51 +156,45 @@ func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []stri continue } - s := dnsfilter.ServiceEntry{} + s := ServiceEntry{} s.Name = name s.Rules = rules setts.ServicesRules = append(setts.ServicesRules, s) } } -func handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { - config.RLock() - list := config.DNS.BlockedServices - config.RUnlock() +func (d *Dnsfilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { + d.confLock.RLock() + list := d.Config.BlockedServices + d.confLock.RUnlock() w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(list) if err != nil { - httpError(w, http.StatusInternalServerError, "json.Encode: %s", err) + httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) return } } -func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { +func (d *Dnsfilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { list := []string{} err := json.NewDecoder(r.Body).Decode(&list) if err != nil { - httpError(w, http.StatusBadRequest, "json.Decode: %s", err) + httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) return } - config.Lock() - config.DNS.BlockedServices = list - config.Unlock() + d.confLock.Lock() + d.Config.BlockedServices = list + d.confLock.Unlock() log.Debug("Updated blocked services list: %d", len(list)) - err = writeAllConfigsAndReloadDNS() - if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) - return - } - - httpOK(r, w) + d.ConfigModified() } -// RegisterBlockedServicesHandlers - register HTTP handlers -func RegisterBlockedServicesHandlers() { - httpRegister(http.MethodGet, "/control/blocked_services/list", handleBlockedServicesList) - httpRegister(http.MethodPost, "/control/blocked_services/set", handleBlockedServicesSet) +// registerBlockedServicesHandlers - register HTTP handlers +func (d *Dnsfilter) registerBlockedServicesHandlers() { + d.Config.HTTPRegister("GET", "/control/blocked_services/list", d.handleBlockedServicesList) + d.Config.HTTPRegister("POST", "/control/blocked_services/set", d.handleBlockedServicesSet) } diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index adfabd9c..3e2b8f0c 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -49,6 +49,10 @@ type Config struct { Rewrites []RewriteEntry `yaml:"rewrites"` + // Names of services to block (globally). + // Per-client settings can override this configuration. + BlockedServices []string `yaml:"blocked_services"` + // Called when the configuration is changed by HTTP request ConfigModified func() `yaml:"-"` @@ -175,6 +179,7 @@ func (d *Dnsfilter) WriteDiskConfig(c *Config) { d.confLock.Lock() *c = d.Config c.Rewrites = rewriteArrayDup(d.Config.Rewrites) + // BlockedServices d.confLock.Unlock() } @@ -633,6 +638,18 @@ func New(c *Config, blockFilters []Filter) *Dnsfilter { d.prepareRewrites() } + bsvcs := []string{} + for _, s := range d.BlockedServices { + if !BlockedSvcKnown(s) { + log.Debug("skipping unknown blocked-service '%s'", s) + continue + } + bsvcs = append(bsvcs, s) + } + d.BlockedServices = bsvcs + + initBlockedServices() + if blockFilters != nil { err := d.initFiltering(nil, blockFilters) if err != nil { @@ -655,6 +672,7 @@ func (d *Dnsfilter) Start() { if d.Config.HTTPRegister != nil { // for tests d.registerSecurityHandlers() d.registerRewritesHandlers() + d.registerBlockedServicesHandlers() } } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 6095af08..64949ce1 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -179,6 +179,7 @@ type ServerConfig struct { FilteringConfig TLSConfig + TLSAllowUnencryptedDOH bool TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2 diff --git a/dnsforward/dnsforward_http.go b/dnsforward/dnsforward_http.go index 088add56..111fefbd 100644 --- a/dnsforward/dnsforward_http.go +++ b/dnsforward/dnsforward_http.go @@ -376,6 +376,20 @@ func checkDNS(input string, bootstrap []string) error { return nil } +func (s *Server) handleDOH(w http.ResponseWriter, r *http.Request) { + if !s.conf.TLSAllowUnencryptedDOH && r.TLS == nil { + httpError(r, w, http.StatusNotFound, "Not Found") + return + } + + if !s.IsRunning() { + httpError(r, w, http.StatusInternalServerError, "DNS server is not running") + return + } + + s.ServeHTTP(w, r) +} + func (s *Server) registerHandlers() { s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig) s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig) @@ -384,4 +398,6 @@ func (s *Server) registerHandlers() { s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList) s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet) + + s.conf.HTTPRegister("", "/dns-query", s.handleDOH) } diff --git a/home/clients.go b/home/clients.go index 1c2b795b..c7571a5f 100644 --- a/home/clients.go +++ b/home/clients.go @@ -14,6 +14,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/dhcpd" + "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" @@ -24,6 +25,8 @@ const ( clientsUpdatePeriod = 1 * time.Hour ) +var webHandlersRegistered = false + // Client information type Client struct { IDs []string @@ -98,15 +101,29 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd. clients.addFromConfig(objects) if !clients.testing { - go clients.periodicUpdate() - clients.addFromDHCP() clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged) - - clients.registerWebHandlers() } } +// Start - start the module +func (clients *clientsContainer) Start() { + if !clients.testing { + if !webHandlersRegistered { + webHandlersRegistered = true + clients.registerWebHandlers() + } + go clients.periodicUpdate() + } + +} + +// Reload - reload auto-clients +func (clients *clientsContainer) Reload() { + clients.addFromHostsFile() + clients.addFromSystemARP() +} + type clientObject struct { Name string `yaml:"name"` Tags []string `yaml:"tags"` @@ -145,7 +162,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) { } for _, s := range cy.BlockedServices { - if !blockedSvcKnown(s) { + if !dnsfilter.BlockedSvcKnown(s) { log.Debug("Clients: skipping unknown blocked-service '%s'", s) continue } @@ -194,8 +211,7 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) { func (clients *clientsContainer) periodicUpdate() { for { - clients.addFromHostsFile() - clients.addFromSystemARP() + clients.Reload() time.Sleep(clientsUpdatePeriod) } } diff --git a/home/config.go b/home/config.go index 657325e2..cb586131 100644 --- a/home/config.go +++ b/home/config.go @@ -2,7 +2,6 @@ package home import ( "io/ioutil" - "net/http" "os" "path/filepath" "sync" @@ -29,14 +28,6 @@ type logSettings struct { Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled } -// HTTPSServer - HTTPS Server -type HTTPSServer struct { - server *http.Server - cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey - sync.Mutex // protects config.TLS - shutdown bool // if TRUE, don't restart the server -} - // configuration is loaded from YAML // field ordering is important -- yaml fields will mirror ordering from here type configuration struct { @@ -59,8 +50,8 @@ type configuration struct { // An active session is automatically refreshed once a day. WebSessionTTLHours uint32 `yaml:"web_session_ttl"` - DNS dnsConfig `yaml:"dns"` - TLS tlsConfig `yaml:"tls"` + DNS dnsConfig `yaml:"dns"` + TLS tlsConfigSettings `yaml:"tls"` Filters []filter `yaml:"filters"` WhitelistFilters []filter `yaml:"whitelist_filters"` @@ -95,10 +86,6 @@ type dnsConfig struct { FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) DnsfilterConf dnsfilter.Config `yaml:",inline"` - - // Names of services to block (globally). - // Per-client settings can override this configuration. - BlockedServices []string `yaml:"blocked_services"` } type tlsConfigSettings struct { @@ -114,33 +101,6 @@ type tlsConfigSettings struct { dnsforward.TLSConfig `yaml:",inline" json:",inline"` } -// field ordering is not important -- these are for API and are recalculated on each run -type tlsConfigStatus struct { - ValidCert bool `yaml:"-" json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates - ValidChain bool `yaml:"-" json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA - Subject string `yaml:"-" json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain - Issuer string `yaml:"-" json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain - NotBefore time.Time `yaml:"-" json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain - NotAfter time.Time `yaml:"-" json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain - DNSNames []string `yaml:"-" json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain - - // key status - ValidKey bool `yaml:"-" json:"valid_key"` // ValidKey is true if the key is a valid private key - KeyType string `yaml:"-" json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA - - // is usable? set by validator - ValidPair bool `yaml:"-" json:"valid_pair"` // ValidPair is true if both certificate and private key are correct - - // warnings - WarningValidation string `yaml:"-" json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description -} - -// field ordering is important -- yaml fields will mirror ordering from here -type tlsConfig struct { - tlsConfigSettings `yaml:",inline" json:",inline"` - tlsConfigStatus `yaml:"-" json:",inline"` -} - // initialize to default values, will be changed later when reading config or parsing command line var config = configuration{ BindPort: 3000, @@ -160,11 +120,9 @@ var config = configuration{ FilteringEnabled: true, // whether or not use filter lists FiltersUpdateIntervalHours: 24, }, - TLS: tlsConfig{ - tlsConfigSettings: tlsConfigSettings{ - PortHTTPS: 443, - PortDNSOverTLS: 853, // needs to be passed through to dnsproxy - }, + TLS: tlsConfigSettings{ + PortHTTPS: 443, + PortDNSOverTLS: 853, // needs to be passed through to dnsproxy }, DHCP: dhcpd.ServerConfig{ LeaseDuration: 86400, @@ -234,26 +192,10 @@ func parseConfig() error { return err } - bsvcs := []string{} - for _, s := range config.DNS.BlockedServices { - if !blockedSvcKnown(s) { - log.Debug("skipping unknown blocked-service '%s'", s) - continue - } - bsvcs = append(bsvcs, s) - } - config.DNS.BlockedServices = bsvcs - if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) { config.DNS.FiltersUpdateIntervalHours = 24 } - status := tlsConfigStatus{} - if !tlsLoadConfig(&config.TLS, &status) { - log.Error("%s", status.WarningValidation) - return err - } - return nil } @@ -282,6 +224,11 @@ func (c *configuration) write() error { if Context.auth != nil { config.Users = Context.auth.GetUsers() } + if Context.tls != nil { + tlsConf := tlsConfigSettings{} + Context.tls.WriteDiskConfig(&tlsConf) + config.TLS = tlsConf + } if Context.stats != nil { sdc := stats.DiskConfig{} @@ -331,13 +278,3 @@ func (c *configuration) write() error { return nil } - -func writeAllConfigs() error { - err := config.write() - if err != nil { - log.Error("Couldn't write config: %s", err) - return err - } - - return nil -} diff --git a/home/control.go b/home/control.go index ca8939ca..4d1dbe0c 100644 --- a/home/control.go +++ b/home/control.go @@ -9,8 +9,6 @@ import ( "strconv" "strings" - "github.com/AdguardTeam/AdGuardHome/util" - "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/golibs/log" "github.com/NYTimes/gziphandler" @@ -27,9 +25,6 @@ func returnOK(w http.ResponseWriter) { } } -func httpOK(r *http.Request, w http.ResponseWriter) { -} - func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) { text := fmt.Sprintf(format, args...) log.Info(text) @@ -39,15 +34,6 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface // --------------- // dns run control // --------------- -func writeAllConfigsAndReloadDNS() error { - err := writeAllConfigs() - if err != nil { - log.Error("Couldn't write all configs: %s", err) - return err - } - return reconfigureDNSServer() -} - func addDNSAddress(dnsAddresses *[]string, addr string) { if config.DNS.Port != 53 { addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port) @@ -55,46 +41,6 @@ func addDNSAddress(dnsAddresses *[]string, addr string) { *dnsAddresses = append(*dnsAddresses, addr) } -// Get the list of DNS addresses the server is listening on -func getDNSAddresses() []string { - dnsAddresses := []string{} - - if config.DNS.BindHost == "0.0.0.0" { - ifaces, e := util.GetValidNetInterfacesForWeb() - if e != nil { - log.Error("Couldn't get network interfaces: %v", e) - return []string{} - } - - for _, iface := range ifaces { - for _, addr := range iface.Addresses { - addDNSAddress(&dnsAddresses, addr) - } - } - } else { - addDNSAddress(&dnsAddresses, config.DNS.BindHost) - } - - if config.TLS.Enabled && len(config.TLS.ServerName) != 0 { - - if config.TLS.PortHTTPS != 0 { - addr := config.TLS.ServerName - if config.TLS.PortHTTPS != 443 { - addr = fmt.Sprintf("%s:%d", addr, config.TLS.PortHTTPS) - } - addr = fmt.Sprintf("https://%s/dns-query", addr) - dnsAddresses = append(dnsAddresses, addr) - } - - if config.TLS.PortDNSOverTLS != 0 { - addr := fmt.Sprintf("tls://%s:%d", config.TLS.ServerName, config.TLS.PortDNSOverTLS) - dnsAddresses = append(dnsAddresses, addr) - } - } - - return dnsAddresses -} - func handleStatus(w http.ResponseWriter, r *http.Request) { c := dnsforward.FilteringConfig{} if Context.dnsServer != nil { @@ -144,23 +90,6 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(data) } -// -------------- -// DNS-over-HTTPS -// -------------- -func handleDOH(w http.ResponseWriter, r *http.Request) { - if !config.TLS.AllowUnencryptedDOH && r.TLS == nil { - httpError(w, http.StatusNotFound, "Not Found") - return - } - - if !isRunning() { - httpError(w, http.StatusInternalServerError, "DNS server is not running") - return - } - - Context.dnsServer.ServeHTTP(w, r) -} - // ------------------------ // registration of handlers // ------------------------ @@ -172,15 +101,16 @@ func registerControlHandlers() { httpRegister(http.MethodPost, "/control/update", handleUpdate) httpRegister("GET", "/control/profile", handleGetProfile) - - RegisterTLSHandlers() - RegisterBlockedServicesHandlers() RegisterAuthHandlers() - - http.HandleFunc("/dns-query", postInstall(handleDOH)) } func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) { + if len(method) == 0 { + // "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method + http.HandleFunc(url, postInstall(handler)) + return + } + http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler))))) } @@ -258,14 +188,16 @@ func preInstallHandler(handler http.Handler) http.Handler { // it also enforces HTTPS if it is enabled and configured func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { + if Context.firstRun && !strings.HasPrefix(r.URL.Path, "/install.") && r.URL.Path != "/favicon.png" { - http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable + http.Redirect(w, r, "/install.html", http.StatusFound) return } + // enforce https? - if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil { + if r.TLS == nil && Context.web.forceHTTPS && Context.web.httpsServer.server != nil { // yes, and we want host from host:port host, _, err := net.SplitHostPort(r.Host) if err != nil { @@ -275,13 +207,14 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res // construct new URL to redirect to newURL := url.URL{ Scheme: "https", - Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)), + Host: net.JoinHostPort(host, strconv.Itoa(Context.web.portHTTPS)), Path: r.URL.Path, RawQuery: r.URL.RawQuery, } http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect) return } + w.Header().Set("Access-Control-Allow-Origin", "*") handler(w, r) } diff --git a/home/control_filtering.go b/home/control_filtering.go index 242bbf3d..837d2a45 100644 --- a/home/control_filtering.go +++ b/home/control_filtering.go @@ -349,7 +349,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { setts := Context.dnsFilter.GetConfig() setts.FilteringEnabled = true - ApplyBlockedServices(&setts, config.DNS.BlockedServices) + Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) if err != nil { httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err) diff --git a/home/control_install.go b/home/control_install.go index 17ffafa9..e3ada191 100644 --- a/home/control_install.go +++ b/home/control_install.go @@ -35,7 +35,7 @@ type netInterfaceJSON struct { } // Get initial installation settings -func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { +func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { data := firstRunData{} data.WebPort = 80 data.DNSPort = 53 @@ -93,7 +93,7 @@ type checkConfigResp struct { } // Check if ports are available, respond with results -func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { +func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { reqData := checkConfigReq{} respData := checkConfigResp{} err := json.NewDecoder(r.Body).Decode(&reqData) @@ -275,7 +275,7 @@ func copyInstallSettings(dst *configuration, src *configuration) { } // Apply new configuration, start DNS server, restart Web server -func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { +func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { newSettings := applyConfigReq{} err := json.NewDecoder(r.Body).Decode(&newSettings) if err != nil { @@ -325,22 +325,11 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { config.DNS.BindHost = newSettings.DNS.IP config.DNS.Port = newSettings.DNS.Port - err = initDNSServer() - var err2 error - if err == nil { - err2 = startDNSServer() - if err2 != nil { - closeDNSServer() - } - } - if err != nil || err2 != nil { + err = StartMods() + if err != nil { Context.firstRun = true copyInstallSettings(&config, &curConfig) - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err) - } else { - httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err2) - } + httpError(w, http.StatusInternalServerError, "%s", err) return } @@ -356,19 +345,21 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { return } + registerControlHandlers() + // this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely if restartHTTP { go func() { - _ = Context.httpServer.Shutdown(context.TODO()) + _ = Context.web.httpServer.Shutdown(context.TODO()) }() } returnOK(w) } -func registerInstallHandlers() { - http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(handleInstallGetAddresses))) - http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(handleInstallCheckConfig))) - http.HandleFunc("/control/install/configure", preInstall(ensurePOST(handleInstallConfigure))) +func (web *Web) registerInstallHandlers() { + http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses))) + http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig))) + http.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure))) } diff --git a/home/control_update.go b/home/control_update.go index 87fe4034..5730c3e0 100644 --- a/home/control_update.go +++ b/home/control_update.go @@ -476,7 +476,6 @@ func doUpdate(u *updateInfo) error { func finishUpdate(u *updateInfo) { log.Info("Stopping all tasks") cleanup() - stopHTTPServer() cleanupAlways() if runtime.GOOS == "windows" { diff --git a/home/dns.go b/home/dns.go index 6e53dc79..29cec636 100644 --- a/home/dns.go +++ b/home/dns.go @@ -3,13 +3,13 @@ package home import ( "fmt" "net" - "os" "path/filepath" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" @@ -25,13 +25,9 @@ func onConfigModified() { // Please note that we must do it even if we don't start it // so that we had access to the query log and the stats func initDNSServer() error { + var err error baseDir := Context.getDataDir() - err := os.MkdirAll(baseDir, 0755) - if err != nil { - return fmt.Errorf("Cannot create DNS data dir at %s: %s", baseDir, err) - } - statsConf := stats.Config{ Filename: filepath.Join(baseDir, "stats.db"), LimitDays: config.DNS.StatsInterval, @@ -70,14 +66,6 @@ func initDNSServer() error { return fmt.Errorf("dnsServer.Prepare: %s", err) } - sessFilename := filepath.Join(baseDir, "sessions.db") - Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) - if Context.auth == nil { - closeDNSServer() - return fmt.Errorf("Couldn't initialize Auth module") - } - config.Users = nil - Context.rdns = InitRDNS(Context.dnsServer, &Context.clients) Context.whois = initWhois(&Context.clients) @@ -169,26 +157,74 @@ func generateServerConfig() dnsforward.ServerConfig { OnDNSRequest: onDNSRequest, } - if config.TLS.Enabled { - newconfig.TLSConfig = config.TLS.TLSConfig - if config.TLS.PortDNSOverTLS != 0 { - newconfig.TLSListenAddr = &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.TLS.PortDNSOverTLS} + tlsConf := tlsConfigSettings{} + Context.tls.WriteDiskConfig(&tlsConf) + if tlsConf.Enabled { + newconfig.TLSConfig = tlsConf.TLSConfig + if tlsConf.PortDNSOverTLS != 0 { + newconfig.TLSListenAddr = &net.TCPAddr{ + IP: net.ParseIP(config.DNS.BindHost), + Port: tlsConf.PortDNSOverTLS, + } } } newconfig.TLSv12Roots = Context.tlsRoots + newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH newconfig.FilterHandler = applyAdditionalFiltering newconfig.GetUpstreamsByClient = getUpstreamsByClient return newconfig } +// Get the list of DNS addresses the server is listening on +func getDNSAddresses() []string { + dnsAddresses := []string{} + + if config.DNS.BindHost == "0.0.0.0" { + ifaces, e := util.GetValidNetInterfacesForWeb() + if e != nil { + log.Error("Couldn't get network interfaces: %v", e) + return []string{} + } + + for _, iface := range ifaces { + for _, addr := range iface.Addresses { + addDNSAddress(&dnsAddresses, addr) + } + } + } else { + addDNSAddress(&dnsAddresses, config.DNS.BindHost) + } + + tlsConf := tlsConfigSettings{} + Context.tls.WriteDiskConfig(&tlsConf) + if tlsConf.Enabled && len(tlsConf.ServerName) != 0 { + + if tlsConf.PortHTTPS != 0 { + addr := tlsConf.ServerName + if tlsConf.PortHTTPS != 443 { + addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS) + } + addr = fmt.Sprintf("https://%s/dns-query", addr) + dnsAddresses = append(dnsAddresses, addr) + } + + if tlsConf.PortDNSOverTLS != 0 { + addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS) + dnsAddresses = append(dnsAddresses, addr) + } + } + + return dnsAddresses +} + func getUpstreamsByClient(clientAddr string) []upstream.Upstream { return Context.clients.FindUpstreams(clientAddr) } // If a client has his own settings, apply them func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { - ApplyBlockedServices(setts, config.DNS.BlockedServices) + Context.dnsFilter.ApplyBlockedServices(setts, nil, true) if len(clientAddr) == 0 { return @@ -202,7 +238,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri log.Debug("Using settings for client with IP %s", clientAddr) if c.UseOwnBlockedServices { - ApplyBlockedServices(setts, c.BlockedServices) + Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false) } setts.ClientTags = c.Tags @@ -224,6 +260,8 @@ func startDNSServer() error { enableFilters(false) + Context.clients.Start() + err := Context.dnsServer.Start() if err != nil { return errorx.Decorate(err, "Couldn't start forwarding DNS server") @@ -295,11 +333,6 @@ func closeDNSServer() { Context.queryLog = nil } - if Context.auth != nil { - Context.auth.Close() - Context.auth = nil - } - Context.filters.Close() log.Debug("Closed all DNS modules") diff --git a/home/home.go b/home/home.go index 760d48ed..17d05d41 100644 --- a/home/home.go +++ b/home/home.go @@ -34,8 +34,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/golibs/log" - "github.com/NYTimes/gziphandler" - "github.com/gobuffalo/packr" ) const ( @@ -58,18 +56,18 @@ type homeContext struct { // Modules // -- - clients clientsContainer // per-client-settings module - stats stats.Stats // statistics module - queryLog querylog.QueryLog // query log module - dnsServer *dnsforward.Server // DNS module - rdns *RDNS // rDNS module - whois *Whois // WHOIS module - dnsFilter *dnsfilter.Dnsfilter // DNS filtering module - dhcpServer *dhcpd.Server // DHCP module - auth *Auth // HTTP authentication module - httpServer *http.Server // HTTP module - httpsServer HTTPSServer // HTTPS module - filters Filtering + clients clientsContainer // per-client-settings module + stats stats.Stats // statistics module + queryLog querylog.QueryLog // query log module + dnsServer *dnsforward.Server // DNS module + rdns *RDNS // rDNS module + whois *Whois // WHOIS module + dnsFilter *dnsfilter.Dnsfilter // DNS filtering module + dhcpServer *dhcpd.Server // DHCP module + auth *Auth // HTTP authentication module + filters Filtering // DNS filtering module + web *Web // Web (HTTP, HTTPS) module + tls *TLSMod // TLS module // Runtime properties // -- @@ -116,10 +114,20 @@ func Main(version string, channel string, armVer string) { Context.appSignalChannel = make(chan os.Signal) signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) go func() { - <-Context.appSignalChannel - cleanup() - cleanupAlways() - os.Exit(0) + for { + sig := <-Context.appSignalChannel + log.Info("Received signal '%s'", sig) + switch sig { + case syscall.SIGHUP: + Context.clients.Reload() + Context.tls.Reload() + + default: + cleanup() + cleanupAlways() + os.Exit(0) + } + } }() // run the protection @@ -163,7 +171,6 @@ func run(args options) { } initConfig() - initServices() Context.tlsRoots = util.LoadSystemRootCAs() Context.transport = &http.Transport{ @@ -218,6 +225,9 @@ func run(args options) { if args.bindPort != 0 { config.BindPort = args.bindPort } + if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { + Context.pidFileName = args.pidFile + } if !Context.firstRun { // Save the updated config @@ -225,11 +235,42 @@ func run(args options) { if err != nil { log.Fatal(err) } + } - err = initDNSServer() + err := os.MkdirAll(Context.getDataDir(), 0755) + if err != nil { + log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err) + } + + sessFilename := filepath.Join(Context.getDataDir(), "sessions.db") + Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) + if Context.auth == nil { + log.Fatalf("Couldn't initialize Auth module") + } + config.Users = nil + + Context.tls = tlsCreate(config.TLS) + if Context.tls == nil { + log.Fatalf("Can't initialize TLS module") + } + + webConf := WebConfig{ + firstRun: Context.firstRun, + BindHost: config.BindHost, + BindPort: config.BindPort, + } + Context.web = CreateWeb(&webConf) + if Context.web == nil { + log.Fatalf("Can't initialize Web module") + } + + if !Context.firstRun { + err := initDNSServer() if err != nil { log.Fatalf("%s", err) } + Context.tls.Start() + go func() { err := startDNSServer() if err != nil { @@ -243,101 +284,27 @@ func run(args options) { } } - if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { - Context.pidFileName = args.pidFile - } - - // Initialize and run the admin Web interface - box := packr.NewBox("../build/static") - - // if not configured, redirect / to /install.html, otherwise redirect /install.html to / - http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box))))) - registerControlHandlers() - - // add handlers for /install paths, we only need them when we're not configured yet - if Context.firstRun { - log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ") - http.Handle("/install.html", preInstallHandler(http.FileServer(box))) - registerInstallHandlers() - } - - Context.httpsServer.cond = sync.NewCond(&Context.httpsServer.Mutex) - - // for https, we have a separate goroutine loop - go httpServerLoop() - - // this loop is used as an ability to change listening host and/or port - for !Context.httpsServer.shutdown { - printHTTPAddresses("http") - - // we need to have new instance, because after Shutdown() the Server is not usable - address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) - Context.httpServer = &http.Server{ - Addr: address, - } - err := Context.httpServer.ListenAndServe() - if err != http.ErrServerClosed { - cleanupAlways() - log.Fatal(err) - } - // We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop - } + Context.web.Start() // wait indefinitely for other go-routines to complete their job select {} } -func httpServerLoop() { - for !Context.httpsServer.shutdown { - Context.httpsServer.cond.L.Lock() - // this mechanism doesn't let us through until all conditions are met - for config.TLS.Enabled == false || - config.TLS.PortHTTPS == 0 || - len(config.TLS.PrivateKeyData) == 0 || - len(config.TLS.CertificateChainData) == 0 { // sleep until necessary data is supplied - Context.httpsServer.cond.Wait() - } - address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS)) - // validate current TLS config and update warnings (it could have been loaded from file) - data := validateCertificates(string(config.TLS.CertificateChainData), string(config.TLS.PrivateKeyData), config.TLS.ServerName) - if !data.ValidPair { - cleanupAlways() - log.Fatal(data.WarningValidation) - } - config.Lock() - config.TLS.tlsConfigStatus = data // update warnings - config.Unlock() - - // prepare certs for HTTPS server - // important -- they have to be copies, otherwise changing the contents in config.TLS will break encryption for in-flight requests - certchain := make([]byte, len(config.TLS.CertificateChainData)) - copy(certchain, config.TLS.CertificateChainData) - privatekey := make([]byte, len(config.TLS.PrivateKeyData)) - copy(privatekey, config.TLS.PrivateKeyData) - cert, err := tls.X509KeyPair(certchain, privatekey) - if err != nil { - cleanupAlways() - log.Fatal(err) - } - Context.httpsServer.cond.L.Unlock() - - // prepare HTTPS server - Context.httpsServer.server = &http.Server{ - Addr: address, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - RootCAs: Context.tlsRoots, - }, - } - - printHTTPAddresses("https") - err = Context.httpsServer.server.ListenAndServeTLS("", "") - if err != http.ErrServerClosed { - cleanupAlways() - log.Fatal(err) - } +// StartMods - initialize and start DNS after installation +func StartMods() error { + err := initDNSServer() + if err != nil { + return err } + + Context.tls.Start() + + err = startDNSServer() + if err != nil { + closeDNSServer() + return err + } + return nil } // Check if the current user has root (administrator) rights @@ -449,6 +416,15 @@ func configureLogger(args options) { func cleanup() { log.Info("Stopping AdGuard Home") + if Context.web != nil { + Context.web.Close() + Context.web = nil + } + if Context.auth != nil { + Context.auth.Close() + Context.auth = nil + } + err := stopDNSServer() if err != nil { log.Error("Couldn't stop DNS server: %s", err) @@ -457,17 +433,11 @@ func cleanup() { if err != nil { log.Error("Couldn't stop DHCP server: %s", err) } -} -// Stop HTTP server, possibly waiting for all active connections to be closed -func stopHTTPServer() { - log.Info("Stopping HTTP server...") - Context.httpsServer.shutdown = true - if Context.httpsServer.server != nil { - _ = Context.httpsServer.server.Shutdown(context.TODO()) + if Context.tls != nil { + Context.tls.Close() + Context.tls = nil } - _ = Context.httpServer.Shutdown(context.TODO()) - log.Info("Stopped HTTP server") } // This function is called before application exits @@ -519,7 +489,7 @@ func loadOptions() options { } o.bindPort = v }, nil}, - {"service", "s", "Service control action: status, install, uninstall, start, stop, restart", func(value string) { + {"service", "s", "Service control action: status, install, uninstall, start, stop, restart, reload (configuration)", func(value string) { o.serviceControlAction = value }, nil}, {"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) { @@ -588,11 +558,13 @@ func loadOptions() options { func printHTTPAddresses(proto string) { var address string - if proto == "https" && config.TLS.ServerName != "" { - if config.TLS.PortHTTPS == 443 { - log.Printf("Go to https://%s", config.TLS.ServerName) + tlsConf := tlsConfigSettings{} + Context.tls.WriteDiskConfig(&tlsConf) + if proto == "https" && tlsConf.ServerName != "" { + if tlsConf.PortHTTPS == 443 { + log.Printf("Go to https://%s", tlsConf.ServerName) } else { - log.Printf("Go to https://%s:%d", config.TLS.ServerName, config.TLS.PortHTTPS) + log.Printf("Go to https://%s:%d", tlsConf.ServerName, tlsConf.PortHTTPS) } } else if config.BindHost == "0.0.0.0" { log.Println("AdGuard Home is available on the following addresses:") diff --git a/home/service.go b/home/service.go index d066e118..2cb31ecb 100644 --- a/home/service.go +++ b/home/service.go @@ -1,9 +1,11 @@ package home import ( + "fmt" "io/ioutil" "os" "runtime" + "strconv" "strings" "syscall" @@ -71,6 +73,48 @@ func svcAction(s service.Service, action string) error { return err } +// Send SIGHUP to a process with ID taken from our pid-file +// If pid-file doesn't exist, find our PID using 'ps' command +func sendSigReload() { + if runtime.GOOS == "windows" { + log.Error("Not implemented on Windows") + return + } + + pidfile := fmt.Sprintf("/var/run/%s.pid", serviceName) + data, err := ioutil.ReadFile(pidfile) + if os.IsNotExist(err) { + code, psdata, err := util.RunCommand("ps", "-C", serviceName, "-o", "pid=") + if err != nil || code != 0 { + log.Error("Can't find AdGuardHome process: %s code:%d", err, code) + return + } + data = []byte(psdata) + + } else if err != nil { + log.Error("Can't read PID file %s: %s", pidfile, err) + return + } + + parts := strings.SplitN(string(data), "\n", 2) + if len(parts) == 0 { + log.Error("Can't read PID file %s: bad value", pidfile) + return + } + + pid, err := strconv.Atoi(parts[0]) + if err != nil { + log.Error("Can't read PID file %s: %s", pidfile, err) + return + } + err = util.SendProcessSignal(pid, syscall.SIGHUP) + if err != nil { + log.Error("Can't send signal to PID %d: %s", pid, err) + return + } + log.Debug("Sent signal to PID %d", pid) +} + // handleServiceControlAction one of the possible control actions: // install -- installs a service/daemon // uninstall -- uninstalls it @@ -84,6 +128,11 @@ func svcAction(s service.Service, action string) error { func handleServiceControlAction(action string) { log.Printf("Service control action: %s", action) + if action == "reload" { + sendSigReload() + return + } + pwd, err := os.Getwd() if err != nil { log.Fatal("Unable to find the path to the current directory") diff --git a/home/control_tls.go b/home/tls.go similarity index 59% rename from home/control_tls.go rename to home/tls.go index e102ecbd..317e7dd8 100644 --- a/home/control_tls.go +++ b/home/tls.go @@ -1,9 +1,6 @@ -// Control: TLS configuring handlers - package home import ( - "context" "crypto" "crypto/ecdsa" "crypto/rsa" @@ -16,18 +13,125 @@ import ( "fmt" "io/ioutil" "net/http" + "os" "reflect" "strings" + "sync" "time" - "github.com/AdguardTeam/AdGuardHome/util" - "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" ) +var tlsWebHandlersRegistered = false + +// TLSMod - TLS module object +type TLSMod struct { + certLastMod time.Time // last modification time of the certificate file + conf tlsConfigSettings + confLock sync.Mutex + status tlsConfigStatus +} + +// Create TLS module +func tlsCreate(conf tlsConfigSettings) *TLSMod { + t := &TLSMod{} + t.conf = conf + if t.conf.Enabled { + if !t.load() { + return nil + } + t.setCertFileTime() + } + return t +} + +func (t *TLSMod) load() bool { + if !tlsLoadConfig(&t.conf, &t.status) { + return false + } + + // validate current TLS config and update warnings (it could have been loaded from file) + data := validateCertificates(string(t.conf.CertificateChainData), string(t.conf.PrivateKeyData), t.conf.ServerName) + if !data.ValidPair { + log.Error(data.WarningValidation) + return false + } + t.status = data + return true +} + +// Close - close module +func (t *TLSMod) Close() { +} + +// WriteDiskConfig - write config +func (t *TLSMod) WriteDiskConfig(conf *tlsConfigSettings) { + t.confLock.Lock() + *conf = t.conf + t.confLock.Unlock() +} + +func (t *TLSMod) setCertFileTime() { + if len(t.conf.CertificatePath) == 0 { + return + } + fi, err := os.Stat(t.conf.CertificatePath) + if err != nil { + log.Error("TLS: %s", err) + return + } + t.certLastMod = fi.ModTime().UTC() +} + +// Start - start the module +func (t *TLSMod) Start() { + if !tlsWebHandlersRegistered { + tlsWebHandlersRegistered = true + t.registerWebHandlers() + } + + t.confLock.Lock() + tlsConf := t.conf + t.confLock.Unlock() + Context.web.TLSConfigChanged(tlsConf) +} + +// Reload - reload certificate file +func (t *TLSMod) Reload() { + t.confLock.Lock() + tlsConf := t.conf + t.confLock.Unlock() + + if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 { + return + } + fi, err := os.Stat(tlsConf.CertificatePath) + if err != nil { + log.Error("TLS: %s", err) + return + } + if fi.ModTime().UTC().Equal(t.certLastMod) { + log.Debug("TLS: certificate file isn't modified") + return + } + log.Debug("TLS: certificate file is modified") + + t.confLock.Lock() + r := t.load() + t.confLock.Unlock() + if !r { + return + } + + t.certLastMod = fi.ModTime().UTC() + + _ = reconfigureDNSServer() + Context.web.TLSConfigChanged(tlsConf) +} + // Set certificate and private key data -func tlsLoadConfig(tls *tlsConfig, status *tlsConfigStatus) bool { +func tlsLoadConfig(tls *tlsConfigSettings, status *tlsConfigStatus) bool { tls.CertificateChainData = []byte(tls.CertificateChain) tls.PrivateKeyData = []byte(tls.PrivateKey) @@ -61,98 +165,115 @@ func tlsLoadConfig(tls *tlsConfig, status *tlsConfigStatus) bool { return true } -// RegisterTLSHandlers registers HTTP handlers for TLS configuration -func RegisterTLSHandlers() { - httpRegister(http.MethodGet, "/control/tls/status", handleTLSStatus) - httpRegister(http.MethodPost, "/control/tls/configure", handleTLSConfigure) - httpRegister(http.MethodPost, "/control/tls/validate", handleTLSValidate) +type tlsConfigStatus struct { + ValidCert bool `json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates + ValidChain bool `json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA + Subject string `json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain + Issuer string `json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain + NotBefore time.Time `json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain + NotAfter time.Time `json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain + DNSNames []string `json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain + + // key status + ValidKey bool `json:"valid_key"` // ValidKey is true if the key is a valid private key + KeyType string `json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA + + // is usable? set by validator + ValidPair bool `json:"valid_pair"` // ValidPair is true if both certificate and private key are correct + + // warnings + WarningValidation string `json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description } -func handleTLSStatus(w http.ResponseWriter, r *http.Request) { - marshalTLS(w, config.TLS) +// field ordering is important -- yaml fields will mirror ordering from here +type tlsConfig struct { + tlsConfigSettings `json:",inline"` + tlsConfigStatus `json:",inline"` } -func handleTLSValidate(w http.ResponseWriter, r *http.Request) { - data, err := unmarshalTLS(r) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) - return +func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, r *http.Request) { + t.confLock.Lock() + data := tlsConfig{ + tlsConfigSettings: t.conf, + tlsConfigStatus: t.status, } - - // check if port is available - // BUT: if we are already using this port, no need - alreadyRunning := false - if Context.httpsServer.server != nil { - alreadyRunning = true - } - if !alreadyRunning { - err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS) - if err != nil { - httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) - return - } - } - - status := tlsConfigStatus{} - if tlsLoadConfig(&data, &status) { - status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName) - } - data.tlsConfigStatus = status - + t.confLock.Unlock() marshalTLS(w, data) } -func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { +func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) { + setts, err := unmarshalTLS(r) + if err != nil { + httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) + return + } + + if !WebCheckPortAvailable(setts.PortHTTPS) { + httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS) + return + } + + status := tlsConfigStatus{} + if tlsLoadConfig(&setts, &status) { + status = validateCertificates(string(setts.CertificateChainData), string(setts.PrivateKeyData), setts.ServerName) + } + + data := tlsConfig{ + tlsConfigSettings: setts, + tlsConfigStatus: status, + } + marshalTLS(w, data) +} + +func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { data, err := unmarshalTLS(r) if err != nil { httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) return } - // check if port is available - // BUT: if we are already using this port, no need - alreadyRunning := false - if Context.httpsServer.server != nil { - alreadyRunning = true - } - if !alreadyRunning { - err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS) - if err != nil { - httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) - return - } + if !WebCheckPortAvailable(data.PortHTTPS) { + httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) + return } status := tlsConfigStatus{} if !tlsLoadConfig(&data, &status) { - data.tlsConfigStatus = status - marshalTLS(w, data) + data2 := tlsConfig{ + tlsConfigSettings: data, + tlsConfigStatus: t.status, + } + marshalTLS(w, data2) return } - data.tlsConfigStatus = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName) + status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName) restartHTTPS := false - if !reflect.DeepEqual(config.TLS.tlsConfigSettings, data.tlsConfigSettings) { + t.confLock.Lock() + if !reflect.DeepEqual(t.conf, data) { log.Printf("tls config settings have changed, will restart HTTPS server") restartHTTPS = true } - config.TLS = data - err = writeAllConfigsAndReloadDNS() + t.conf = data + t.status = status + t.confLock.Unlock() + t.setCertFileTime() + onConfigModified() + err = reconfigureDNSServer() if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) + httpError(w, http.StatusInternalServerError, "%s", err) return } - marshalTLS(w, data) + data2 := tlsConfig{ + tlsConfigSettings: data, + tlsConfigStatus: t.status, + } + marshalTLS(w, data2) // this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely if restartHTTPS { go func() { time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server - Context.httpsServer.cond.L.Lock() - Context.httpsServer.cond.Broadcast() - if Context.httpsServer.server != nil { - Context.httpsServer.server.Shutdown(context.TODO()) - } - Context.httpsServer.cond.L.Unlock() + Context.web.TLSConfigChanged(data) }() } } @@ -337,8 +458,8 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) { } // unmarshalTLS handles base64-encoded certificates transparently -func unmarshalTLS(r *http.Request) (tlsConfig, error) { - data := tlsConfig{} +func unmarshalTLS(r *http.Request) (tlsConfigSettings, error) { + data := tlsConfigSettings{} err := json.NewDecoder(r.Body).Decode(&data) if err != nil { return data, errorx.Decorate(err, "Failed to parse new TLS config json") @@ -389,3 +510,10 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) { return } } + +// registerWebHandlers registers HTTP handlers for TLS configuration +func (t *TLSMod) registerWebHandlers() { + httpRegister("GET", "/control/tls/status", t.handleTLSStatus) + httpRegister("POST", "/control/tls/configure", t.handleTLSConfigure) + httpRegister("POST", "/control/tls/validate", t.handleTLSValidate) +} diff --git a/home/web.go b/home/web.go new file mode 100644 index 00000000..28db3948 --- /dev/null +++ b/home/web.go @@ -0,0 +1,189 @@ +package home + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "strconv" + "sync" + + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" + "github.com/NYTimes/gziphandler" + "github.com/gobuffalo/packr" +) + +type WebConfig struct { + firstRun bool + BindHost string + BindPort int + PortHTTPS int +} + +// HTTPSServer - HTTPS Server +type HTTPSServer struct { + server *http.Server + cond *sync.Cond + condLock sync.Mutex + shutdown bool // if TRUE, don't restart the server + enabled bool + cert tls.Certificate +} + +// Web - module object +type Web struct { + conf *WebConfig + forceHTTPS bool + portHTTPS int + httpServer *http.Server // HTTP module + httpsServer HTTPSServer // HTTPS module +} + +// CreateWeb - create module +func CreateWeb(conf *WebConfig) *Web { + w := Web{} + w.conf = conf + + // Initialize and run the admin Web interface + box := packr.NewBox("../build/static") + + // if not configured, redirect / to /install.html, otherwise redirect /install.html to / + http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box))))) + + // add handlers for /install paths, we only need them when we're not configured yet + if conf.firstRun { + log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ") + http.Handle("/install.html", preInstallHandler(http.FileServer(box))) + w.registerInstallHandlers() + } else { + registerControlHandlers() + } + + w.httpsServer.cond = sync.NewCond(&w.httpsServer.condLock) + return &w +} + +// WebCheckPortAvailable - check if port is available +// BUT: if we are already using this port, no need +func WebCheckPortAvailable(port int) bool { + alreadyRunning := false + if Context.web.httpsServer.server != nil { + alreadyRunning = true + } + if !alreadyRunning { + err := util.CheckPortAvailable(config.BindHost, port) + if err != nil { + return false + } + } + return true +} + +// TLSConfigChanged - called when TLS configuration has changed +func (w *Web) TLSConfigChanged(tlsConf tlsConfigSettings) { + log.Debug("Web: applying new TLS configuration") + w.conf.PortHTTPS = tlsConf.PortHTTPS + w.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0) + w.portHTTPS = tlsConf.PortHTTPS + + enabled := tlsConf.Enabled && + tlsConf.PortHTTPS != 0 && + len(tlsConf.PrivateKeyData) != 0 && + len(tlsConf.CertificateChainData) != 0 + var cert tls.Certificate + var err error + if enabled { + cert, err = tls.X509KeyPair(tlsConf.CertificateChainData, tlsConf.PrivateKeyData) + if err != nil { + log.Fatal(err) + } + } + + w.httpsServer.cond.L.Lock() + if w.httpsServer.server != nil { + w.httpsServer.server.Shutdown(context.TODO()) + } + w.httpsServer.enabled = enabled + w.httpsServer.cert = cert + w.httpsServer.cond.Broadcast() + w.httpsServer.cond.L.Unlock() +} + +// Start - start serving HTTP requests +func (w *Web) Start() { + // for https, we have a separate goroutine loop + go w.httpServerLoop() + + // this loop is used as an ability to change listening host and/or port + for !w.httpsServer.shutdown { + printHTTPAddresses("http") + + // we need to have new instance, because after Shutdown() the Server is not usable + address := net.JoinHostPort(w.conf.BindHost, strconv.Itoa(w.conf.BindPort)) + w.httpServer = &http.Server{ + Addr: address, + } + err := w.httpServer.ListenAndServe() + if err != http.ErrServerClosed { + cleanupAlways() + log.Fatal(err) + } + // We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop + } +} + +// Close - stop HTTP server, possibly waiting for all active connections to be closed +func (w *Web) Close() { + log.Info("Stopping HTTP server...") + w.httpsServer.cond.L.Lock() + w.httpsServer.shutdown = true + w.httpsServer.cond.L.Unlock() + if w.httpsServer.server != nil { + _ = w.httpsServer.server.Shutdown(context.TODO()) + } + if w.httpServer != nil { + _ = w.httpServer.Shutdown(context.TODO()) + } + + log.Info("Stopped HTTP server") +} + +func (w *Web) httpServerLoop() { + for { + w.httpsServer.cond.L.Lock() + if w.httpsServer.shutdown { + w.httpsServer.cond.L.Unlock() + break + } + + // this mechanism doesn't let us through until all conditions are met + for !w.httpsServer.enabled { // sleep until necessary data is supplied + w.httpsServer.cond.Wait() + if w.httpsServer.shutdown { + w.httpsServer.cond.L.Unlock() + return + } + } + + w.httpsServer.cond.L.Unlock() + + // prepare HTTPS server + address := net.JoinHostPort(w.conf.BindHost, strconv.Itoa(w.conf.PortHTTPS)) + w.httpsServer.server = &http.Server{ + Addr: address, + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{w.httpsServer.cert}, + MinVersion: tls.VersionTLS12, + RootCAs: Context.tlsRoots, + }, + } + + printHTTPAddresses("https") + err := w.httpsServer.server.ListenAndServeTLS("", "") + if err != http.ErrServerClosed { + cleanupAlways() + log.Fatal(err) + } + } +} diff --git a/util/os_freebsd.go b/util/os_freebsd.go index 33311e16..8254d897 100644 --- a/util/os_freebsd.go +++ b/util/os_freebsd.go @@ -25,3 +25,8 @@ func SetRlimit(val uint) { func HaveAdminRights() (bool, error) { return os.Getuid() == 0, nil } + +// SendProcessSignal - send signal to a process +func SendProcessSignal(pid int, sig syscall.Signal) error { + return syscall.Kill(pid, sig) +} diff --git a/util/os_unix.go b/util/os_unix.go index 338edfa8..9826a68c 100644 --- a/util/os_unix.go +++ b/util/os_unix.go @@ -25,3 +25,8 @@ func SetRlimit(val uint) { func HaveAdminRights() (bool, error) { return os.Getuid() == 0, nil } + +// SendProcessSignal - send signal to a process +func SendProcessSignal(pid int, sig syscall.Signal) error { + return syscall.Kill(pid, sig) +} diff --git a/util/os_windows.go b/util/os_windows.go index e081f758..1c09e588 100644 --- a/util/os_windows.go +++ b/util/os_windows.go @@ -1,6 +1,11 @@ package util -import "golang.org/x/sys/windows" +import ( + "fmt" + "syscall" + + "golang.org/x/sys/windows" +) // Set user-specified limit of how many fd's we can use func SetRlimit(val uint) { @@ -26,3 +31,7 @@ func HaveAdminRights() (bool, error) { } return true, nil } + +func SendProcessSignal(pid int, sig syscall.Signal) error { + return fmt.Errorf("not supported on Windows") +}