Merge: + Reload configuration command; refactoring
Close #1302 * commit 'dd7d9dc334181b3923b1532aaae51cae77d5ed55': * move getDNSAddresses() and "/dns-query" handler to DNS module * TLS is now a separate module (logically) * move HTTP server code * refactor: move blocked-services functions to dnsfilter * refactor + service: support "-s reload" command + clients: re-read auto-clients from /etc/hosts when SIGHUP is received
This commit is contained in:
commit
32baa907b6
|
@ -71,7 +71,6 @@ Contents:
|
||||||
![](doc/agh-arch.png)
|
![](doc/agh-arch.png)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## First startup
|
## First startup
|
||||||
|
|
||||||
The first application startup is detected when there's no .yaml configuration file.
|
The first application startup is detected when there's no .yaml configuration file.
|
||||||
|
|
|
@ -18,6 +18,8 @@ import (
|
||||||
const defaultDiscoverTime = time.Second * 3
|
const defaultDiscoverTime = time.Second * 3
|
||||||
const leaseExpireStatic = 1
|
const leaseExpireStatic = 1
|
||||||
|
|
||||||
|
var webHandlersRegistered = false
|
||||||
|
|
||||||
// Lease contains the necessary information about a DHCP lease
|
// Lease contains the necessary information about a DHCP lease
|
||||||
// field ordering is important -- yaml fields will mirror ordering from here
|
// field ordering is important -- yaml fields will mirror ordering from here
|
||||||
type Lease struct {
|
type Lease struct {
|
||||||
|
@ -121,9 +123,6 @@ func Create(config ServerConfig) *Server {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if s.conf.HTTPRegister != nil {
|
|
||||||
s.registerHandlers()
|
|
||||||
}
|
|
||||||
|
|
||||||
// we can't delay database loading until DHCP server is started,
|
// we can't delay database loading until DHCP server is started,
|
||||||
// because we need static leases functionality available beforehand
|
// because we need static leases functionality available beforehand
|
||||||
|
@ -221,6 +220,11 @@ func (s *Server) setConfig(config ServerConfig) error {
|
||||||
// Start will listen on port 67 and serve DHCP requests.
|
// Start will listen on port 67 and serve DHCP requests.
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
|
|
||||||
|
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
|
||||||
|
webHandlersRegistered = true
|
||||||
|
s.registerHandlers()
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: don't close if interface and addresses are the same
|
// TODO: don't close if interface and addresses are the same
|
||||||
if s.conn != nil {
|
if s.conn != nil {
|
||||||
s.closeConn()
|
s.closeConn()
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
package home
|
package dnsfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
)
|
)
|
||||||
|
@ -119,7 +118,7 @@ var serviceRulesArray = []svc{
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert array to map
|
// convert array to map
|
||||||
func initServices() {
|
func initBlockedServices() {
|
||||||
serviceRules = make(map[string][]*rules.NetworkRule)
|
serviceRules = make(map[string][]*rules.NetworkRule)
|
||||||
for _, s := range serviceRulesArray {
|
for _, s := range serviceRulesArray {
|
||||||
netRules := []*rules.NetworkRule{}
|
netRules := []*rules.NetworkRule{}
|
||||||
|
@ -135,15 +134,20 @@ func initServices() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return TRUE if a blocked service name is known
|
// BlockedSvcKnown - return TRUE if a blocked service name is known
|
||||||
func blockedSvcKnown(s string) bool {
|
func BlockedSvcKnown(s string) bool {
|
||||||
_, ok := serviceRules[s]
|
_, ok := serviceRules[s]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyBlockedServices - set blocked services settings for this DNS request
|
// ApplyBlockedServices - set blocked services settings for this DNS request
|
||||||
func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []string) {
|
func (d *Dnsfilter) ApplyBlockedServices(setts *RequestFilteringSettings, list []string, global bool) {
|
||||||
setts.ServicesRules = []dnsfilter.ServiceEntry{}
|
setts.ServicesRules = []ServiceEntry{}
|
||||||
|
if global {
|
||||||
|
d.confLock.RLock()
|
||||||
|
defer d.confLock.RUnlock()
|
||||||
|
list = d.Config.BlockedServices
|
||||||
|
}
|
||||||
for _, name := range list {
|
for _, name := range list {
|
||||||
rules, ok := serviceRules[name]
|
rules, ok := serviceRules[name]
|
||||||
|
|
||||||
|
@ -152,51 +156,45 @@ func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []stri
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
s := dnsfilter.ServiceEntry{}
|
s := ServiceEntry{}
|
||||||
s.Name = name
|
s.Name = name
|
||||||
s.Rules = rules
|
s.Rules = rules
|
||||||
setts.ServicesRules = append(setts.ServicesRules, s)
|
setts.ServicesRules = append(setts.ServicesRules, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
|
func (d *Dnsfilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
|
||||||
config.RLock()
|
d.confLock.RLock()
|
||||||
list := config.DNS.BlockedServices
|
list := d.Config.BlockedServices
|
||||||
config.RUnlock()
|
d.confLock.RUnlock()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
err := json.NewEncoder(w).Encode(list)
|
err := json.NewEncoder(w).Encode(list)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusInternalServerError, "json.Encode: %s", err)
|
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) {
|
func (d *Dnsfilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) {
|
||||||
list := []string{}
|
list := []string{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&list)
|
err := json.NewDecoder(r.Body).Decode(&list)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
|
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Lock()
|
d.confLock.Lock()
|
||||||
config.DNS.BlockedServices = list
|
d.Config.BlockedServices = list
|
||||||
config.Unlock()
|
d.confLock.Unlock()
|
||||||
|
|
||||||
log.Debug("Updated blocked services list: %d", len(list))
|
log.Debug("Updated blocked services list: %d", len(list))
|
||||||
|
|
||||||
err = writeAllConfigsAndReloadDNS()
|
d.ConfigModified()
|
||||||
if err != nil {
|
|
||||||
httpError(w, http.StatusBadRequest, "%s", err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
httpOK(r, w)
|
// registerBlockedServicesHandlers - register HTTP handlers
|
||||||
}
|
func (d *Dnsfilter) registerBlockedServicesHandlers() {
|
||||||
|
d.Config.HTTPRegister("GET", "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||||
// RegisterBlockedServicesHandlers - register HTTP handlers
|
d.Config.HTTPRegister("POST", "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||||
func RegisterBlockedServicesHandlers() {
|
|
||||||
httpRegister(http.MethodGet, "/control/blocked_services/list", handleBlockedServicesList)
|
|
||||||
httpRegister(http.MethodPost, "/control/blocked_services/set", handleBlockedServicesSet)
|
|
||||||
}
|
}
|
|
@ -49,6 +49,10 @@ type Config struct {
|
||||||
|
|
||||||
Rewrites []RewriteEntry `yaml:"rewrites"`
|
Rewrites []RewriteEntry `yaml:"rewrites"`
|
||||||
|
|
||||||
|
// Names of services to block (globally).
|
||||||
|
// Per-client settings can override this configuration.
|
||||||
|
BlockedServices []string `yaml:"blocked_services"`
|
||||||
|
|
||||||
// Called when the configuration is changed by HTTP request
|
// Called when the configuration is changed by HTTP request
|
||||||
ConfigModified func() `yaml:"-"`
|
ConfigModified func() `yaml:"-"`
|
||||||
|
|
||||||
|
@ -175,6 +179,7 @@ func (d *Dnsfilter) WriteDiskConfig(c *Config) {
|
||||||
d.confLock.Lock()
|
d.confLock.Lock()
|
||||||
*c = d.Config
|
*c = d.Config
|
||||||
c.Rewrites = rewriteArrayDup(d.Config.Rewrites)
|
c.Rewrites = rewriteArrayDup(d.Config.Rewrites)
|
||||||
|
// BlockedServices
|
||||||
d.confLock.Unlock()
|
d.confLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -633,6 +638,18 @@ func New(c *Config, blockFilters []Filter) *Dnsfilter {
|
||||||
d.prepareRewrites()
|
d.prepareRewrites()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bsvcs := []string{}
|
||||||
|
for _, s := range d.BlockedServices {
|
||||||
|
if !BlockedSvcKnown(s) {
|
||||||
|
log.Debug("skipping unknown blocked-service '%s'", s)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
bsvcs = append(bsvcs, s)
|
||||||
|
}
|
||||||
|
d.BlockedServices = bsvcs
|
||||||
|
|
||||||
|
initBlockedServices()
|
||||||
|
|
||||||
if blockFilters != nil {
|
if blockFilters != nil {
|
||||||
err := d.initFiltering(nil, blockFilters)
|
err := d.initFiltering(nil, blockFilters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -655,6 +672,7 @@ func (d *Dnsfilter) Start() {
|
||||||
if d.Config.HTTPRegister != nil { // for tests
|
if d.Config.HTTPRegister != nil { // for tests
|
||||||
d.registerSecurityHandlers()
|
d.registerSecurityHandlers()
|
||||||
d.registerRewritesHandlers()
|
d.registerRewritesHandlers()
|
||||||
|
d.registerBlockedServicesHandlers()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -179,6 +179,7 @@ type ServerConfig struct {
|
||||||
|
|
||||||
FilteringConfig
|
FilteringConfig
|
||||||
TLSConfig
|
TLSConfig
|
||||||
|
TLSAllowUnencryptedDOH bool
|
||||||
|
|
||||||
TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2
|
TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2
|
||||||
|
|
||||||
|
|
|
@ -376,6 +376,20 @@ func checkDNS(input string, bootstrap []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDOH(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !s.conf.TLSAllowUnencryptedDOH && r.TLS == nil {
|
||||||
|
httpError(r, w, http.StatusNotFound, "Not Found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.IsRunning() {
|
||||||
|
httpError(r, w, http.StatusInternalServerError, "DNS server is not running")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) registerHandlers() {
|
func (s *Server) registerHandlers() {
|
||||||
s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig)
|
s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig)
|
||||||
s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig)
|
s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig)
|
||||||
|
@ -384,4 +398,6 @@ func (s *Server) registerHandlers() {
|
||||||
|
|
||||||
s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList)
|
s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList)
|
||||||
s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet)
|
s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet)
|
||||||
|
|
||||||
|
s.conf.HTTPRegister("", "/dns-query", s.handleDOH)
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -24,6 +25,8 @@ const (
|
||||||
clientsUpdatePeriod = 1 * time.Hour
|
clientsUpdatePeriod = 1 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var webHandlersRegistered = false
|
||||||
|
|
||||||
// Client information
|
// Client information
|
||||||
type Client struct {
|
type Client struct {
|
||||||
IDs []string
|
IDs []string
|
||||||
|
@ -98,13 +101,27 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
|
||||||
clients.addFromConfig(objects)
|
clients.addFromConfig(objects)
|
||||||
|
|
||||||
if !clients.testing {
|
if !clients.testing {
|
||||||
go clients.periodicUpdate()
|
|
||||||
|
|
||||||
clients.addFromDHCP()
|
clients.addFromDHCP()
|
||||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start - start the module
|
||||||
|
func (clients *clientsContainer) Start() {
|
||||||
|
if !clients.testing {
|
||||||
|
if !webHandlersRegistered {
|
||||||
|
webHandlersRegistered = true
|
||||||
clients.registerWebHandlers()
|
clients.registerWebHandlers()
|
||||||
}
|
}
|
||||||
|
go clients.periodicUpdate()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload - reload auto-clients
|
||||||
|
func (clients *clientsContainer) Reload() {
|
||||||
|
clients.addFromHostsFile()
|
||||||
|
clients.addFromSystemARP()
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientObject struct {
|
type clientObject struct {
|
||||||
|
@ -145,7 +162,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, s := range cy.BlockedServices {
|
for _, s := range cy.BlockedServices {
|
||||||
if !blockedSvcKnown(s) {
|
if !dnsfilter.BlockedSvcKnown(s) {
|
||||||
log.Debug("Clients: skipping unknown blocked-service '%s'", s)
|
log.Debug("Clients: skipping unknown blocked-service '%s'", s)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -194,8 +211,7 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
|
||||||
|
|
||||||
func (clients *clientsContainer) periodicUpdate() {
|
func (clients *clientsContainer) periodicUpdate() {
|
||||||
for {
|
for {
|
||||||
clients.addFromHostsFile()
|
clients.Reload()
|
||||||
clients.addFromSystemARP()
|
|
||||||
time.Sleep(clientsUpdatePeriod)
|
time.Sleep(clientsUpdatePeriod)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -29,14 +28,6 @@ type logSettings struct {
|
||||||
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
|
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPSServer - HTTPS Server
|
|
||||||
type HTTPSServer struct {
|
|
||||||
server *http.Server
|
|
||||||
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
|
|
||||||
sync.Mutex // protects config.TLS
|
|
||||||
shutdown bool // if TRUE, don't restart the server
|
|
||||||
}
|
|
||||||
|
|
||||||
// configuration is loaded from YAML
|
// configuration is loaded from YAML
|
||||||
// field ordering is important -- yaml fields will mirror ordering from here
|
// field ordering is important -- yaml fields will mirror ordering from here
|
||||||
type configuration struct {
|
type configuration struct {
|
||||||
|
@ -60,7 +51,7 @@ type configuration struct {
|
||||||
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
|
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
|
||||||
|
|
||||||
DNS dnsConfig `yaml:"dns"`
|
DNS dnsConfig `yaml:"dns"`
|
||||||
TLS tlsConfig `yaml:"tls"`
|
TLS tlsConfigSettings `yaml:"tls"`
|
||||||
|
|
||||||
Filters []filter `yaml:"filters"`
|
Filters []filter `yaml:"filters"`
|
||||||
WhitelistFilters []filter `yaml:"whitelist_filters"`
|
WhitelistFilters []filter `yaml:"whitelist_filters"`
|
||||||
|
@ -95,10 +86,6 @@ type dnsConfig struct {
|
||||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||||
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
||||||
DnsfilterConf dnsfilter.Config `yaml:",inline"`
|
DnsfilterConf dnsfilter.Config `yaml:",inline"`
|
||||||
|
|
||||||
// Names of services to block (globally).
|
|
||||||
// Per-client settings can override this configuration.
|
|
||||||
BlockedServices []string `yaml:"blocked_services"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type tlsConfigSettings struct {
|
type tlsConfigSettings struct {
|
||||||
|
@ -114,33 +101,6 @@ type tlsConfigSettings struct {
|
||||||
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
|
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// field ordering is not important -- these are for API and are recalculated on each run
|
|
||||||
type tlsConfigStatus struct {
|
|
||||||
ValidCert bool `yaml:"-" json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
|
|
||||||
ValidChain bool `yaml:"-" json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
|
|
||||||
Subject string `yaml:"-" json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
|
|
||||||
Issuer string `yaml:"-" json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
|
|
||||||
NotBefore time.Time `yaml:"-" json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
|
|
||||||
NotAfter time.Time `yaml:"-" json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
|
|
||||||
DNSNames []string `yaml:"-" json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
|
|
||||||
|
|
||||||
// key status
|
|
||||||
ValidKey bool `yaml:"-" json:"valid_key"` // ValidKey is true if the key is a valid private key
|
|
||||||
KeyType string `yaml:"-" json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
|
|
||||||
|
|
||||||
// is usable? set by validator
|
|
||||||
ValidPair bool `yaml:"-" json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
|
|
||||||
|
|
||||||
// warnings
|
|
||||||
WarningValidation string `yaml:"-" json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
|
|
||||||
}
|
|
||||||
|
|
||||||
// field ordering is important -- yaml fields will mirror ordering from here
|
|
||||||
type tlsConfig struct {
|
|
||||||
tlsConfigSettings `yaml:",inline" json:",inline"`
|
|
||||||
tlsConfigStatus `yaml:"-" json:",inline"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// initialize to default values, will be changed later when reading config or parsing command line
|
// initialize to default values, will be changed later when reading config or parsing command line
|
||||||
var config = configuration{
|
var config = configuration{
|
||||||
BindPort: 3000,
|
BindPort: 3000,
|
||||||
|
@ -160,12 +120,10 @@ var config = configuration{
|
||||||
FilteringEnabled: true, // whether or not use filter lists
|
FilteringEnabled: true, // whether or not use filter lists
|
||||||
FiltersUpdateIntervalHours: 24,
|
FiltersUpdateIntervalHours: 24,
|
||||||
},
|
},
|
||||||
TLS: tlsConfig{
|
TLS: tlsConfigSettings{
|
||||||
tlsConfigSettings: tlsConfigSettings{
|
|
||||||
PortHTTPS: 443,
|
PortHTTPS: 443,
|
||||||
PortDNSOverTLS: 853, // needs to be passed through to dnsproxy
|
PortDNSOverTLS: 853, // needs to be passed through to dnsproxy
|
||||||
},
|
},
|
||||||
},
|
|
||||||
DHCP: dhcpd.ServerConfig{
|
DHCP: dhcpd.ServerConfig{
|
||||||
LeaseDuration: 86400,
|
LeaseDuration: 86400,
|
||||||
ICMPTimeout: 1000,
|
ICMPTimeout: 1000,
|
||||||
|
@ -234,26 +192,10 @@ func parseConfig() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
bsvcs := []string{}
|
|
||||||
for _, s := range config.DNS.BlockedServices {
|
|
||||||
if !blockedSvcKnown(s) {
|
|
||||||
log.Debug("skipping unknown blocked-service '%s'", s)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
bsvcs = append(bsvcs, s)
|
|
||||||
}
|
|
||||||
config.DNS.BlockedServices = bsvcs
|
|
||||||
|
|
||||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||||
config.DNS.FiltersUpdateIntervalHours = 24
|
config.DNS.FiltersUpdateIntervalHours = 24
|
||||||
}
|
}
|
||||||
|
|
||||||
status := tlsConfigStatus{}
|
|
||||||
if !tlsLoadConfig(&config.TLS, &status) {
|
|
||||||
log.Error("%s", status.WarningValidation)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -282,6 +224,11 @@ func (c *configuration) write() error {
|
||||||
if Context.auth != nil {
|
if Context.auth != nil {
|
||||||
config.Users = Context.auth.GetUsers()
|
config.Users = Context.auth.GetUsers()
|
||||||
}
|
}
|
||||||
|
if Context.tls != nil {
|
||||||
|
tlsConf := tlsConfigSettings{}
|
||||||
|
Context.tls.WriteDiskConfig(&tlsConf)
|
||||||
|
config.TLS = tlsConf
|
||||||
|
}
|
||||||
|
|
||||||
if Context.stats != nil {
|
if Context.stats != nil {
|
||||||
sdc := stats.DiskConfig{}
|
sdc := stats.DiskConfig{}
|
||||||
|
@ -331,13 +278,3 @@ func (c *configuration) write() error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeAllConfigs() error {
|
|
||||||
err := config.write()
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Couldn't write config: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -9,8 +9,6 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/util"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/NYTimes/gziphandler"
|
"github.com/NYTimes/gziphandler"
|
||||||
|
@ -27,9 +25,6 @@ func returnOK(w http.ResponseWriter) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpOK(r *http.Request, w http.ResponseWriter) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
|
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||||
text := fmt.Sprintf(format, args...)
|
text := fmt.Sprintf(format, args...)
|
||||||
log.Info(text)
|
log.Info(text)
|
||||||
|
@ -39,15 +34,6 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
|
||||||
// ---------------
|
// ---------------
|
||||||
// dns run control
|
// dns run control
|
||||||
// ---------------
|
// ---------------
|
||||||
func writeAllConfigsAndReloadDNS() error {
|
|
||||||
err := writeAllConfigs()
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Couldn't write all configs: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return reconfigureDNSServer()
|
|
||||||
}
|
|
||||||
|
|
||||||
func addDNSAddress(dnsAddresses *[]string, addr string) {
|
func addDNSAddress(dnsAddresses *[]string, addr string) {
|
||||||
if config.DNS.Port != 53 {
|
if config.DNS.Port != 53 {
|
||||||
addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port)
|
addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port)
|
||||||
|
@ -55,46 +41,6 @@ func addDNSAddress(dnsAddresses *[]string, addr string) {
|
||||||
*dnsAddresses = append(*dnsAddresses, addr)
|
*dnsAddresses = append(*dnsAddresses, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the list of DNS addresses the server is listening on
|
|
||||||
func getDNSAddresses() []string {
|
|
||||||
dnsAddresses := []string{}
|
|
||||||
|
|
||||||
if config.DNS.BindHost == "0.0.0.0" {
|
|
||||||
ifaces, e := util.GetValidNetInterfacesForWeb()
|
|
||||||
if e != nil {
|
|
||||||
log.Error("Couldn't get network interfaces: %v", e)
|
|
||||||
return []string{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, iface := range ifaces {
|
|
||||||
for _, addr := range iface.Addresses {
|
|
||||||
addDNSAddress(&dnsAddresses, addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TLS.Enabled && len(config.TLS.ServerName) != 0 {
|
|
||||||
|
|
||||||
if config.TLS.PortHTTPS != 0 {
|
|
||||||
addr := config.TLS.ServerName
|
|
||||||
if config.TLS.PortHTTPS != 443 {
|
|
||||||
addr = fmt.Sprintf("%s:%d", addr, config.TLS.PortHTTPS)
|
|
||||||
}
|
|
||||||
addr = fmt.Sprintf("https://%s/dns-query", addr)
|
|
||||||
dnsAddresses = append(dnsAddresses, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TLS.PortDNSOverTLS != 0 {
|
|
||||||
addr := fmt.Sprintf("tls://%s:%d", config.TLS.ServerName, config.TLS.PortDNSOverTLS)
|
|
||||||
dnsAddresses = append(dnsAddresses, addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return dnsAddresses
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
c := dnsforward.FilteringConfig{}
|
c := dnsforward.FilteringConfig{}
|
||||||
if Context.dnsServer != nil {
|
if Context.dnsServer != nil {
|
||||||
|
@ -144,23 +90,6 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||||
_, _ = w.Write(data)
|
_, _ = w.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --------------
|
|
||||||
// DNS-over-HTTPS
|
|
||||||
// --------------
|
|
||||||
func handleDOH(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !config.TLS.AllowUnencryptedDOH && r.TLS == nil {
|
|
||||||
httpError(w, http.StatusNotFound, "Not Found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isRunning() {
|
|
||||||
httpError(w, http.StatusInternalServerError, "DNS server is not running")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Context.dnsServer.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------
|
// ------------------------
|
||||||
// registration of handlers
|
// registration of handlers
|
||||||
// ------------------------
|
// ------------------------
|
||||||
|
@ -172,15 +101,16 @@ func registerControlHandlers() {
|
||||||
httpRegister(http.MethodPost, "/control/update", handleUpdate)
|
httpRegister(http.MethodPost, "/control/update", handleUpdate)
|
||||||
|
|
||||||
httpRegister("GET", "/control/profile", handleGetProfile)
|
httpRegister("GET", "/control/profile", handleGetProfile)
|
||||||
|
|
||||||
RegisterTLSHandlers()
|
|
||||||
RegisterBlockedServicesHandlers()
|
|
||||||
RegisterAuthHandlers()
|
RegisterAuthHandlers()
|
||||||
|
|
||||||
http.HandleFunc("/dns-query", postInstall(handleDOH))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
|
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
|
||||||
|
if len(method) == 0 {
|
||||||
|
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||||
|
http.HandleFunc(url, postInstall(handler))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,14 +188,16 @@ func preInstallHandler(handler http.Handler) http.Handler {
|
||||||
// it also enforces HTTPS if it is enabled and configured
|
// it also enforces HTTPS if it is enabled and configured
|
||||||
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
if Context.firstRun &&
|
if Context.firstRun &&
|
||||||
!strings.HasPrefix(r.URL.Path, "/install.") &&
|
!strings.HasPrefix(r.URL.Path, "/install.") &&
|
||||||
r.URL.Path != "/favicon.png" {
|
r.URL.Path != "/favicon.png" {
|
||||||
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable
|
http.Redirect(w, r, "/install.html", http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// enforce https?
|
// enforce https?
|
||||||
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil {
|
if r.TLS == nil && Context.web.forceHTTPS && Context.web.httpsServer.server != nil {
|
||||||
// yes, and we want host from host:port
|
// yes, and we want host from host:port
|
||||||
host, _, err := net.SplitHostPort(r.Host)
|
host, _, err := net.SplitHostPort(r.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -275,13 +207,14 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
|
||||||
// construct new URL to redirect to
|
// construct new URL to redirect to
|
||||||
newURL := url.URL{
|
newURL := url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)),
|
Host: net.JoinHostPort(host, strconv.Itoa(Context.web.portHTTPS)),
|
||||||
Path: r.URL.Path,
|
Path: r.URL.Path,
|
||||||
RawQuery: r.URL.RawQuery,
|
RawQuery: r.URL.RawQuery,
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
|
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
handler(w, r)
|
handler(w, r)
|
||||||
}
|
}
|
||||||
|
|
|
@ -349,7 +349,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
setts := Context.dnsFilter.GetConfig()
|
setts := Context.dnsFilter.GetConfig()
|
||||||
setts.FilteringEnabled = true
|
setts.FilteringEnabled = true
|
||||||
ApplyBlockedServices(&setts, config.DNS.BlockedServices)
|
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
|
||||||
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
|
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
|
||||||
|
|
|
@ -35,7 +35,7 @@ type netInterfaceJSON struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get initial installation settings
|
// Get initial installation settings
|
||||||
func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
||||||
data := firstRunData{}
|
data := firstRunData{}
|
||||||
data.WebPort = 80
|
data.WebPort = 80
|
||||||
data.DNSPort = 53
|
data.DNSPort = 53
|
||||||
|
@ -93,7 +93,7 @@ type checkConfigResp struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if ports are available, respond with results
|
// Check if ports are available, respond with results
|
||||||
func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
reqData := checkConfigReq{}
|
reqData := checkConfigReq{}
|
||||||
respData := checkConfigResp{}
|
respData := checkConfigResp{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||||
|
@ -275,7 +275,7 @@ func copyInstallSettings(dst *configuration, src *configuration) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply new configuration, start DNS server, restart Web server
|
// Apply new configuration, start DNS server, restart Web server
|
||||||
func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||||
newSettings := applyConfigReq{}
|
newSettings := applyConfigReq{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&newSettings)
|
err := json.NewDecoder(r.Body).Decode(&newSettings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -325,22 +325,11 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||||
config.DNS.BindHost = newSettings.DNS.IP
|
config.DNS.BindHost = newSettings.DNS.IP
|
||||||
config.DNS.Port = newSettings.DNS.Port
|
config.DNS.Port = newSettings.DNS.Port
|
||||||
|
|
||||||
err = initDNSServer()
|
err = StartMods()
|
||||||
var err2 error
|
if err != nil {
|
||||||
if err == nil {
|
|
||||||
err2 = startDNSServer()
|
|
||||||
if err2 != nil {
|
|
||||||
closeDNSServer()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil || err2 != nil {
|
|
||||||
Context.firstRun = true
|
Context.firstRun = true
|
||||||
copyInstallSettings(&config, &curConfig)
|
copyInstallSettings(&config, &curConfig)
|
||||||
if err != nil {
|
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||||
httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err)
|
|
||||||
} else {
|
|
||||||
httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err2)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -356,19 +345,21 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
registerControlHandlers()
|
||||||
|
|
||||||
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
||||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||||
if restartHTTP {
|
if restartHTTP {
|
||||||
go func() {
|
go func() {
|
||||||
_ = Context.httpServer.Shutdown(context.TODO())
|
_ = Context.web.httpServer.Shutdown(context.TODO())
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
returnOK(w)
|
returnOK(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerInstallHandlers() {
|
func (web *Web) registerInstallHandlers() {
|
||||||
http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(handleInstallGetAddresses)))
|
http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
|
||||||
http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(handleInstallCheckConfig)))
|
http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
|
||||||
http.HandleFunc("/control/install/configure", preInstall(ensurePOST(handleInstallConfigure)))
|
http.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
|
||||||
}
|
}
|
||||||
|
|
|
@ -476,7 +476,6 @@ func doUpdate(u *updateInfo) error {
|
||||||
func finishUpdate(u *updateInfo) {
|
func finishUpdate(u *updateInfo) {
|
||||||
log.Info("Stopping all tasks")
|
log.Info("Stopping all tasks")
|
||||||
cleanup()
|
cleanup()
|
||||||
stopHTTPServer()
|
|
||||||
cleanupAlways()
|
cleanupAlways()
|
||||||
|
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
|
|
83
home/dns.go
83
home/dns.go
|
@ -3,13 +3,13 @@ package home
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -25,13 +25,9 @@ func onConfigModified() {
|
||||||
// Please note that we must do it even if we don't start it
|
// Please note that we must do it even if we don't start it
|
||||||
// so that we had access to the query log and the stats
|
// so that we had access to the query log and the stats
|
||||||
func initDNSServer() error {
|
func initDNSServer() error {
|
||||||
|
var err error
|
||||||
baseDir := Context.getDataDir()
|
baseDir := Context.getDataDir()
|
||||||
|
|
||||||
err := os.MkdirAll(baseDir, 0755)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Cannot create DNS data dir at %s: %s", baseDir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
statsConf := stats.Config{
|
statsConf := stats.Config{
|
||||||
Filename: filepath.Join(baseDir, "stats.db"),
|
Filename: filepath.Join(baseDir, "stats.db"),
|
||||||
LimitDays: config.DNS.StatsInterval,
|
LimitDays: config.DNS.StatsInterval,
|
||||||
|
@ -70,14 +66,6 @@ func initDNSServer() error {
|
||||||
return fmt.Errorf("dnsServer.Prepare: %s", err)
|
return fmt.Errorf("dnsServer.Prepare: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sessFilename := filepath.Join(baseDir, "sessions.db")
|
|
||||||
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
|
|
||||||
if Context.auth == nil {
|
|
||||||
closeDNSServer()
|
|
||||||
return fmt.Errorf("Couldn't initialize Auth module")
|
|
||||||
}
|
|
||||||
config.Users = nil
|
|
||||||
|
|
||||||
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
|
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
|
||||||
Context.whois = initWhois(&Context.clients)
|
Context.whois = initWhois(&Context.clients)
|
||||||
|
|
||||||
|
@ -169,26 +157,74 @@ func generateServerConfig() dnsforward.ServerConfig {
|
||||||
OnDNSRequest: onDNSRequest,
|
OnDNSRequest: onDNSRequest,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.TLS.Enabled {
|
tlsConf := tlsConfigSettings{}
|
||||||
newconfig.TLSConfig = config.TLS.TLSConfig
|
Context.tls.WriteDiskConfig(&tlsConf)
|
||||||
if config.TLS.PortDNSOverTLS != 0 {
|
if tlsConf.Enabled {
|
||||||
newconfig.TLSListenAddr = &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.TLS.PortDNSOverTLS}
|
newconfig.TLSConfig = tlsConf.TLSConfig
|
||||||
|
if tlsConf.PortDNSOverTLS != 0 {
|
||||||
|
newconfig.TLSListenAddr = &net.TCPAddr{
|
||||||
|
IP: net.ParseIP(config.DNS.BindHost),
|
||||||
|
Port: tlsConf.PortDNSOverTLS,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
newconfig.TLSv12Roots = Context.tlsRoots
|
newconfig.TLSv12Roots = Context.tlsRoots
|
||||||
|
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||||
|
|
||||||
newconfig.FilterHandler = applyAdditionalFiltering
|
newconfig.FilterHandler = applyAdditionalFiltering
|
||||||
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
||||||
return newconfig
|
return newconfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the list of DNS addresses the server is listening on
|
||||||
|
func getDNSAddresses() []string {
|
||||||
|
dnsAddresses := []string{}
|
||||||
|
|
||||||
|
if config.DNS.BindHost == "0.0.0.0" {
|
||||||
|
ifaces, e := util.GetValidNetInterfacesForWeb()
|
||||||
|
if e != nil {
|
||||||
|
log.Error("Couldn't get network interfaces: %v", e)
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
for _, addr := range iface.Addresses {
|
||||||
|
addDNSAddress(&dnsAddresses, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConf := tlsConfigSettings{}
|
||||||
|
Context.tls.WriteDiskConfig(&tlsConf)
|
||||||
|
if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
|
||||||
|
|
||||||
|
if tlsConf.PortHTTPS != 0 {
|
||||||
|
addr := tlsConf.ServerName
|
||||||
|
if tlsConf.PortHTTPS != 443 {
|
||||||
|
addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS)
|
||||||
|
}
|
||||||
|
addr = fmt.Sprintf("https://%s/dns-query", addr)
|
||||||
|
dnsAddresses = append(dnsAddresses, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConf.PortDNSOverTLS != 0 {
|
||||||
|
addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS)
|
||||||
|
dnsAddresses = append(dnsAddresses, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dnsAddresses
|
||||||
|
}
|
||||||
|
|
||||||
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
||||||
return Context.clients.FindUpstreams(clientAddr)
|
return Context.clients.FindUpstreams(clientAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If a client has his own settings, apply them
|
// If a client has his own settings, apply them
|
||||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||||
ApplyBlockedServices(setts, config.DNS.BlockedServices)
|
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||||
|
|
||||||
if len(clientAddr) == 0 {
|
if len(clientAddr) == 0 {
|
||||||
return
|
return
|
||||||
|
@ -202,7 +238,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri
|
||||||
log.Debug("Using settings for client with IP %s", clientAddr)
|
log.Debug("Using settings for client with IP %s", clientAddr)
|
||||||
|
|
||||||
if c.UseOwnBlockedServices {
|
if c.UseOwnBlockedServices {
|
||||||
ApplyBlockedServices(setts, c.BlockedServices)
|
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
setts.ClientTags = c.Tags
|
setts.ClientTags = c.Tags
|
||||||
|
@ -224,6 +260,8 @@ func startDNSServer() error {
|
||||||
|
|
||||||
enableFilters(false)
|
enableFilters(false)
|
||||||
|
|
||||||
|
Context.clients.Start()
|
||||||
|
|
||||||
err := Context.dnsServer.Start()
|
err := Context.dnsServer.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||||
|
@ -295,11 +333,6 @@ func closeDNSServer() {
|
||||||
Context.queryLog = nil
|
Context.queryLog = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.auth != nil {
|
|
||||||
Context.auth.Close()
|
|
||||||
Context.auth = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
Context.filters.Close()
|
Context.filters.Close()
|
||||||
|
|
||||||
log.Debug("Closed all DNS modules")
|
log.Debug("Closed all DNS modules")
|
||||||
|
|
188
home/home.go
188
home/home.go
|
@ -34,8 +34,6 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/NYTimes/gziphandler"
|
|
||||||
"github.com/gobuffalo/packr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -67,9 +65,9 @@ type homeContext struct {
|
||||||
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
|
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
|
||||||
dhcpServer *dhcpd.Server // DHCP module
|
dhcpServer *dhcpd.Server // DHCP module
|
||||||
auth *Auth // HTTP authentication module
|
auth *Auth // HTTP authentication module
|
||||||
httpServer *http.Server // HTTP module
|
filters Filtering // DNS filtering module
|
||||||
httpsServer HTTPSServer // HTTPS module
|
web *Web // Web (HTTP, HTTPS) module
|
||||||
filters Filtering
|
tls *TLSMod // TLS module
|
||||||
|
|
||||||
// Runtime properties
|
// Runtime properties
|
||||||
// --
|
// --
|
||||||
|
@ -116,10 +114,20 @@ func Main(version string, channel string, armVer string) {
|
||||||
Context.appSignalChannel = make(chan os.Signal)
|
Context.appSignalChannel = make(chan os.Signal)
|
||||||
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
||||||
go func() {
|
go func() {
|
||||||
<-Context.appSignalChannel
|
for {
|
||||||
|
sig := <-Context.appSignalChannel
|
||||||
|
log.Info("Received signal '%s'", sig)
|
||||||
|
switch sig {
|
||||||
|
case syscall.SIGHUP:
|
||||||
|
Context.clients.Reload()
|
||||||
|
Context.tls.Reload()
|
||||||
|
|
||||||
|
default:
|
||||||
cleanup()
|
cleanup()
|
||||||
cleanupAlways()
|
cleanupAlways()
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// run the protection
|
// run the protection
|
||||||
|
@ -163,7 +171,6 @@ func run(args options) {
|
||||||
}
|
}
|
||||||
|
|
||||||
initConfig()
|
initConfig()
|
||||||
initServices()
|
|
||||||
|
|
||||||
Context.tlsRoots = util.LoadSystemRootCAs()
|
Context.tlsRoots = util.LoadSystemRootCAs()
|
||||||
Context.transport = &http.Transport{
|
Context.transport = &http.Transport{
|
||||||
|
@ -218,6 +225,9 @@ func run(args options) {
|
||||||
if args.bindPort != 0 {
|
if args.bindPort != 0 {
|
||||||
config.BindPort = args.bindPort
|
config.BindPort = args.bindPort
|
||||||
}
|
}
|
||||||
|
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
||||||
|
Context.pidFileName = args.pidFile
|
||||||
|
}
|
||||||
|
|
||||||
if !Context.firstRun {
|
if !Context.firstRun {
|
||||||
// Save the updated config
|
// Save the updated config
|
||||||
|
@ -225,11 +235,42 @@ func run(args options) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = initDNSServer()
|
err := os.MkdirAll(Context.getDataDir(), 0755)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
|
||||||
|
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
|
||||||
|
if Context.auth == nil {
|
||||||
|
log.Fatalf("Couldn't initialize Auth module")
|
||||||
|
}
|
||||||
|
config.Users = nil
|
||||||
|
|
||||||
|
Context.tls = tlsCreate(config.TLS)
|
||||||
|
if Context.tls == nil {
|
||||||
|
log.Fatalf("Can't initialize TLS module")
|
||||||
|
}
|
||||||
|
|
||||||
|
webConf := WebConfig{
|
||||||
|
firstRun: Context.firstRun,
|
||||||
|
BindHost: config.BindHost,
|
||||||
|
BindPort: config.BindPort,
|
||||||
|
}
|
||||||
|
Context.web = CreateWeb(&webConf)
|
||||||
|
if Context.web == nil {
|
||||||
|
log.Fatalf("Can't initialize Web module")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !Context.firstRun {
|
||||||
|
err := initDNSServer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("%s", err)
|
log.Fatalf("%s", err)
|
||||||
}
|
}
|
||||||
|
Context.tls.Start()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := startDNSServer()
|
err := startDNSServer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -243,101 +284,27 @@ func run(args options) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
Context.web.Start()
|
||||||
Context.pidFileName = args.pidFile
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize and run the admin Web interface
|
|
||||||
box := packr.NewBox("../build/static")
|
|
||||||
|
|
||||||
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
|
|
||||||
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
|
|
||||||
registerControlHandlers()
|
|
||||||
|
|
||||||
// add handlers for /install paths, we only need them when we're not configured yet
|
|
||||||
if Context.firstRun {
|
|
||||||
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
|
|
||||||
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
|
|
||||||
registerInstallHandlers()
|
|
||||||
}
|
|
||||||
|
|
||||||
Context.httpsServer.cond = sync.NewCond(&Context.httpsServer.Mutex)
|
|
||||||
|
|
||||||
// for https, we have a separate goroutine loop
|
|
||||||
go httpServerLoop()
|
|
||||||
|
|
||||||
// this loop is used as an ability to change listening host and/or port
|
|
||||||
for !Context.httpsServer.shutdown {
|
|
||||||
printHTTPAddresses("http")
|
|
||||||
|
|
||||||
// we need to have new instance, because after Shutdown() the Server is not usable
|
|
||||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
|
||||||
Context.httpServer = &http.Server{
|
|
||||||
Addr: address,
|
|
||||||
}
|
|
||||||
err := Context.httpServer.ListenAndServe()
|
|
||||||
if err != http.ErrServerClosed {
|
|
||||||
cleanupAlways()
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait indefinitely for other go-routines to complete their job
|
// wait indefinitely for other go-routines to complete their job
|
||||||
select {}
|
select {}
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpServerLoop() {
|
// StartMods - initialize and start DNS after installation
|
||||||
for !Context.httpsServer.shutdown {
|
func StartMods() error {
|
||||||
Context.httpsServer.cond.L.Lock()
|
err := initDNSServer()
|
||||||
// this mechanism doesn't let us through until all conditions are met
|
|
||||||
for config.TLS.Enabled == false ||
|
|
||||||
config.TLS.PortHTTPS == 0 ||
|
|
||||||
len(config.TLS.PrivateKeyData) == 0 ||
|
|
||||||
len(config.TLS.CertificateChainData) == 0 { // sleep until necessary data is supplied
|
|
||||||
Context.httpsServer.cond.Wait()
|
|
||||||
}
|
|
||||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
|
|
||||||
// validate current TLS config and update warnings (it could have been loaded from file)
|
|
||||||
data := validateCertificates(string(config.TLS.CertificateChainData), string(config.TLS.PrivateKeyData), config.TLS.ServerName)
|
|
||||||
if !data.ValidPair {
|
|
||||||
cleanupAlways()
|
|
||||||
log.Fatal(data.WarningValidation)
|
|
||||||
}
|
|
||||||
config.Lock()
|
|
||||||
config.TLS.tlsConfigStatus = data // update warnings
|
|
||||||
config.Unlock()
|
|
||||||
|
|
||||||
// prepare certs for HTTPS server
|
|
||||||
// important -- they have to be copies, otherwise changing the contents in config.TLS will break encryption for in-flight requests
|
|
||||||
certchain := make([]byte, len(config.TLS.CertificateChainData))
|
|
||||||
copy(certchain, config.TLS.CertificateChainData)
|
|
||||||
privatekey := make([]byte, len(config.TLS.PrivateKeyData))
|
|
||||||
copy(privatekey, config.TLS.PrivateKeyData)
|
|
||||||
cert, err := tls.X509KeyPair(certchain, privatekey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanupAlways()
|
return err
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
Context.httpsServer.cond.L.Unlock()
|
|
||||||
|
|
||||||
// prepare HTTPS server
|
|
||||||
Context.httpsServer.server = &http.Server{
|
|
||||||
Addr: address,
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
RootCAs: Context.tlsRoots,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
printHTTPAddresses("https")
|
Context.tls.Start()
|
||||||
err = Context.httpsServer.server.ListenAndServeTLS("", "")
|
|
||||||
if err != http.ErrServerClosed {
|
err = startDNSServer()
|
||||||
cleanupAlways()
|
if err != nil {
|
||||||
log.Fatal(err)
|
closeDNSServer()
|
||||||
}
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the current user has root (administrator) rights
|
// Check if the current user has root (administrator) rights
|
||||||
|
@ -449,6 +416,15 @@ func configureLogger(args options) {
|
||||||
func cleanup() {
|
func cleanup() {
|
||||||
log.Info("Stopping AdGuard Home")
|
log.Info("Stopping AdGuard Home")
|
||||||
|
|
||||||
|
if Context.web != nil {
|
||||||
|
Context.web.Close()
|
||||||
|
Context.web = nil
|
||||||
|
}
|
||||||
|
if Context.auth != nil {
|
||||||
|
Context.auth.Close()
|
||||||
|
Context.auth = nil
|
||||||
|
}
|
||||||
|
|
||||||
err := stopDNSServer()
|
err := stopDNSServer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Couldn't stop DNS server: %s", err)
|
log.Error("Couldn't stop DNS server: %s", err)
|
||||||
|
@ -457,17 +433,11 @@ func cleanup() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Couldn't stop DHCP server: %s", err)
|
log.Error("Couldn't stop DHCP server: %s", err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Stop HTTP server, possibly waiting for all active connections to be closed
|
if Context.tls != nil {
|
||||||
func stopHTTPServer() {
|
Context.tls.Close()
|
||||||
log.Info("Stopping HTTP server...")
|
Context.tls = nil
|
||||||
Context.httpsServer.shutdown = true
|
|
||||||
if Context.httpsServer.server != nil {
|
|
||||||
_ = Context.httpsServer.server.Shutdown(context.TODO())
|
|
||||||
}
|
}
|
||||||
_ = Context.httpServer.Shutdown(context.TODO())
|
|
||||||
log.Info("Stopped HTTP server")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function is called before application exits
|
// This function is called before application exits
|
||||||
|
@ -519,7 +489,7 @@ func loadOptions() options {
|
||||||
}
|
}
|
||||||
o.bindPort = v
|
o.bindPort = v
|
||||||
}, nil},
|
}, nil},
|
||||||
{"service", "s", "Service control action: status, install, uninstall, start, stop, restart", func(value string) {
|
{"service", "s", "Service control action: status, install, uninstall, start, stop, restart, reload (configuration)", func(value string) {
|
||||||
o.serviceControlAction = value
|
o.serviceControlAction = value
|
||||||
}, nil},
|
}, nil},
|
||||||
{"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) {
|
{"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) {
|
||||||
|
@ -588,11 +558,13 @@ func loadOptions() options {
|
||||||
func printHTTPAddresses(proto string) {
|
func printHTTPAddresses(proto string) {
|
||||||
var address string
|
var address string
|
||||||
|
|
||||||
if proto == "https" && config.TLS.ServerName != "" {
|
tlsConf := tlsConfigSettings{}
|
||||||
if config.TLS.PortHTTPS == 443 {
|
Context.tls.WriteDiskConfig(&tlsConf)
|
||||||
log.Printf("Go to https://%s", config.TLS.ServerName)
|
if proto == "https" && tlsConf.ServerName != "" {
|
||||||
|
if tlsConf.PortHTTPS == 443 {
|
||||||
|
log.Printf("Go to https://%s", tlsConf.ServerName)
|
||||||
} else {
|
} else {
|
||||||
log.Printf("Go to https://%s:%d", config.TLS.ServerName, config.TLS.PortHTTPS)
|
log.Printf("Go to https://%s:%d", tlsConf.ServerName, tlsConf.PortHTTPS)
|
||||||
}
|
}
|
||||||
} else if config.BindHost == "0.0.0.0" {
|
} else if config.BindHost == "0.0.0.0" {
|
||||||
log.Println("AdGuard Home is available on the following addresses:")
|
log.Println("AdGuard Home is available on the following addresses:")
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
package home
|
package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
@ -71,6 +73,48 @@ func svcAction(s service.Service, action string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send SIGHUP to a process with ID taken from our pid-file
|
||||||
|
// If pid-file doesn't exist, find our PID using 'ps' command
|
||||||
|
func sendSigReload() {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
log.Error("Not implemented on Windows")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pidfile := fmt.Sprintf("/var/run/%s.pid", serviceName)
|
||||||
|
data, err := ioutil.ReadFile(pidfile)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
code, psdata, err := util.RunCommand("ps", "-C", serviceName, "-o", "pid=")
|
||||||
|
if err != nil || code != 0 {
|
||||||
|
log.Error("Can't find AdGuardHome process: %s code:%d", err, code)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = []byte(psdata)
|
||||||
|
|
||||||
|
} else if err != nil {
|
||||||
|
log.Error("Can't read PID file %s: %s", pidfile, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(string(data), "\n", 2)
|
||||||
|
if len(parts) == 0 {
|
||||||
|
log.Error("Can't read PID file %s: bad value", pidfile)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pid, err := strconv.Atoi(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Can't read PID file %s: %s", pidfile, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = util.SendProcessSignal(pid, syscall.SIGHUP)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Can't send signal to PID %d: %s", pid, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debug("Sent signal to PID %d", pid)
|
||||||
|
}
|
||||||
|
|
||||||
// handleServiceControlAction one of the possible control actions:
|
// handleServiceControlAction one of the possible control actions:
|
||||||
// install -- installs a service/daemon
|
// install -- installs a service/daemon
|
||||||
// uninstall -- uninstalls it
|
// uninstall -- uninstalls it
|
||||||
|
@ -84,6 +128,11 @@ func svcAction(s service.Service, action string) error {
|
||||||
func handleServiceControlAction(action string) {
|
func handleServiceControlAction(action string) {
|
||||||
log.Printf("Service control action: %s", action)
|
log.Printf("Service control action: %s", action)
|
||||||
|
|
||||||
|
if action == "reload" {
|
||||||
|
sendSigReload()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
pwd, err := os.Getwd()
|
pwd, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("Unable to find the path to the current directory")
|
log.Fatal("Unable to find the path to the current directory")
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
// Control: TLS configuring handlers
|
|
||||||
|
|
||||||
package home
|
package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
@ -16,18 +13,125 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/util"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var tlsWebHandlersRegistered = false
|
||||||
|
|
||||||
|
// TLSMod - TLS module object
|
||||||
|
type TLSMod struct {
|
||||||
|
certLastMod time.Time // last modification time of the certificate file
|
||||||
|
conf tlsConfigSettings
|
||||||
|
confLock sync.Mutex
|
||||||
|
status tlsConfigStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS module
|
||||||
|
func tlsCreate(conf tlsConfigSettings) *TLSMod {
|
||||||
|
t := &TLSMod{}
|
||||||
|
t.conf = conf
|
||||||
|
if t.conf.Enabled {
|
||||||
|
if !t.load() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
t.setCertFileTime()
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TLSMod) load() bool {
|
||||||
|
if !tlsLoadConfig(&t.conf, &t.status) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate current TLS config and update warnings (it could have been loaded from file)
|
||||||
|
data := validateCertificates(string(t.conf.CertificateChainData), string(t.conf.PrivateKeyData), t.conf.ServerName)
|
||||||
|
if !data.ValidPair {
|
||||||
|
log.Error(data.WarningValidation)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t.status = data
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close - close module
|
||||||
|
func (t *TLSMod) Close() {
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteDiskConfig - write config
|
||||||
|
func (t *TLSMod) WriteDiskConfig(conf *tlsConfigSettings) {
|
||||||
|
t.confLock.Lock()
|
||||||
|
*conf = t.conf
|
||||||
|
t.confLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TLSMod) setCertFileTime() {
|
||||||
|
if len(t.conf.CertificatePath) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fi, err := os.Stat(t.conf.CertificatePath)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("TLS: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.certLastMod = fi.ModTime().UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start - start the module
|
||||||
|
func (t *TLSMod) Start() {
|
||||||
|
if !tlsWebHandlersRegistered {
|
||||||
|
tlsWebHandlersRegistered = true
|
||||||
|
t.registerWebHandlers()
|
||||||
|
}
|
||||||
|
|
||||||
|
t.confLock.Lock()
|
||||||
|
tlsConf := t.conf
|
||||||
|
t.confLock.Unlock()
|
||||||
|
Context.web.TLSConfigChanged(tlsConf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload - reload certificate file
|
||||||
|
func (t *TLSMod) Reload() {
|
||||||
|
t.confLock.Lock()
|
||||||
|
tlsConf := t.conf
|
||||||
|
t.confLock.Unlock()
|
||||||
|
|
||||||
|
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fi, err := os.Stat(tlsConf.CertificatePath)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("TLS: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if fi.ModTime().UTC().Equal(t.certLastMod) {
|
||||||
|
log.Debug("TLS: certificate file isn't modified")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debug("TLS: certificate file is modified")
|
||||||
|
|
||||||
|
t.confLock.Lock()
|
||||||
|
r := t.load()
|
||||||
|
t.confLock.Unlock()
|
||||||
|
if !r {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.certLastMod = fi.ModTime().UTC()
|
||||||
|
|
||||||
|
_ = reconfigureDNSServer()
|
||||||
|
Context.web.TLSConfigChanged(tlsConf)
|
||||||
|
}
|
||||||
|
|
||||||
// Set certificate and private key data
|
// Set certificate and private key data
|
||||||
func tlsLoadConfig(tls *tlsConfig, status *tlsConfigStatus) bool {
|
func tlsLoadConfig(tls *tlsConfigSettings, status *tlsConfigStatus) bool {
|
||||||
tls.CertificateChainData = []byte(tls.CertificateChain)
|
tls.CertificateChainData = []byte(tls.CertificateChain)
|
||||||
tls.PrivateKeyData = []byte(tls.PrivateKey)
|
tls.PrivateKeyData = []byte(tls.PrivateKey)
|
||||||
|
|
||||||
|
@ -61,98 +165,115 @@ func tlsLoadConfig(tls *tlsConfig, status *tlsConfigStatus) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterTLSHandlers registers HTTP handlers for TLS configuration
|
type tlsConfigStatus struct {
|
||||||
func RegisterTLSHandlers() {
|
ValidCert bool `json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
|
||||||
httpRegister(http.MethodGet, "/control/tls/status", handleTLSStatus)
|
ValidChain bool `json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
|
||||||
httpRegister(http.MethodPost, "/control/tls/configure", handleTLSConfigure)
|
Subject string `json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
|
||||||
httpRegister(http.MethodPost, "/control/tls/validate", handleTLSValidate)
|
Issuer string `json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
|
||||||
|
NotBefore time.Time `json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
|
||||||
|
NotAfter time.Time `json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
|
||||||
|
DNSNames []string `json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
|
||||||
|
|
||||||
|
// key status
|
||||||
|
ValidKey bool `json:"valid_key"` // ValidKey is true if the key is a valid private key
|
||||||
|
KeyType string `json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
|
||||||
|
|
||||||
|
// is usable? set by validator
|
||||||
|
ValidPair bool `json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
|
||||||
|
|
||||||
|
// warnings
|
||||||
|
WarningValidation string `json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
// field ordering is important -- yaml fields will mirror ordering from here
|
||||||
marshalTLS(w, config.TLS)
|
type tlsConfig struct {
|
||||||
|
tlsConfigSettings `json:",inline"`
|
||||||
|
tlsConfigStatus `json:",inline"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
data, err := unmarshalTLS(r)
|
t.confLock.Lock()
|
||||||
if err != nil {
|
data := tlsConfig{
|
||||||
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
tlsConfigSettings: t.conf,
|
||||||
return
|
tlsConfigStatus: t.status,
|
||||||
}
|
}
|
||||||
|
t.confLock.Unlock()
|
||||||
// check if port is available
|
|
||||||
// BUT: if we are already using this port, no need
|
|
||||||
alreadyRunning := false
|
|
||||||
if Context.httpsServer.server != nil {
|
|
||||||
alreadyRunning = true
|
|
||||||
}
|
|
||||||
if !alreadyRunning {
|
|
||||||
err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS)
|
|
||||||
if err != nil {
|
|
||||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status := tlsConfigStatus{}
|
|
||||||
if tlsLoadConfig(&data, &status) {
|
|
||||||
status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
|
|
||||||
}
|
|
||||||
data.tlsConfigStatus = status
|
|
||||||
|
|
||||||
marshalTLS(w, data)
|
marshalTLS(w, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||||
|
setts, err := unmarshalTLS(r)
|
||||||
|
if err != nil {
|
||||||
|
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !WebCheckPortAvailable(setts.PortHTTPS) {
|
||||||
|
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status := tlsConfigStatus{}
|
||||||
|
if tlsLoadConfig(&setts, &status) {
|
||||||
|
status = validateCertificates(string(setts.CertificateChainData), string(setts.PrivateKeyData), setts.ServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := tlsConfig{
|
||||||
|
tlsConfigSettings: setts,
|
||||||
|
tlsConfigStatus: status,
|
||||||
|
}
|
||||||
|
marshalTLS(w, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||||
data, err := unmarshalTLS(r)
|
data, err := unmarshalTLS(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if port is available
|
if !WebCheckPortAvailable(data.PortHTTPS) {
|
||||||
// BUT: if we are already using this port, no need
|
|
||||||
alreadyRunning := false
|
|
||||||
if Context.httpsServer.server != nil {
|
|
||||||
alreadyRunning = true
|
|
||||||
}
|
|
||||||
if !alreadyRunning {
|
|
||||||
err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS)
|
|
||||||
if err != nil {
|
|
||||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
status := tlsConfigStatus{}
|
status := tlsConfigStatus{}
|
||||||
if !tlsLoadConfig(&data, &status) {
|
if !tlsLoadConfig(&data, &status) {
|
||||||
data.tlsConfigStatus = status
|
data2 := tlsConfig{
|
||||||
marshalTLS(w, data)
|
tlsConfigSettings: data,
|
||||||
|
tlsConfigStatus: t.status,
|
||||||
|
}
|
||||||
|
marshalTLS(w, data2)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
data.tlsConfigStatus = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
|
status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
|
||||||
restartHTTPS := false
|
restartHTTPS := false
|
||||||
if !reflect.DeepEqual(config.TLS.tlsConfigSettings, data.tlsConfigSettings) {
|
t.confLock.Lock()
|
||||||
|
if !reflect.DeepEqual(t.conf, data) {
|
||||||
log.Printf("tls config settings have changed, will restart HTTPS server")
|
log.Printf("tls config settings have changed, will restart HTTPS server")
|
||||||
restartHTTPS = true
|
restartHTTPS = true
|
||||||
}
|
}
|
||||||
config.TLS = data
|
t.conf = data
|
||||||
err = writeAllConfigsAndReloadDNS()
|
t.status = status
|
||||||
|
t.confLock.Unlock()
|
||||||
|
t.setCertFileTime()
|
||||||
|
onConfigModified()
|
||||||
|
err = reconfigureDNSServer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err)
|
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
marshalTLS(w, data)
|
data2 := tlsConfig{
|
||||||
|
tlsConfigSettings: data,
|
||||||
|
tlsConfigStatus: t.status,
|
||||||
|
}
|
||||||
|
marshalTLS(w, data2)
|
||||||
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
||||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||||
if restartHTTPS {
|
if restartHTTPS {
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
|
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
|
||||||
Context.httpsServer.cond.L.Lock()
|
Context.web.TLSConfigChanged(data)
|
||||||
Context.httpsServer.cond.Broadcast()
|
|
||||||
if Context.httpsServer.server != nil {
|
|
||||||
Context.httpsServer.server.Shutdown(context.TODO())
|
|
||||||
}
|
|
||||||
Context.httpsServer.cond.L.Unlock()
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -337,8 +458,8 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// unmarshalTLS handles base64-encoded certificates transparently
|
// unmarshalTLS handles base64-encoded certificates transparently
|
||||||
func unmarshalTLS(r *http.Request) (tlsConfig, error) {
|
func unmarshalTLS(r *http.Request) (tlsConfigSettings, error) {
|
||||||
data := tlsConfig{}
|
data := tlsConfigSettings{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&data)
|
err := json.NewDecoder(r.Body).Decode(&data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, errorx.Decorate(err, "Failed to parse new TLS config json")
|
return data, errorx.Decorate(err, "Failed to parse new TLS config json")
|
||||||
|
@ -389,3 +510,10 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// registerWebHandlers registers HTTP handlers for TLS configuration
|
||||||
|
func (t *TLSMod) registerWebHandlers() {
|
||||||
|
httpRegister("GET", "/control/tls/status", t.handleTLSStatus)
|
||||||
|
httpRegister("POST", "/control/tls/configure", t.handleTLSConfigure)
|
||||||
|
httpRegister("POST", "/control/tls/validate", t.handleTLSValidate)
|
||||||
|
}
|
|
@ -0,0 +1,189 @@
|
||||||
|
package home
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/NYTimes/gziphandler"
|
||||||
|
"github.com/gobuffalo/packr"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WebConfig struct {
|
||||||
|
firstRun bool
|
||||||
|
BindHost string
|
||||||
|
BindPort int
|
||||||
|
PortHTTPS int
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPSServer - HTTPS Server
|
||||||
|
type HTTPSServer struct {
|
||||||
|
server *http.Server
|
||||||
|
cond *sync.Cond
|
||||||
|
condLock sync.Mutex
|
||||||
|
shutdown bool // if TRUE, don't restart the server
|
||||||
|
enabled bool
|
||||||
|
cert tls.Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
// Web - module object
|
||||||
|
type Web struct {
|
||||||
|
conf *WebConfig
|
||||||
|
forceHTTPS bool
|
||||||
|
portHTTPS int
|
||||||
|
httpServer *http.Server // HTTP module
|
||||||
|
httpsServer HTTPSServer // HTTPS module
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateWeb - create module
|
||||||
|
func CreateWeb(conf *WebConfig) *Web {
|
||||||
|
w := Web{}
|
||||||
|
w.conf = conf
|
||||||
|
|
||||||
|
// Initialize and run the admin Web interface
|
||||||
|
box := packr.NewBox("../build/static")
|
||||||
|
|
||||||
|
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
|
||||||
|
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
|
||||||
|
|
||||||
|
// add handlers for /install paths, we only need them when we're not configured yet
|
||||||
|
if conf.firstRun {
|
||||||
|
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
|
||||||
|
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
|
||||||
|
w.registerInstallHandlers()
|
||||||
|
} else {
|
||||||
|
registerControlHandlers()
|
||||||
|
}
|
||||||
|
|
||||||
|
w.httpsServer.cond = sync.NewCond(&w.httpsServer.condLock)
|
||||||
|
return &w
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebCheckPortAvailable - check if port is available
|
||||||
|
// BUT: if we are already using this port, no need
|
||||||
|
func WebCheckPortAvailable(port int) bool {
|
||||||
|
alreadyRunning := false
|
||||||
|
if Context.web.httpsServer.server != nil {
|
||||||
|
alreadyRunning = true
|
||||||
|
}
|
||||||
|
if !alreadyRunning {
|
||||||
|
err := util.CheckPortAvailable(config.BindHost, port)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSConfigChanged - called when TLS configuration has changed
|
||||||
|
func (w *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
|
||||||
|
log.Debug("Web: applying new TLS configuration")
|
||||||
|
w.conf.PortHTTPS = tlsConf.PortHTTPS
|
||||||
|
w.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
|
||||||
|
w.portHTTPS = tlsConf.PortHTTPS
|
||||||
|
|
||||||
|
enabled := tlsConf.Enabled &&
|
||||||
|
tlsConf.PortHTTPS != 0 &&
|
||||||
|
len(tlsConf.PrivateKeyData) != 0 &&
|
||||||
|
len(tlsConf.CertificateChainData) != 0
|
||||||
|
var cert tls.Certificate
|
||||||
|
var err error
|
||||||
|
if enabled {
|
||||||
|
cert, err = tls.X509KeyPair(tlsConf.CertificateChainData, tlsConf.PrivateKeyData)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.httpsServer.cond.L.Lock()
|
||||||
|
if w.httpsServer.server != nil {
|
||||||
|
w.httpsServer.server.Shutdown(context.TODO())
|
||||||
|
}
|
||||||
|
w.httpsServer.enabled = enabled
|
||||||
|
w.httpsServer.cert = cert
|
||||||
|
w.httpsServer.cond.Broadcast()
|
||||||
|
w.httpsServer.cond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start - start serving HTTP requests
|
||||||
|
func (w *Web) Start() {
|
||||||
|
// for https, we have a separate goroutine loop
|
||||||
|
go w.httpServerLoop()
|
||||||
|
|
||||||
|
// this loop is used as an ability to change listening host and/or port
|
||||||
|
for !w.httpsServer.shutdown {
|
||||||
|
printHTTPAddresses("http")
|
||||||
|
|
||||||
|
// we need to have new instance, because after Shutdown() the Server is not usable
|
||||||
|
address := net.JoinHostPort(w.conf.BindHost, strconv.Itoa(w.conf.BindPort))
|
||||||
|
w.httpServer = &http.Server{
|
||||||
|
Addr: address,
|
||||||
|
}
|
||||||
|
err := w.httpServer.ListenAndServe()
|
||||||
|
if err != http.ErrServerClosed {
|
||||||
|
cleanupAlways()
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close - stop HTTP server, possibly waiting for all active connections to be closed
|
||||||
|
func (w *Web) Close() {
|
||||||
|
log.Info("Stopping HTTP server...")
|
||||||
|
w.httpsServer.cond.L.Lock()
|
||||||
|
w.httpsServer.shutdown = true
|
||||||
|
w.httpsServer.cond.L.Unlock()
|
||||||
|
if w.httpsServer.server != nil {
|
||||||
|
_ = w.httpsServer.server.Shutdown(context.TODO())
|
||||||
|
}
|
||||||
|
if w.httpServer != nil {
|
||||||
|
_ = w.httpServer.Shutdown(context.TODO())
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Stopped HTTP server")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Web) httpServerLoop() {
|
||||||
|
for {
|
||||||
|
w.httpsServer.cond.L.Lock()
|
||||||
|
if w.httpsServer.shutdown {
|
||||||
|
w.httpsServer.cond.L.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// this mechanism doesn't let us through until all conditions are met
|
||||||
|
for !w.httpsServer.enabled { // sleep until necessary data is supplied
|
||||||
|
w.httpsServer.cond.Wait()
|
||||||
|
if w.httpsServer.shutdown {
|
||||||
|
w.httpsServer.cond.L.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.httpsServer.cond.L.Unlock()
|
||||||
|
|
||||||
|
// prepare HTTPS server
|
||||||
|
address := net.JoinHostPort(w.conf.BindHost, strconv.Itoa(w.conf.PortHTTPS))
|
||||||
|
w.httpsServer.server = &http.Server{
|
||||||
|
Addr: address,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{w.httpsServer.cert},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
RootCAs: Context.tlsRoots,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
printHTTPAddresses("https")
|
||||||
|
err := w.httpsServer.server.ListenAndServeTLS("", "")
|
||||||
|
if err != http.ErrServerClosed {
|
||||||
|
cleanupAlways()
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -25,3 +25,8 @@ func SetRlimit(val uint) {
|
||||||
func HaveAdminRights() (bool, error) {
|
func HaveAdminRights() (bool, error) {
|
||||||
return os.Getuid() == 0, nil
|
return os.Getuid() == 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SendProcessSignal - send signal to a process
|
||||||
|
func SendProcessSignal(pid int, sig syscall.Signal) error {
|
||||||
|
return syscall.Kill(pid, sig)
|
||||||
|
}
|
||||||
|
|
|
@ -25,3 +25,8 @@ func SetRlimit(val uint) {
|
||||||
func HaveAdminRights() (bool, error) {
|
func HaveAdminRights() (bool, error) {
|
||||||
return os.Getuid() == 0, nil
|
return os.Getuid() == 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SendProcessSignal - send signal to a process
|
||||||
|
func SendProcessSignal(pid int, sig syscall.Signal) error {
|
||||||
|
return syscall.Kill(pid, sig)
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
package util
|
package util
|
||||||
|
|
||||||
import "golang.org/x/sys/windows"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
// Set user-specified limit of how many fd's we can use
|
// Set user-specified limit of how many fd's we can use
|
||||||
func SetRlimit(val uint) {
|
func SetRlimit(val uint) {
|
||||||
|
@ -26,3 +31,7 @@ func HaveAdminRights() (bool, error) {
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SendProcessSignal(pid int, sig syscall.Signal) error {
|
||||||
|
return fmt.Errorf("not supported on Windows")
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue