Merge: + Reload configuration command; refactoring

Close #1302

* commit 'dd7d9dc334181b3923b1532aaae51cae77d5ed55':
  * move getDNSAddresses() and "/dns-query" handler to DNS module
  * TLS is now a separate module (logically)
  * move HTTP server code
  * refactor: move blocked-services functions to dnsfilter
  * refactor
  + service: support "-s reload" command
  + clients: re-read auto-clients from /etc/hosts when SIGHUP is received
This commit is contained in:
Simon Zolin 2020-03-17 17:19:03 +03:00
commit 32baa907b6
20 changed files with 735 additions and 433 deletions

View File

@ -71,7 +71,6 @@ Contents:
![](doc/agh-arch.png) ![](doc/agh-arch.png)
## First startup ## First startup
The first application startup is detected when there's no .yaml configuration file. The first application startup is detected when there's no .yaml configuration file.

View File

@ -18,6 +18,8 @@ import (
const defaultDiscoverTime = time.Second * 3 const defaultDiscoverTime = time.Second * 3
const leaseExpireStatic = 1 const leaseExpireStatic = 1
var webHandlersRegistered = false
// Lease contains the necessary information about a DHCP lease // Lease contains the necessary information about a DHCP lease
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type Lease struct { type Lease struct {
@ -121,9 +123,6 @@ func Create(config ServerConfig) *Server {
return nil return nil
} }
} }
if s.conf.HTTPRegister != nil {
s.registerHandlers()
}
// we can't delay database loading until DHCP server is started, // we can't delay database loading until DHCP server is started,
// because we need static leases functionality available beforehand // because we need static leases functionality available beforehand
@ -221,6 +220,11 @@ func (s *Server) setConfig(config ServerConfig) error {
// Start will listen on port 67 and serve DHCP requests. // Start will listen on port 67 and serve DHCP requests.
func (s *Server) Start() error { func (s *Server) Start() error {
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
webHandlersRegistered = true
s.registerHandlers()
}
// TODO: don't close if interface and addresses are the same // TODO: don't close if interface and addresses are the same
if s.conn != nil { if s.conn != nil {
s.closeConn() s.closeConn()

View File

@ -1,10 +1,9 @@
package home package dnsfilter
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
) )
@ -119,7 +118,7 @@ var serviceRulesArray = []svc{
} }
// convert array to map // convert array to map
func initServices() { func initBlockedServices() {
serviceRules = make(map[string][]*rules.NetworkRule) serviceRules = make(map[string][]*rules.NetworkRule)
for _, s := range serviceRulesArray { for _, s := range serviceRulesArray {
netRules := []*rules.NetworkRule{} netRules := []*rules.NetworkRule{}
@ -135,15 +134,20 @@ func initServices() {
} }
} }
// Return TRUE if a blocked service name is known // BlockedSvcKnown - return TRUE if a blocked service name is known
func blockedSvcKnown(s string) bool { func BlockedSvcKnown(s string) bool {
_, ok := serviceRules[s] _, ok := serviceRules[s]
return ok return ok
} }
// ApplyBlockedServices - set blocked services settings for this DNS request // ApplyBlockedServices - set blocked services settings for this DNS request
func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []string) { func (d *Dnsfilter) ApplyBlockedServices(setts *RequestFilteringSettings, list []string, global bool) {
setts.ServicesRules = []dnsfilter.ServiceEntry{} setts.ServicesRules = []ServiceEntry{}
if global {
d.confLock.RLock()
defer d.confLock.RUnlock()
list = d.Config.BlockedServices
}
for _, name := range list { for _, name := range list {
rules, ok := serviceRules[name] rules, ok := serviceRules[name]
@ -152,51 +156,45 @@ func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []stri
continue continue
} }
s := dnsfilter.ServiceEntry{} s := ServiceEntry{}
s.Name = name s.Name = name
s.Rules = rules s.Rules = rules
setts.ServicesRules = append(setts.ServicesRules, s) setts.ServicesRules = append(setts.ServicesRules, s)
} }
} }
func handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { func (d *Dnsfilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
config.RLock() d.confLock.RLock()
list := config.DNS.BlockedServices list := d.Config.BlockedServices
config.RUnlock() d.confLock.RUnlock()
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(list) err := json.NewEncoder(w).Encode(list)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "json.Encode: %s", err) httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
return return
} }
} }
func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { func (d *Dnsfilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) {
list := []string{} list := []string{}
err := json.NewDecoder(r.Body).Decode(&list) err := json.NewDecoder(r.Body).Decode(&list)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "json.Decode: %s", err) httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return return
} }
config.Lock() d.confLock.Lock()
config.DNS.BlockedServices = list d.Config.BlockedServices = list
config.Unlock() d.confLock.Unlock()
log.Debug("Updated blocked services list: %d", len(list)) log.Debug("Updated blocked services list: %d", len(list))
err = writeAllConfigsAndReloadDNS() d.ConfigModified()
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
} }
httpOK(r, w) // registerBlockedServicesHandlers - register HTTP handlers
} func (d *Dnsfilter) registerBlockedServicesHandlers() {
d.Config.HTTPRegister("GET", "/control/blocked_services/list", d.handleBlockedServicesList)
// RegisterBlockedServicesHandlers - register HTTP handlers d.Config.HTTPRegister("POST", "/control/blocked_services/set", d.handleBlockedServicesSet)
func RegisterBlockedServicesHandlers() {
httpRegister(http.MethodGet, "/control/blocked_services/list", handleBlockedServicesList)
httpRegister(http.MethodPost, "/control/blocked_services/set", handleBlockedServicesSet)
} }

View File

@ -49,6 +49,10 @@ type Config struct {
Rewrites []RewriteEntry `yaml:"rewrites"` Rewrites []RewriteEntry `yaml:"rewrites"`
// Names of services to block (globally).
// Per-client settings can override this configuration.
BlockedServices []string `yaml:"blocked_services"`
// Called when the configuration is changed by HTTP request // Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"` ConfigModified func() `yaml:"-"`
@ -175,6 +179,7 @@ func (d *Dnsfilter) WriteDiskConfig(c *Config) {
d.confLock.Lock() d.confLock.Lock()
*c = d.Config *c = d.Config
c.Rewrites = rewriteArrayDup(d.Config.Rewrites) c.Rewrites = rewriteArrayDup(d.Config.Rewrites)
// BlockedServices
d.confLock.Unlock() d.confLock.Unlock()
} }
@ -633,6 +638,18 @@ func New(c *Config, blockFilters []Filter) *Dnsfilter {
d.prepareRewrites() d.prepareRewrites()
} }
bsvcs := []string{}
for _, s := range d.BlockedServices {
if !BlockedSvcKnown(s) {
log.Debug("skipping unknown blocked-service '%s'", s)
continue
}
bsvcs = append(bsvcs, s)
}
d.BlockedServices = bsvcs
initBlockedServices()
if blockFilters != nil { if blockFilters != nil {
err := d.initFiltering(nil, blockFilters) err := d.initFiltering(nil, blockFilters)
if err != nil { if err != nil {
@ -655,6 +672,7 @@ func (d *Dnsfilter) Start() {
if d.Config.HTTPRegister != nil { // for tests if d.Config.HTTPRegister != nil { // for tests
d.registerSecurityHandlers() d.registerSecurityHandlers()
d.registerRewritesHandlers() d.registerRewritesHandlers()
d.registerBlockedServicesHandlers()
} }
} }

View File

@ -179,6 +179,7 @@ type ServerConfig struct {
FilteringConfig FilteringConfig
TLSConfig TLSConfig
TLSAllowUnencryptedDOH bool
TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2 TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2

View File

@ -376,6 +376,20 @@ func checkDNS(input string, bootstrap []string) error {
return nil return nil
} }
func (s *Server) handleDOH(w http.ResponseWriter, r *http.Request) {
if !s.conf.TLSAllowUnencryptedDOH && r.TLS == nil {
httpError(r, w, http.StatusNotFound, "Not Found")
return
}
if !s.IsRunning() {
httpError(r, w, http.StatusInternalServerError, "DNS server is not running")
return
}
s.ServeHTTP(w, r)
}
func (s *Server) registerHandlers() { func (s *Server) registerHandlers() {
s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig) s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig)
s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig) s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig)
@ -384,4 +398,6 @@ func (s *Server) registerHandlers() {
s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList) s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList)
s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet) s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet)
s.conf.HTTPRegister("", "/dns-query", s.handleDOH)
} }

View File

@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -24,6 +25,8 @@ const (
clientsUpdatePeriod = 1 * time.Hour clientsUpdatePeriod = 1 * time.Hour
) )
var webHandlersRegistered = false
// Client information // Client information
type Client struct { type Client struct {
IDs []string IDs []string
@ -98,13 +101,27 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
clients.addFromConfig(objects) clients.addFromConfig(objects)
if !clients.testing { if !clients.testing {
go clients.periodicUpdate()
clients.addFromDHCP() clients.addFromDHCP()
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged) clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
}
}
// Start - start the module
func (clients *clientsContainer) Start() {
if !clients.testing {
if !webHandlersRegistered {
webHandlersRegistered = true
clients.registerWebHandlers() clients.registerWebHandlers()
} }
go clients.periodicUpdate()
}
}
// Reload - reload auto-clients
func (clients *clientsContainer) Reload() {
clients.addFromHostsFile()
clients.addFromSystemARP()
} }
type clientObject struct { type clientObject struct {
@ -145,7 +162,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
} }
for _, s := range cy.BlockedServices { for _, s := range cy.BlockedServices {
if !blockedSvcKnown(s) { if !dnsfilter.BlockedSvcKnown(s) {
log.Debug("Clients: skipping unknown blocked-service '%s'", s) log.Debug("Clients: skipping unknown blocked-service '%s'", s)
continue continue
} }
@ -194,8 +211,7 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
func (clients *clientsContainer) periodicUpdate() { func (clients *clientsContainer) periodicUpdate() {
for { for {
clients.addFromHostsFile() clients.Reload()
clients.addFromSystemARP()
time.Sleep(clientsUpdatePeriod) time.Sleep(clientsUpdatePeriod)
} }
} }

View File

@ -2,7 +2,6 @@ package home
import ( import (
"io/ioutil" "io/ioutil"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
@ -29,14 +28,6 @@ type logSettings struct {
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
} }
// HTTPSServer - HTTPS Server
type HTTPSServer struct {
server *http.Server
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
sync.Mutex // protects config.TLS
shutdown bool // if TRUE, don't restart the server
}
// configuration is loaded from YAML // configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type configuration struct { type configuration struct {
@ -60,7 +51,7 @@ type configuration struct {
WebSessionTTLHours uint32 `yaml:"web_session_ttl"` WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
DNS dnsConfig `yaml:"dns"` DNS dnsConfig `yaml:"dns"`
TLS tlsConfig `yaml:"tls"` TLS tlsConfigSettings `yaml:"tls"`
Filters []filter `yaml:"filters"` Filters []filter `yaml:"filters"`
WhitelistFilters []filter `yaml:"whitelist_filters"` WhitelistFilters []filter `yaml:"whitelist_filters"`
@ -95,10 +86,6 @@ type dnsConfig struct {
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
DnsfilterConf dnsfilter.Config `yaml:",inline"` DnsfilterConf dnsfilter.Config `yaml:",inline"`
// Names of services to block (globally).
// Per-client settings can override this configuration.
BlockedServices []string `yaml:"blocked_services"`
} }
type tlsConfigSettings struct { type tlsConfigSettings struct {
@ -114,33 +101,6 @@ type tlsConfigSettings struct {
dnsforward.TLSConfig `yaml:",inline" json:",inline"` dnsforward.TLSConfig `yaml:",inline" json:",inline"`
} }
// field ordering is not important -- these are for API and are recalculated on each run
type tlsConfigStatus struct {
ValidCert bool `yaml:"-" json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
ValidChain bool `yaml:"-" json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
Subject string `yaml:"-" json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
Issuer string `yaml:"-" json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
NotBefore time.Time `yaml:"-" json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
NotAfter time.Time `yaml:"-" json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
DNSNames []string `yaml:"-" json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
// key status
ValidKey bool `yaml:"-" json:"valid_key"` // ValidKey is true if the key is a valid private key
KeyType string `yaml:"-" json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
// is usable? set by validator
ValidPair bool `yaml:"-" json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
// warnings
WarningValidation string `yaml:"-" json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
}
// field ordering is important -- yaml fields will mirror ordering from here
type tlsConfig struct {
tlsConfigSettings `yaml:",inline" json:",inline"`
tlsConfigStatus `yaml:"-" json:",inline"`
}
// initialize to default values, will be changed later when reading config or parsing command line // initialize to default values, will be changed later when reading config or parsing command line
var config = configuration{ var config = configuration{
BindPort: 3000, BindPort: 3000,
@ -160,12 +120,10 @@ var config = configuration{
FilteringEnabled: true, // whether or not use filter lists FilteringEnabled: true, // whether or not use filter lists
FiltersUpdateIntervalHours: 24, FiltersUpdateIntervalHours: 24,
}, },
TLS: tlsConfig{ TLS: tlsConfigSettings{
tlsConfigSettings: tlsConfigSettings{
PortHTTPS: 443, PortHTTPS: 443,
PortDNSOverTLS: 853, // needs to be passed through to dnsproxy PortDNSOverTLS: 853, // needs to be passed through to dnsproxy
}, },
},
DHCP: dhcpd.ServerConfig{ DHCP: dhcpd.ServerConfig{
LeaseDuration: 86400, LeaseDuration: 86400,
ICMPTimeout: 1000, ICMPTimeout: 1000,
@ -234,26 +192,10 @@ func parseConfig() error {
return err return err
} }
bsvcs := []string{}
for _, s := range config.DNS.BlockedServices {
if !blockedSvcKnown(s) {
log.Debug("skipping unknown blocked-service '%s'", s)
continue
}
bsvcs = append(bsvcs, s)
}
config.DNS.BlockedServices = bsvcs
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) { if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
config.DNS.FiltersUpdateIntervalHours = 24 config.DNS.FiltersUpdateIntervalHours = 24
} }
status := tlsConfigStatus{}
if !tlsLoadConfig(&config.TLS, &status) {
log.Error("%s", status.WarningValidation)
return err
}
return nil return nil
} }
@ -282,6 +224,11 @@ func (c *configuration) write() error {
if Context.auth != nil { if Context.auth != nil {
config.Users = Context.auth.GetUsers() config.Users = Context.auth.GetUsers()
} }
if Context.tls != nil {
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
config.TLS = tlsConf
}
if Context.stats != nil { if Context.stats != nil {
sdc := stats.DiskConfig{} sdc := stats.DiskConfig{}
@ -331,13 +278,3 @@ func (c *configuration) write() error {
return nil return nil
} }
func writeAllConfigs() error {
err := config.write()
if err != nil {
log.Error("Couldn't write config: %s", err)
return err
}
return nil
}

View File

@ -9,8 +9,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler" "github.com/NYTimes/gziphandler"
@ -27,9 +25,6 @@ func returnOK(w http.ResponseWriter) {
} }
} }
func httpOK(r *http.Request, w http.ResponseWriter) {
}
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) { func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...) text := fmt.Sprintf(format, args...)
log.Info(text) log.Info(text)
@ -39,15 +34,6 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
// --------------- // ---------------
// dns run control // dns run control
// --------------- // ---------------
func writeAllConfigsAndReloadDNS() error {
err := writeAllConfigs()
if err != nil {
log.Error("Couldn't write all configs: %s", err)
return err
}
return reconfigureDNSServer()
}
func addDNSAddress(dnsAddresses *[]string, addr string) { func addDNSAddress(dnsAddresses *[]string, addr string) {
if config.DNS.Port != 53 { if config.DNS.Port != 53 {
addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port) addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port)
@ -55,46 +41,6 @@ func addDNSAddress(dnsAddresses *[]string, addr string) {
*dnsAddresses = append(*dnsAddresses, addr) *dnsAddresses = append(*dnsAddresses, addr)
} }
// Get the list of DNS addresses the server is listening on
func getDNSAddresses() []string {
dnsAddresses := []string{}
if config.DNS.BindHost == "0.0.0.0" {
ifaces, e := util.GetValidNetInterfacesForWeb()
if e != nil {
log.Error("Couldn't get network interfaces: %v", e)
return []string{}
}
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
addDNSAddress(&dnsAddresses, addr)
}
}
} else {
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
}
if config.TLS.Enabled && len(config.TLS.ServerName) != 0 {
if config.TLS.PortHTTPS != 0 {
addr := config.TLS.ServerName
if config.TLS.PortHTTPS != 443 {
addr = fmt.Sprintf("%s:%d", addr, config.TLS.PortHTTPS)
}
addr = fmt.Sprintf("https://%s/dns-query", addr)
dnsAddresses = append(dnsAddresses, addr)
}
if config.TLS.PortDNSOverTLS != 0 {
addr := fmt.Sprintf("tls://%s:%d", config.TLS.ServerName, config.TLS.PortDNSOverTLS)
dnsAddresses = append(dnsAddresses, addr)
}
}
return dnsAddresses
}
func handleStatus(w http.ResponseWriter, r *http.Request) { func handleStatus(w http.ResponseWriter, r *http.Request) {
c := dnsforward.FilteringConfig{} c := dnsforward.FilteringConfig{}
if Context.dnsServer != nil { if Context.dnsServer != nil {
@ -144,23 +90,6 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(data) _, _ = w.Write(data)
} }
// --------------
// DNS-over-HTTPS
// --------------
func handleDOH(w http.ResponseWriter, r *http.Request) {
if !config.TLS.AllowUnencryptedDOH && r.TLS == nil {
httpError(w, http.StatusNotFound, "Not Found")
return
}
if !isRunning() {
httpError(w, http.StatusInternalServerError, "DNS server is not running")
return
}
Context.dnsServer.ServeHTTP(w, r)
}
// ------------------------ // ------------------------
// registration of handlers // registration of handlers
// ------------------------ // ------------------------
@ -172,15 +101,16 @@ func registerControlHandlers() {
httpRegister(http.MethodPost, "/control/update", handleUpdate) httpRegister(http.MethodPost, "/control/update", handleUpdate)
httpRegister("GET", "/control/profile", handleGetProfile) httpRegister("GET", "/control/profile", handleGetProfile)
RegisterTLSHandlers()
RegisterBlockedServicesHandlers()
RegisterAuthHandlers() RegisterAuthHandlers()
http.HandleFunc("/dns-query", postInstall(handleDOH))
} }
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) { func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
if len(method) == 0 {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
http.HandleFunc(url, postInstall(handler))
return
}
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler))))) http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
} }
@ -258,14 +188,16 @@ func preInstallHandler(handler http.Handler) http.Handler {
// it also enforces HTTPS if it is enabled and configured // it also enforces HTTPS if it is enabled and configured
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if Context.firstRun && if Context.firstRun &&
!strings.HasPrefix(r.URL.Path, "/install.") && !strings.HasPrefix(r.URL.Path, "/install.") &&
r.URL.Path != "/favicon.png" { r.URL.Path != "/favicon.png" {
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable http.Redirect(w, r, "/install.html", http.StatusFound)
return return
} }
// enforce https? // enforce https?
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil { if r.TLS == nil && Context.web.forceHTTPS && Context.web.httpsServer.server != nil {
// yes, and we want host from host:port // yes, and we want host from host:port
host, _, err := net.SplitHostPort(r.Host) host, _, err := net.SplitHostPort(r.Host)
if err != nil { if err != nil {
@ -275,13 +207,14 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
// construct new URL to redirect to // construct new URL to redirect to
newURL := url.URL{ newURL := url.URL{
Scheme: "https", Scheme: "https",
Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)), Host: net.JoinHostPort(host, strconv.Itoa(Context.web.portHTTPS)),
Path: r.URL.Path, Path: r.URL.Path,
RawQuery: r.URL.RawQuery, RawQuery: r.URL.RawQuery,
} }
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect) http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
return return
} }
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
handler(w, r) handler(w, r)
} }

View File

@ -349,7 +349,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
setts := Context.dnsFilter.GetConfig() setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true setts.FilteringEnabled = true
ApplyBlockedServices(&setts, config.DNS.BlockedServices) Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err) httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)

View File

@ -35,7 +35,7 @@ type netInterfaceJSON struct {
} }
// Get initial installation settings // Get initial installation settings
func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data := firstRunData{} data := firstRunData{}
data.WebPort = 80 data.WebPort = 80
data.DNSPort = 53 data.DNSPort = 53
@ -93,7 +93,7 @@ type checkConfigResp struct {
} }
// Check if ports are available, respond with results // Check if ports are available, respond with results
func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
reqData := checkConfigReq{} reqData := checkConfigReq{}
respData := checkConfigResp{} respData := checkConfigResp{}
err := json.NewDecoder(r.Body).Decode(&reqData) err := json.NewDecoder(r.Body).Decode(&reqData)
@ -275,7 +275,7 @@ func copyInstallSettings(dst *configuration, src *configuration) {
} }
// Apply new configuration, start DNS server, restart Web server // Apply new configuration, start DNS server, restart Web server
func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
newSettings := applyConfigReq{} newSettings := applyConfigReq{}
err := json.NewDecoder(r.Body).Decode(&newSettings) err := json.NewDecoder(r.Body).Decode(&newSettings)
if err != nil { if err != nil {
@ -325,22 +325,11 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
config.DNS.BindHost = newSettings.DNS.IP config.DNS.BindHost = newSettings.DNS.IP
config.DNS.Port = newSettings.DNS.Port config.DNS.Port = newSettings.DNS.Port
err = initDNSServer() err = StartMods()
var err2 error if err != nil {
if err == nil {
err2 = startDNSServer()
if err2 != nil {
closeDNSServer()
}
}
if err != nil || err2 != nil {
Context.firstRun = true Context.firstRun = true
copyInstallSettings(&config, &curConfig) copyInstallSettings(&config, &curConfig)
if err != nil { httpError(w, http.StatusInternalServerError, "%s", err)
httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err)
} else {
httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err2)
}
return return
} }
@ -356,19 +345,21 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
return return
} }
registerControlHandlers()
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block // this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
if restartHTTP { if restartHTTP {
go func() { go func() {
_ = Context.httpServer.Shutdown(context.TODO()) _ = Context.web.httpServer.Shutdown(context.TODO())
}() }()
} }
returnOK(w) returnOK(w)
} }
func registerInstallHandlers() { func (web *Web) registerInstallHandlers() {
http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(handleInstallGetAddresses))) http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(handleInstallCheckConfig))) http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
http.HandleFunc("/control/install/configure", preInstall(ensurePOST(handleInstallConfigure))) http.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
} }

View File

@ -476,7 +476,6 @@ func doUpdate(u *updateInfo) error {
func finishUpdate(u *updateInfo) { func finishUpdate(u *updateInfo) {
log.Info("Stopping all tasks") log.Info("Stopping all tasks")
cleanup() cleanup()
stopHTTPServer()
cleanupAlways() cleanupAlways()
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {

View File

@ -3,13 +3,13 @@ package home
import ( import (
"fmt" "fmt"
"net" "net"
"os"
"path/filepath" "path/filepath"
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -25,13 +25,9 @@ func onConfigModified() {
// Please note that we must do it even if we don't start it // Please note that we must do it even if we don't start it
// so that we had access to the query log and the stats // so that we had access to the query log and the stats
func initDNSServer() error { func initDNSServer() error {
var err error
baseDir := Context.getDataDir() baseDir := Context.getDataDir()
err := os.MkdirAll(baseDir, 0755)
if err != nil {
return fmt.Errorf("Cannot create DNS data dir at %s: %s", baseDir, err)
}
statsConf := stats.Config{ statsConf := stats.Config{
Filename: filepath.Join(baseDir, "stats.db"), Filename: filepath.Join(baseDir, "stats.db"),
LimitDays: config.DNS.StatsInterval, LimitDays: config.DNS.StatsInterval,
@ -70,14 +66,6 @@ func initDNSServer() error {
return fmt.Errorf("dnsServer.Prepare: %s", err) return fmt.Errorf("dnsServer.Prepare: %s", err)
} }
sessFilename := filepath.Join(baseDir, "sessions.db")
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
if Context.auth == nil {
closeDNSServer()
return fmt.Errorf("Couldn't initialize Auth module")
}
config.Users = nil
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients) Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
Context.whois = initWhois(&Context.clients) Context.whois = initWhois(&Context.clients)
@ -169,26 +157,74 @@ func generateServerConfig() dnsforward.ServerConfig {
OnDNSRequest: onDNSRequest, OnDNSRequest: onDNSRequest,
} }
if config.TLS.Enabled { tlsConf := tlsConfigSettings{}
newconfig.TLSConfig = config.TLS.TLSConfig Context.tls.WriteDiskConfig(&tlsConf)
if config.TLS.PortDNSOverTLS != 0 { if tlsConf.Enabled {
newconfig.TLSListenAddr = &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.TLS.PortDNSOverTLS} newconfig.TLSConfig = tlsConf.TLSConfig
if tlsConf.PortDNSOverTLS != 0 {
newconfig.TLSListenAddr = &net.TCPAddr{
IP: net.ParseIP(config.DNS.BindHost),
Port: tlsConf.PortDNSOverTLS,
}
} }
} }
newconfig.TLSv12Roots = Context.tlsRoots newconfig.TLSv12Roots = Context.tlsRoots
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
newconfig.FilterHandler = applyAdditionalFiltering newconfig.FilterHandler = applyAdditionalFiltering
newconfig.GetUpstreamsByClient = getUpstreamsByClient newconfig.GetUpstreamsByClient = getUpstreamsByClient
return newconfig return newconfig
} }
// Get the list of DNS addresses the server is listening on
func getDNSAddresses() []string {
dnsAddresses := []string{}
if config.DNS.BindHost == "0.0.0.0" {
ifaces, e := util.GetValidNetInterfacesForWeb()
if e != nil {
log.Error("Couldn't get network interfaces: %v", e)
return []string{}
}
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
addDNSAddress(&dnsAddresses, addr)
}
}
} else {
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
}
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
if tlsConf.PortHTTPS != 0 {
addr := tlsConf.ServerName
if tlsConf.PortHTTPS != 443 {
addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS)
}
addr = fmt.Sprintf("https://%s/dns-query", addr)
dnsAddresses = append(dnsAddresses, addr)
}
if tlsConf.PortDNSOverTLS != 0 {
addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS)
dnsAddresses = append(dnsAddresses, addr)
}
}
return dnsAddresses
}
func getUpstreamsByClient(clientAddr string) []upstream.Upstream { func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
return Context.clients.FindUpstreams(clientAddr) return Context.clients.FindUpstreams(clientAddr)
} }
// If a client has his own settings, apply them // If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
ApplyBlockedServices(setts, config.DNS.BlockedServices) Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
if len(clientAddr) == 0 { if len(clientAddr) == 0 {
return return
@ -202,7 +238,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri
log.Debug("Using settings for client with IP %s", clientAddr) log.Debug("Using settings for client with IP %s", clientAddr)
if c.UseOwnBlockedServices { if c.UseOwnBlockedServices {
ApplyBlockedServices(setts, c.BlockedServices) Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
} }
setts.ClientTags = c.Tags setts.ClientTags = c.Tags
@ -224,6 +260,8 @@ func startDNSServer() error {
enableFilters(false) enableFilters(false)
Context.clients.Start()
err := Context.dnsServer.Start() err := Context.dnsServer.Start()
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server") return errorx.Decorate(err, "Couldn't start forwarding DNS server")
@ -295,11 +333,6 @@ func closeDNSServer() {
Context.queryLog = nil Context.queryLog = nil
} }
if Context.auth != nil {
Context.auth.Close()
Context.auth = nil
}
Context.filters.Close() Context.filters.Close()
log.Debug("Closed all DNS modules") log.Debug("Closed all DNS modules")

View File

@ -34,8 +34,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler"
"github.com/gobuffalo/packr"
) )
const ( const (
@ -67,9 +65,9 @@ type homeContext struct {
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
dhcpServer *dhcpd.Server // DHCP module dhcpServer *dhcpd.Server // DHCP module
auth *Auth // HTTP authentication module auth *Auth // HTTP authentication module
httpServer *http.Server // HTTP module filters Filtering // DNS filtering module
httpsServer HTTPSServer // HTTPS module web *Web // Web (HTTP, HTTPS) module
filters Filtering tls *TLSMod // TLS module
// Runtime properties // Runtime properties
// -- // --
@ -116,10 +114,20 @@ func Main(version string, channel string, armVer string) {
Context.appSignalChannel = make(chan os.Signal) Context.appSignalChannel = make(chan os.Signal)
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() { go func() {
<-Context.appSignalChannel for {
sig := <-Context.appSignalChannel
log.Info("Received signal '%s'", sig)
switch sig {
case syscall.SIGHUP:
Context.clients.Reload()
Context.tls.Reload()
default:
cleanup() cleanup()
cleanupAlways() cleanupAlways()
os.Exit(0) os.Exit(0)
}
}
}() }()
// run the protection // run the protection
@ -163,7 +171,6 @@ func run(args options) {
} }
initConfig() initConfig()
initServices()
Context.tlsRoots = util.LoadSystemRootCAs() Context.tlsRoots = util.LoadSystemRootCAs()
Context.transport = &http.Transport{ Context.transport = &http.Transport{
@ -218,6 +225,9 @@ func run(args options) {
if args.bindPort != 0 { if args.bindPort != 0 {
config.BindPort = args.bindPort config.BindPort = args.bindPort
} }
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
Context.pidFileName = args.pidFile
}
if !Context.firstRun { if !Context.firstRun {
// Save the updated config // Save the updated config
@ -225,11 +235,42 @@ func run(args options) {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
}
err = initDNSServer() err := os.MkdirAll(Context.getDataDir(), 0755)
if err != nil {
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
}
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
if Context.auth == nil {
log.Fatalf("Couldn't initialize Auth module")
}
config.Users = nil
Context.tls = tlsCreate(config.TLS)
if Context.tls == nil {
log.Fatalf("Can't initialize TLS module")
}
webConf := WebConfig{
firstRun: Context.firstRun,
BindHost: config.BindHost,
BindPort: config.BindPort,
}
Context.web = CreateWeb(&webConf)
if Context.web == nil {
log.Fatalf("Can't initialize Web module")
}
if !Context.firstRun {
err := initDNSServer()
if err != nil { if err != nil {
log.Fatalf("%s", err) log.Fatalf("%s", err)
} }
Context.tls.Start()
go func() { go func() {
err := startDNSServer() err := startDNSServer()
if err != nil { if err != nil {
@ -243,101 +284,27 @@ func run(args options) {
} }
} }
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { Context.web.Start()
Context.pidFileName = args.pidFile
}
// Initialize and run the admin Web interface
box := packr.NewBox("../build/static")
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
registerControlHandlers()
// add handlers for /install paths, we only need them when we're not configured yet
if Context.firstRun {
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
registerInstallHandlers()
}
Context.httpsServer.cond = sync.NewCond(&Context.httpsServer.Mutex)
// for https, we have a separate goroutine loop
go httpServerLoop()
// this loop is used as an ability to change listening host and/or port
for !Context.httpsServer.shutdown {
printHTTPAddresses("http")
// we need to have new instance, because after Shutdown() the Server is not usable
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
Context.httpServer = &http.Server{
Addr: address,
}
err := Context.httpServer.ListenAndServe()
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
}
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
}
// wait indefinitely for other go-routines to complete their job // wait indefinitely for other go-routines to complete their job
select {} select {}
} }
func httpServerLoop() { // StartMods - initialize and start DNS after installation
for !Context.httpsServer.shutdown { func StartMods() error {
Context.httpsServer.cond.L.Lock() err := initDNSServer()
// this mechanism doesn't let us through until all conditions are met
for config.TLS.Enabled == false ||
config.TLS.PortHTTPS == 0 ||
len(config.TLS.PrivateKeyData) == 0 ||
len(config.TLS.CertificateChainData) == 0 { // sleep until necessary data is supplied
Context.httpsServer.cond.Wait()
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
// validate current TLS config and update warnings (it could have been loaded from file)
data := validateCertificates(string(config.TLS.CertificateChainData), string(config.TLS.PrivateKeyData), config.TLS.ServerName)
if !data.ValidPair {
cleanupAlways()
log.Fatal(data.WarningValidation)
}
config.Lock()
config.TLS.tlsConfigStatus = data // update warnings
config.Unlock()
// prepare certs for HTTPS server
// important -- they have to be copies, otherwise changing the contents in config.TLS will break encryption for in-flight requests
certchain := make([]byte, len(config.TLS.CertificateChainData))
copy(certchain, config.TLS.CertificateChainData)
privatekey := make([]byte, len(config.TLS.PrivateKeyData))
copy(privatekey, config.TLS.PrivateKeyData)
cert, err := tls.X509KeyPair(certchain, privatekey)
if err != nil { if err != nil {
cleanupAlways() return err
log.Fatal(err)
}
Context.httpsServer.cond.L.Unlock()
// prepare HTTPS server
Context.httpsServer.server = &http.Server{
Addr: address,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
RootCAs: Context.tlsRoots,
},
} }
printHTTPAddresses("https") Context.tls.Start()
err = Context.httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed { err = startDNSServer()
cleanupAlways() if err != nil {
log.Fatal(err) closeDNSServer()
} return err
} }
return nil
} }
// Check if the current user has root (administrator) rights // Check if the current user has root (administrator) rights
@ -449,6 +416,15 @@ func configureLogger(args options) {
func cleanup() { func cleanup() {
log.Info("Stopping AdGuard Home") log.Info("Stopping AdGuard Home")
if Context.web != nil {
Context.web.Close()
Context.web = nil
}
if Context.auth != nil {
Context.auth.Close()
Context.auth = nil
}
err := stopDNSServer() err := stopDNSServer()
if err != nil { if err != nil {
log.Error("Couldn't stop DNS server: %s", err) log.Error("Couldn't stop DNS server: %s", err)
@ -457,17 +433,11 @@ func cleanup() {
if err != nil { if err != nil {
log.Error("Couldn't stop DHCP server: %s", err) log.Error("Couldn't stop DHCP server: %s", err)
} }
}
// Stop HTTP server, possibly waiting for all active connections to be closed if Context.tls != nil {
func stopHTTPServer() { Context.tls.Close()
log.Info("Stopping HTTP server...") Context.tls = nil
Context.httpsServer.shutdown = true
if Context.httpsServer.server != nil {
_ = Context.httpsServer.server.Shutdown(context.TODO())
} }
_ = Context.httpServer.Shutdown(context.TODO())
log.Info("Stopped HTTP server")
} }
// This function is called before application exits // This function is called before application exits
@ -519,7 +489,7 @@ func loadOptions() options {
} }
o.bindPort = v o.bindPort = v
}, nil}, }, nil},
{"service", "s", "Service control action: status, install, uninstall, start, stop, restart", func(value string) { {"service", "s", "Service control action: status, install, uninstall, start, stop, restart, reload (configuration)", func(value string) {
o.serviceControlAction = value o.serviceControlAction = value
}, nil}, }, nil},
{"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) { {"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) {
@ -588,11 +558,13 @@ func loadOptions() options {
func printHTTPAddresses(proto string) { func printHTTPAddresses(proto string) {
var address string var address string
if proto == "https" && config.TLS.ServerName != "" { tlsConf := tlsConfigSettings{}
if config.TLS.PortHTTPS == 443 { Context.tls.WriteDiskConfig(&tlsConf)
log.Printf("Go to https://%s", config.TLS.ServerName) if proto == "https" && tlsConf.ServerName != "" {
if tlsConf.PortHTTPS == 443 {
log.Printf("Go to https://%s", tlsConf.ServerName)
} else { } else {
log.Printf("Go to https://%s:%d", config.TLS.ServerName, config.TLS.PortHTTPS) log.Printf("Go to https://%s:%d", tlsConf.ServerName, tlsConf.PortHTTPS)
} }
} else if config.BindHost == "0.0.0.0" { } else if config.BindHost == "0.0.0.0" {
log.Println("AdGuard Home is available on the following addresses:") log.Println("AdGuard Home is available on the following addresses:")

View File

@ -1,9 +1,11 @@
package home package home
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"runtime" "runtime"
"strconv"
"strings" "strings"
"syscall" "syscall"
@ -71,6 +73,48 @@ func svcAction(s service.Service, action string) error {
return err return err
} }
// Send SIGHUP to a process with ID taken from our pid-file
// If pid-file doesn't exist, find our PID using 'ps' command
func sendSigReload() {
if runtime.GOOS == "windows" {
log.Error("Not implemented on Windows")
return
}
pidfile := fmt.Sprintf("/var/run/%s.pid", serviceName)
data, err := ioutil.ReadFile(pidfile)
if os.IsNotExist(err) {
code, psdata, err := util.RunCommand("ps", "-C", serviceName, "-o", "pid=")
if err != nil || code != 0 {
log.Error("Can't find AdGuardHome process: %s code:%d", err, code)
return
}
data = []byte(psdata)
} else if err != nil {
log.Error("Can't read PID file %s: %s", pidfile, err)
return
}
parts := strings.SplitN(string(data), "\n", 2)
if len(parts) == 0 {
log.Error("Can't read PID file %s: bad value", pidfile)
return
}
pid, err := strconv.Atoi(parts[0])
if err != nil {
log.Error("Can't read PID file %s: %s", pidfile, err)
return
}
err = util.SendProcessSignal(pid, syscall.SIGHUP)
if err != nil {
log.Error("Can't send signal to PID %d: %s", pid, err)
return
}
log.Debug("Sent signal to PID %d", pid)
}
// handleServiceControlAction one of the possible control actions: // handleServiceControlAction one of the possible control actions:
// install -- installs a service/daemon // install -- installs a service/daemon
// uninstall -- uninstalls it // uninstall -- uninstalls it
@ -84,6 +128,11 @@ func svcAction(s service.Service, action string) error {
func handleServiceControlAction(action string) { func handleServiceControlAction(action string) {
log.Printf("Service control action: %s", action) log.Printf("Service control action: %s", action)
if action == "reload" {
sendSigReload()
return
}
pwd, err := os.Getwd() pwd, err := os.Getwd()
if err != nil { if err != nil {
log.Fatal("Unable to find the path to the current directory") log.Fatal("Unable to find the path to the current directory")

View File

@ -1,9 +1,6 @@
// Control: TLS configuring handlers
package home package home
import ( import (
"context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
@ -16,18 +13,125 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os"
"reflect" "reflect"
"strings" "strings"
"sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
) )
var tlsWebHandlersRegistered = false
// TLSMod - TLS module object
type TLSMod struct {
certLastMod time.Time // last modification time of the certificate file
conf tlsConfigSettings
confLock sync.Mutex
status tlsConfigStatus
}
// Create TLS module
func tlsCreate(conf tlsConfigSettings) *TLSMod {
t := &TLSMod{}
t.conf = conf
if t.conf.Enabled {
if !t.load() {
return nil
}
t.setCertFileTime()
}
return t
}
func (t *TLSMod) load() bool {
if !tlsLoadConfig(&t.conf, &t.status) {
return false
}
// validate current TLS config and update warnings (it could have been loaded from file)
data := validateCertificates(string(t.conf.CertificateChainData), string(t.conf.PrivateKeyData), t.conf.ServerName)
if !data.ValidPair {
log.Error(data.WarningValidation)
return false
}
t.status = data
return true
}
// Close - close module
func (t *TLSMod) Close() {
}
// WriteDiskConfig - write config
func (t *TLSMod) WriteDiskConfig(conf *tlsConfigSettings) {
t.confLock.Lock()
*conf = t.conf
t.confLock.Unlock()
}
func (t *TLSMod) setCertFileTime() {
if len(t.conf.CertificatePath) == 0 {
return
}
fi, err := os.Stat(t.conf.CertificatePath)
if err != nil {
log.Error("TLS: %s", err)
return
}
t.certLastMod = fi.ModTime().UTC()
}
// Start - start the module
func (t *TLSMod) Start() {
if !tlsWebHandlersRegistered {
tlsWebHandlersRegistered = true
t.registerWebHandlers()
}
t.confLock.Lock()
tlsConf := t.conf
t.confLock.Unlock()
Context.web.TLSConfigChanged(tlsConf)
}
// Reload - reload certificate file
func (t *TLSMod) Reload() {
t.confLock.Lock()
tlsConf := t.conf
t.confLock.Unlock()
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
return
}
fi, err := os.Stat(tlsConf.CertificatePath)
if err != nil {
log.Error("TLS: %s", err)
return
}
if fi.ModTime().UTC().Equal(t.certLastMod) {
log.Debug("TLS: certificate file isn't modified")
return
}
log.Debug("TLS: certificate file is modified")
t.confLock.Lock()
r := t.load()
t.confLock.Unlock()
if !r {
return
}
t.certLastMod = fi.ModTime().UTC()
_ = reconfigureDNSServer()
Context.web.TLSConfigChanged(tlsConf)
}
// Set certificate and private key data // Set certificate and private key data
func tlsLoadConfig(tls *tlsConfig, status *tlsConfigStatus) bool { func tlsLoadConfig(tls *tlsConfigSettings, status *tlsConfigStatus) bool {
tls.CertificateChainData = []byte(tls.CertificateChain) tls.CertificateChainData = []byte(tls.CertificateChain)
tls.PrivateKeyData = []byte(tls.PrivateKey) tls.PrivateKeyData = []byte(tls.PrivateKey)
@ -61,98 +165,115 @@ func tlsLoadConfig(tls *tlsConfig, status *tlsConfigStatus) bool {
return true return true
} }
// RegisterTLSHandlers registers HTTP handlers for TLS configuration type tlsConfigStatus struct {
func RegisterTLSHandlers() { ValidCert bool `json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
httpRegister(http.MethodGet, "/control/tls/status", handleTLSStatus) ValidChain bool `json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
httpRegister(http.MethodPost, "/control/tls/configure", handleTLSConfigure) Subject string `json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
httpRegister(http.MethodPost, "/control/tls/validate", handleTLSValidate) Issuer string `json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
NotBefore time.Time `json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
NotAfter time.Time `json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
DNSNames []string `json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
// key status
ValidKey bool `json:"valid_key"` // ValidKey is true if the key is a valid private key
KeyType string `json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
// is usable? set by validator
ValidPair bool `json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
// warnings
WarningValidation string `json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
} }
func handleTLSStatus(w http.ResponseWriter, r *http.Request) { // field ordering is important -- yaml fields will mirror ordering from here
marshalTLS(w, config.TLS) type tlsConfig struct {
tlsConfigSettings `json:",inline"`
tlsConfigStatus `json:",inline"`
} }
func handleTLSValidate(w http.ResponseWriter, r *http.Request) { func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
data, err := unmarshalTLS(r) t.confLock.Lock()
if err != nil { data := tlsConfig{
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) tlsConfigSettings: t.conf,
return tlsConfigStatus: t.status,
} }
t.confLock.Unlock()
// check if port is available
// BUT: if we are already using this port, no need
alreadyRunning := false
if Context.httpsServer.server != nil {
alreadyRunning = true
}
if !alreadyRunning {
err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS)
if err != nil {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
return
}
}
status := tlsConfigStatus{}
if tlsLoadConfig(&data, &status) {
status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
}
data.tlsConfigStatus = status
marshalTLS(w, data) marshalTLS(w, data)
} }
func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
setts, err := unmarshalTLS(r)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
}
if !WebCheckPortAvailable(setts.PortHTTPS) {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS)
return
}
status := tlsConfigStatus{}
if tlsLoadConfig(&setts, &status) {
status = validateCertificates(string(setts.CertificateChainData), string(setts.PrivateKeyData), setts.ServerName)
}
data := tlsConfig{
tlsConfigSettings: setts,
tlsConfigStatus: status,
}
marshalTLS(w, data)
}
func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
data, err := unmarshalTLS(r) data, err := unmarshalTLS(r)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return return
} }
// check if port is available if !WebCheckPortAvailable(data.PortHTTPS) {
// BUT: if we are already using this port, no need
alreadyRunning := false
if Context.httpsServer.server != nil {
alreadyRunning = true
}
if !alreadyRunning {
err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS)
if err != nil {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
return return
} }
}
status := tlsConfigStatus{} status := tlsConfigStatus{}
if !tlsLoadConfig(&data, &status) { if !tlsLoadConfig(&data, &status) {
data.tlsConfigStatus = status data2 := tlsConfig{
marshalTLS(w, data) tlsConfigSettings: data,
tlsConfigStatus: t.status,
}
marshalTLS(w, data2)
return return
} }
data.tlsConfigStatus = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName) status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
restartHTTPS := false restartHTTPS := false
if !reflect.DeepEqual(config.TLS.tlsConfigSettings, data.tlsConfigSettings) { t.confLock.Lock()
if !reflect.DeepEqual(t.conf, data) {
log.Printf("tls config settings have changed, will restart HTTPS server") log.Printf("tls config settings have changed, will restart HTTPS server")
restartHTTPS = true restartHTTPS = true
} }
config.TLS = data t.conf = data
err = writeAllConfigsAndReloadDNS() t.status = status
t.confLock.Unlock()
t.setCertFileTime()
onConfigModified()
err = reconfigureDNSServer()
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) httpError(w, http.StatusInternalServerError, "%s", err)
return return
} }
marshalTLS(w, data) data2 := tlsConfig{
tlsConfigSettings: data,
tlsConfigStatus: t.status,
}
marshalTLS(w, data2)
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block // this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
if restartHTTPS { if restartHTTPS {
go func() { go func() {
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
Context.httpsServer.cond.L.Lock() Context.web.TLSConfigChanged(data)
Context.httpsServer.cond.Broadcast()
if Context.httpsServer.server != nil {
Context.httpsServer.server.Shutdown(context.TODO())
}
Context.httpsServer.cond.L.Unlock()
}() }()
} }
} }
@ -337,8 +458,8 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
} }
// unmarshalTLS handles base64-encoded certificates transparently // unmarshalTLS handles base64-encoded certificates transparently
func unmarshalTLS(r *http.Request) (tlsConfig, error) { func unmarshalTLS(r *http.Request) (tlsConfigSettings, error) {
data := tlsConfig{} data := tlsConfigSettings{}
err := json.NewDecoder(r.Body).Decode(&data) err := json.NewDecoder(r.Body).Decode(&data)
if err != nil { if err != nil {
return data, errorx.Decorate(err, "Failed to parse new TLS config json") return data, errorx.Decorate(err, "Failed to parse new TLS config json")
@ -389,3 +510,10 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) {
return return
} }
} }
// registerWebHandlers registers HTTP handlers for TLS configuration
func (t *TLSMod) registerWebHandlers() {
httpRegister("GET", "/control/tls/status", t.handleTLSStatus)
httpRegister("POST", "/control/tls/configure", t.handleTLSConfigure)
httpRegister("POST", "/control/tls/validate", t.handleTLSValidate)
}

189
home/web.go Normal file
View File

@ -0,0 +1,189 @@
package home
import (
"context"
"crypto/tls"
"net"
"net/http"
"strconv"
"sync"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler"
"github.com/gobuffalo/packr"
)
type WebConfig struct {
firstRun bool
BindHost string
BindPort int
PortHTTPS int
}
// HTTPSServer - HTTPS Server
type HTTPSServer struct {
server *http.Server
cond *sync.Cond
condLock sync.Mutex
shutdown bool // if TRUE, don't restart the server
enabled bool
cert tls.Certificate
}
// Web - module object
type Web struct {
conf *WebConfig
forceHTTPS bool
portHTTPS int
httpServer *http.Server // HTTP module
httpsServer HTTPSServer // HTTPS module
}
// CreateWeb - create module
func CreateWeb(conf *WebConfig) *Web {
w := Web{}
w.conf = conf
// Initialize and run the admin Web interface
box := packr.NewBox("../build/static")
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
// add handlers for /install paths, we only need them when we're not configured yet
if conf.firstRun {
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
w.registerInstallHandlers()
} else {
registerControlHandlers()
}
w.httpsServer.cond = sync.NewCond(&w.httpsServer.condLock)
return &w
}
// WebCheckPortAvailable - check if port is available
// BUT: if we are already using this port, no need
func WebCheckPortAvailable(port int) bool {
alreadyRunning := false
if Context.web.httpsServer.server != nil {
alreadyRunning = true
}
if !alreadyRunning {
err := util.CheckPortAvailable(config.BindHost, port)
if err != nil {
return false
}
}
return true
}
// TLSConfigChanged - called when TLS configuration has changed
func (w *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
log.Debug("Web: applying new TLS configuration")
w.conf.PortHTTPS = tlsConf.PortHTTPS
w.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
w.portHTTPS = tlsConf.PortHTTPS
enabled := tlsConf.Enabled &&
tlsConf.PortHTTPS != 0 &&
len(tlsConf.PrivateKeyData) != 0 &&
len(tlsConf.CertificateChainData) != 0
var cert tls.Certificate
var err error
if enabled {
cert, err = tls.X509KeyPair(tlsConf.CertificateChainData, tlsConf.PrivateKeyData)
if err != nil {
log.Fatal(err)
}
}
w.httpsServer.cond.L.Lock()
if w.httpsServer.server != nil {
w.httpsServer.server.Shutdown(context.TODO())
}
w.httpsServer.enabled = enabled
w.httpsServer.cert = cert
w.httpsServer.cond.Broadcast()
w.httpsServer.cond.L.Unlock()
}
// Start - start serving HTTP requests
func (w *Web) Start() {
// for https, we have a separate goroutine loop
go w.httpServerLoop()
// this loop is used as an ability to change listening host and/or port
for !w.httpsServer.shutdown {
printHTTPAddresses("http")
// we need to have new instance, because after Shutdown() the Server is not usable
address := net.JoinHostPort(w.conf.BindHost, strconv.Itoa(w.conf.BindPort))
w.httpServer = &http.Server{
Addr: address,
}
err := w.httpServer.ListenAndServe()
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
}
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
}
}
// Close - stop HTTP server, possibly waiting for all active connections to be closed
func (w *Web) Close() {
log.Info("Stopping HTTP server...")
w.httpsServer.cond.L.Lock()
w.httpsServer.shutdown = true
w.httpsServer.cond.L.Unlock()
if w.httpsServer.server != nil {
_ = w.httpsServer.server.Shutdown(context.TODO())
}
if w.httpServer != nil {
_ = w.httpServer.Shutdown(context.TODO())
}
log.Info("Stopped HTTP server")
}
func (w *Web) httpServerLoop() {
for {
w.httpsServer.cond.L.Lock()
if w.httpsServer.shutdown {
w.httpsServer.cond.L.Unlock()
break
}
// this mechanism doesn't let us through until all conditions are met
for !w.httpsServer.enabled { // sleep until necessary data is supplied
w.httpsServer.cond.Wait()
if w.httpsServer.shutdown {
w.httpsServer.cond.L.Unlock()
return
}
}
w.httpsServer.cond.L.Unlock()
// prepare HTTPS server
address := net.JoinHostPort(w.conf.BindHost, strconv.Itoa(w.conf.PortHTTPS))
w.httpsServer.server = &http.Server{
Addr: address,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{w.httpsServer.cert},
MinVersion: tls.VersionTLS12,
RootCAs: Context.tlsRoots,
},
}
printHTTPAddresses("https")
err := w.httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
}
}
}

View File

@ -25,3 +25,8 @@ func SetRlimit(val uint) {
func HaveAdminRights() (bool, error) { func HaveAdminRights() (bool, error) {
return os.Getuid() == 0, nil return os.Getuid() == 0, nil
} }
// SendProcessSignal - send signal to a process
func SendProcessSignal(pid int, sig syscall.Signal) error {
return syscall.Kill(pid, sig)
}

View File

@ -25,3 +25,8 @@ func SetRlimit(val uint) {
func HaveAdminRights() (bool, error) { func HaveAdminRights() (bool, error) {
return os.Getuid() == 0, nil return os.Getuid() == 0, nil
} }
// SendProcessSignal - send signal to a process
func SendProcessSignal(pid int, sig syscall.Signal) error {
return syscall.Kill(pid, sig)
}

View File

@ -1,6 +1,11 @@
package util package util
import "golang.org/x/sys/windows" import (
"fmt"
"syscall"
"golang.org/x/sys/windows"
)
// Set user-specified limit of how many fd's we can use // Set user-specified limit of how many fd's we can use
func SetRlimit(val uint) { func SetRlimit(val uint) {
@ -26,3 +31,7 @@ func HaveAdminRights() (bool, error) {
} }
return true, nil return true, nil
} }
func SendProcessSignal(pid int, sig syscall.Signal) error {
return fmt.Errorf("not supported on Windows")
}