Merge: + Reload configuration command; refactoring
Close #1302 * commit 'dd7d9dc334181b3923b1532aaae51cae77d5ed55': * move getDNSAddresses() and "/dns-query" handler to DNS module * TLS is now a separate module (logically) * move HTTP server code * refactor: move blocked-services functions to dnsfilter * refactor + service: support "-s reload" command + clients: re-read auto-clients from /etc/hosts when SIGHUP is received
This commit is contained in:
commit
32baa907b6
|
@ -71,7 +71,6 @@ Contents:
|
|||
![](doc/agh-arch.png)
|
||||
|
||||
|
||||
|
||||
## First startup
|
||||
|
||||
The first application startup is detected when there's no .yaml configuration file.
|
||||
|
|
|
@ -18,6 +18,8 @@ import (
|
|||
const defaultDiscoverTime = time.Second * 3
|
||||
const leaseExpireStatic = 1
|
||||
|
||||
var webHandlersRegistered = false
|
||||
|
||||
// Lease contains the necessary information about a DHCP lease
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type Lease struct {
|
||||
|
@ -121,9 +123,6 @@ func Create(config ServerConfig) *Server {
|
|||
return nil
|
||||
}
|
||||
}
|
||||
if s.conf.HTTPRegister != nil {
|
||||
s.registerHandlers()
|
||||
}
|
||||
|
||||
// we can't delay database loading until DHCP server is started,
|
||||
// because we need static leases functionality available beforehand
|
||||
|
@ -221,6 +220,11 @@ func (s *Server) setConfig(config ServerConfig) error {
|
|||
// Start will listen on port 67 and serve DHCP requests.
|
||||
func (s *Server) Start() error {
|
||||
|
||||
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
|
||||
webHandlersRegistered = true
|
||||
s.registerHandlers()
|
||||
}
|
||||
|
||||
// TODO: don't close if interface and addresses are the same
|
||||
if s.conn != nil {
|
||||
s.closeConn()
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
package home
|
||||
package dnsfilter
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
)
|
||||
|
@ -119,7 +118,7 @@ var serviceRulesArray = []svc{
|
|||
}
|
||||
|
||||
// convert array to map
|
||||
func initServices() {
|
||||
func initBlockedServices() {
|
||||
serviceRules = make(map[string][]*rules.NetworkRule)
|
||||
for _, s := range serviceRulesArray {
|
||||
netRules := []*rules.NetworkRule{}
|
||||
|
@ -135,15 +134,20 @@ func initServices() {
|
|||
}
|
||||
}
|
||||
|
||||
// Return TRUE if a blocked service name is known
|
||||
func blockedSvcKnown(s string) bool {
|
||||
// BlockedSvcKnown - return TRUE if a blocked service name is known
|
||||
func BlockedSvcKnown(s string) bool {
|
||||
_, ok := serviceRules[s]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ApplyBlockedServices - set blocked services settings for this DNS request
|
||||
func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []string) {
|
||||
setts.ServicesRules = []dnsfilter.ServiceEntry{}
|
||||
func (d *Dnsfilter) ApplyBlockedServices(setts *RequestFilteringSettings, list []string, global bool) {
|
||||
setts.ServicesRules = []ServiceEntry{}
|
||||
if global {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
list = d.Config.BlockedServices
|
||||
}
|
||||
for _, name := range list {
|
||||
rules, ok := serviceRules[name]
|
||||
|
||||
|
@ -152,51 +156,45 @@ func ApplyBlockedServices(setts *dnsfilter.RequestFilteringSettings, list []stri
|
|||
continue
|
||||
}
|
||||
|
||||
s := dnsfilter.ServiceEntry{}
|
||||
s := ServiceEntry{}
|
||||
s.Name = name
|
||||
s.Rules = rules
|
||||
setts.ServicesRules = append(setts.ServicesRules, s)
|
||||
}
|
||||
}
|
||||
|
||||
func handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
|
||||
config.RLock()
|
||||
list := config.DNS.BlockedServices
|
||||
config.RUnlock()
|
||||
func (d *Dnsfilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
|
||||
d.confLock.RLock()
|
||||
list := d.Config.BlockedServices
|
||||
d.confLock.RUnlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(list)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *Dnsfilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) {
|
||||
list := []string{}
|
||||
err := json.NewDecoder(r.Body).Decode(&list)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
config.Lock()
|
||||
config.DNS.BlockedServices = list
|
||||
config.Unlock()
|
||||
d.confLock.Lock()
|
||||
d.Config.BlockedServices = list
|
||||
d.confLock.Unlock()
|
||||
|
||||
log.Debug("Updated blocked services list: %d", len(list))
|
||||
|
||||
err = writeAllConfigsAndReloadDNS()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
d.ConfigModified()
|
||||
}
|
||||
|
||||
httpOK(r, w)
|
||||
}
|
||||
|
||||
// RegisterBlockedServicesHandlers - register HTTP handlers
|
||||
func RegisterBlockedServicesHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/blocked_services/list", handleBlockedServicesList)
|
||||
httpRegister(http.MethodPost, "/control/blocked_services/set", handleBlockedServicesSet)
|
||||
// registerBlockedServicesHandlers - register HTTP handlers
|
||||
func (d *Dnsfilter) registerBlockedServicesHandlers() {
|
||||
d.Config.HTTPRegister("GET", "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||
d.Config.HTTPRegister("POST", "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||
}
|
|
@ -49,6 +49,10 @@ type Config struct {
|
|||
|
||||
Rewrites []RewriteEntry `yaml:"rewrites"`
|
||||
|
||||
// Names of services to block (globally).
|
||||
// Per-client settings can override this configuration.
|
||||
BlockedServices []string `yaml:"blocked_services"`
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
|
@ -175,6 +179,7 @@ func (d *Dnsfilter) WriteDiskConfig(c *Config) {
|
|||
d.confLock.Lock()
|
||||
*c = d.Config
|
||||
c.Rewrites = rewriteArrayDup(d.Config.Rewrites)
|
||||
// BlockedServices
|
||||
d.confLock.Unlock()
|
||||
}
|
||||
|
||||
|
@ -633,6 +638,18 @@ func New(c *Config, blockFilters []Filter) *Dnsfilter {
|
|||
d.prepareRewrites()
|
||||
}
|
||||
|
||||
bsvcs := []string{}
|
||||
for _, s := range d.BlockedServices {
|
||||
if !BlockedSvcKnown(s) {
|
||||
log.Debug("skipping unknown blocked-service '%s'", s)
|
||||
continue
|
||||
}
|
||||
bsvcs = append(bsvcs, s)
|
||||
}
|
||||
d.BlockedServices = bsvcs
|
||||
|
||||
initBlockedServices()
|
||||
|
||||
if blockFilters != nil {
|
||||
err := d.initFiltering(nil, blockFilters)
|
||||
if err != nil {
|
||||
|
@ -655,6 +672,7 @@ func (d *Dnsfilter) Start() {
|
|||
if d.Config.HTTPRegister != nil { // for tests
|
||||
d.registerSecurityHandlers()
|
||||
d.registerRewritesHandlers()
|
||||
d.registerBlockedServicesHandlers()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -179,6 +179,7 @@ type ServerConfig struct {
|
|||
|
||||
FilteringConfig
|
||||
TLSConfig
|
||||
TLSAllowUnencryptedDOH bool
|
||||
|
||||
TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2
|
||||
|
||||
|
|
|
@ -376,6 +376,20 @@ func checkDNS(input string, bootstrap []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleDOH(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.conf.TLSAllowUnencryptedDOH && r.TLS == nil {
|
||||
httpError(r, w, http.StatusNotFound, "Not Found")
|
||||
return
|
||||
}
|
||||
|
||||
if !s.IsRunning() {
|
||||
httpError(r, w, http.StatusInternalServerError, "DNS server is not running")
|
||||
return
|
||||
}
|
||||
|
||||
s.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *Server) registerHandlers() {
|
||||
s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig)
|
||||
s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig)
|
||||
|
@ -384,4 +398,6 @@ func (s *Server) registerHandlers() {
|
|||
|
||||
s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList)
|
||||
s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet)
|
||||
|
||||
s.conf.HTTPRegister("", "/dns-query", s.handleDOH)
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
@ -24,6 +25,8 @@ const (
|
|||
clientsUpdatePeriod = 1 * time.Hour
|
||||
)
|
||||
|
||||
var webHandlersRegistered = false
|
||||
|
||||
// Client information
|
||||
type Client struct {
|
||||
IDs []string
|
||||
|
@ -98,13 +101,27 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
|
|||
clients.addFromConfig(objects)
|
||||
|
||||
if !clients.testing {
|
||||
go clients.periodicUpdate()
|
||||
|
||||
clients.addFromDHCP()
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
}
|
||||
}
|
||||
|
||||
// Start - start the module
|
||||
func (clients *clientsContainer) Start() {
|
||||
if !clients.testing {
|
||||
if !webHandlersRegistered {
|
||||
webHandlersRegistered = true
|
||||
clients.registerWebHandlers()
|
||||
}
|
||||
go clients.periodicUpdate()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Reload - reload auto-clients
|
||||
func (clients *clientsContainer) Reload() {
|
||||
clients.addFromHostsFile()
|
||||
clients.addFromSystemARP()
|
||||
}
|
||||
|
||||
type clientObject struct {
|
||||
|
@ -145,7 +162,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
|||
}
|
||||
|
||||
for _, s := range cy.BlockedServices {
|
||||
if !blockedSvcKnown(s) {
|
||||
if !dnsfilter.BlockedSvcKnown(s) {
|
||||
log.Debug("Clients: skipping unknown blocked-service '%s'", s)
|
||||
continue
|
||||
}
|
||||
|
@ -194,8 +211,7 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
|
|||
|
||||
func (clients *clientsContainer) periodicUpdate() {
|
||||
for {
|
||||
clients.addFromHostsFile()
|
||||
clients.addFromSystemARP()
|
||||
clients.Reload()
|
||||
time.Sleep(clientsUpdatePeriod)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package home
|
|||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
@ -29,14 +28,6 @@ type logSettings struct {
|
|||
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
|
||||
}
|
||||
|
||||
// HTTPSServer - HTTPS Server
|
||||
type HTTPSServer struct {
|
||||
server *http.Server
|
||||
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
|
||||
sync.Mutex // protects config.TLS
|
||||
shutdown bool // if TRUE, don't restart the server
|
||||
}
|
||||
|
||||
// configuration is loaded from YAML
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type configuration struct {
|
||||
|
@ -60,7 +51,7 @@ type configuration struct {
|
|||
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
|
||||
|
||||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfig `yaml:"tls"`
|
||||
TLS tlsConfigSettings `yaml:"tls"`
|
||||
|
||||
Filters []filter `yaml:"filters"`
|
||||
WhitelistFilters []filter `yaml:"whitelist_filters"`
|
||||
|
@ -95,10 +86,6 @@ type dnsConfig struct {
|
|||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
||||
DnsfilterConf dnsfilter.Config `yaml:",inline"`
|
||||
|
||||
// Names of services to block (globally).
|
||||
// Per-client settings can override this configuration.
|
||||
BlockedServices []string `yaml:"blocked_services"`
|
||||
}
|
||||
|
||||
type tlsConfigSettings struct {
|
||||
|
@ -114,33 +101,6 @@ type tlsConfigSettings struct {
|
|||
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
|
||||
}
|
||||
|
||||
// field ordering is not important -- these are for API and are recalculated on each run
|
||||
type tlsConfigStatus struct {
|
||||
ValidCert bool `yaml:"-" json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
|
||||
ValidChain bool `yaml:"-" json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
|
||||
Subject string `yaml:"-" json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
|
||||
Issuer string `yaml:"-" json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
|
||||
NotBefore time.Time `yaml:"-" json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
|
||||
NotAfter time.Time `yaml:"-" json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
|
||||
DNSNames []string `yaml:"-" json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
|
||||
|
||||
// key status
|
||||
ValidKey bool `yaml:"-" json:"valid_key"` // ValidKey is true if the key is a valid private key
|
||||
KeyType string `yaml:"-" json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
|
||||
|
||||
// is usable? set by validator
|
||||
ValidPair bool `yaml:"-" json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
|
||||
|
||||
// warnings
|
||||
WarningValidation string `yaml:"-" json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type tlsConfig struct {
|
||||
tlsConfigSettings `yaml:",inline" json:",inline"`
|
||||
tlsConfigStatus `yaml:"-" json:",inline"`
|
||||
}
|
||||
|
||||
// initialize to default values, will be changed later when reading config or parsing command line
|
||||
var config = configuration{
|
||||
BindPort: 3000,
|
||||
|
@ -160,12 +120,10 @@ var config = configuration{
|
|||
FilteringEnabled: true, // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours: 24,
|
||||
},
|
||||
TLS: tlsConfig{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
TLS: tlsConfigSettings{
|
||||
PortHTTPS: 443,
|
||||
PortDNSOverTLS: 853, // needs to be passed through to dnsproxy
|
||||
},
|
||||
},
|
||||
DHCP: dhcpd.ServerConfig{
|
||||
LeaseDuration: 86400,
|
||||
ICMPTimeout: 1000,
|
||||
|
@ -234,26 +192,10 @@ func parseConfig() error {
|
|||
return err
|
||||
}
|
||||
|
||||
bsvcs := []string{}
|
||||
for _, s := range config.DNS.BlockedServices {
|
||||
if !blockedSvcKnown(s) {
|
||||
log.Debug("skipping unknown blocked-service '%s'", s)
|
||||
continue
|
||||
}
|
||||
bsvcs = append(bsvcs, s)
|
||||
}
|
||||
config.DNS.BlockedServices = bsvcs
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||
config.DNS.FiltersUpdateIntervalHours = 24
|
||||
}
|
||||
|
||||
status := tlsConfigStatus{}
|
||||
if !tlsLoadConfig(&config.TLS, &status) {
|
||||
log.Error("%s", status.WarningValidation)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -282,6 +224,11 @@ func (c *configuration) write() error {
|
|||
if Context.auth != nil {
|
||||
config.Users = Context.auth.GetUsers()
|
||||
}
|
||||
if Context.tls != nil {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
config.TLS = tlsConf
|
||||
}
|
||||
|
||||
if Context.stats != nil {
|
||||
sdc := stats.DiskConfig{}
|
||||
|
@ -331,13 +278,3 @@ func (c *configuration) write() error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAllConfigs() error {
|
||||
err := config.write()
|
||||
if err != nil {
|
||||
log.Error("Couldn't write config: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,8 +9,6 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
|
@ -27,9 +25,6 @@ func returnOK(w http.ResponseWriter) {
|
|||
}
|
||||
}
|
||||
|
||||
func httpOK(r *http.Request, w http.ResponseWriter) {
|
||||
}
|
||||
|
||||
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Info(text)
|
||||
|
@ -39,15 +34,6 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
|
|||
// ---------------
|
||||
// dns run control
|
||||
// ---------------
|
||||
func writeAllConfigsAndReloadDNS() error {
|
||||
err := writeAllConfigs()
|
||||
if err != nil {
|
||||
log.Error("Couldn't write all configs: %s", err)
|
||||
return err
|
||||
}
|
||||
return reconfigureDNSServer()
|
||||
}
|
||||
|
||||
func addDNSAddress(dnsAddresses *[]string, addr string) {
|
||||
if config.DNS.Port != 53 {
|
||||
addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port)
|
||||
|
@ -55,46 +41,6 @@ func addDNSAddress(dnsAddresses *[]string, addr string) {
|
|||
*dnsAddresses = append(*dnsAddresses, addr)
|
||||
}
|
||||
|
||||
// Get the list of DNS addresses the server is listening on
|
||||
func getDNSAddresses() []string {
|
||||
dnsAddresses := []string{}
|
||||
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
ifaces, e := util.GetValidNetInterfacesForWeb()
|
||||
if e != nil {
|
||||
log.Error("Couldn't get network interfaces: %v", e)
|
||||
return []string{}
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
addDNSAddress(&dnsAddresses, addr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
||||
}
|
||||
|
||||
if config.TLS.Enabled && len(config.TLS.ServerName) != 0 {
|
||||
|
||||
if config.TLS.PortHTTPS != 0 {
|
||||
addr := config.TLS.ServerName
|
||||
if config.TLS.PortHTTPS != 443 {
|
||||
addr = fmt.Sprintf("%s:%d", addr, config.TLS.PortHTTPS)
|
||||
}
|
||||
addr = fmt.Sprintf("https://%s/dns-query", addr)
|
||||
dnsAddresses = append(dnsAddresses, addr)
|
||||
}
|
||||
|
||||
if config.TLS.PortDNSOverTLS != 0 {
|
||||
addr := fmt.Sprintf("tls://%s:%d", config.TLS.ServerName, config.TLS.PortDNSOverTLS)
|
||||
dnsAddresses = append(dnsAddresses, addr)
|
||||
}
|
||||
}
|
||||
|
||||
return dnsAddresses
|
||||
}
|
||||
|
||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
c := dnsforward.FilteringConfig{}
|
||||
if Context.dnsServer != nil {
|
||||
|
@ -144,23 +90,6 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
|||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
// --------------
|
||||
// DNS-over-HTTPS
|
||||
// --------------
|
||||
func handleDOH(w http.ResponseWriter, r *http.Request) {
|
||||
if !config.TLS.AllowUnencryptedDOH && r.TLS == nil {
|
||||
httpError(w, http.StatusNotFound, "Not Found")
|
||||
return
|
||||
}
|
||||
|
||||
if !isRunning() {
|
||||
httpError(w, http.StatusInternalServerError, "DNS server is not running")
|
||||
return
|
||||
}
|
||||
|
||||
Context.dnsServer.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// ------------------------
|
||||
// registration of handlers
|
||||
// ------------------------
|
||||
|
@ -172,15 +101,16 @@ func registerControlHandlers() {
|
|||
httpRegister(http.MethodPost, "/control/update", handleUpdate)
|
||||
|
||||
httpRegister("GET", "/control/profile", handleGetProfile)
|
||||
|
||||
RegisterTLSHandlers()
|
||||
RegisterBlockedServicesHandlers()
|
||||
RegisterAuthHandlers()
|
||||
|
||||
http.HandleFunc("/dns-query", postInstall(handleDOH))
|
||||
}
|
||||
|
||||
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
if len(method) == 0 {
|
||||
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||
http.HandleFunc(url, postInstall(handler))
|
||||
return
|
||||
}
|
||||
|
||||
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
||||
}
|
||||
|
||||
|
@ -258,14 +188,16 @@ func preInstallHandler(handler http.Handler) http.Handler {
|
|||
// it also enforces HTTPS if it is enabled and configured
|
||||
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if Context.firstRun &&
|
||||
!strings.HasPrefix(r.URL.Path, "/install.") &&
|
||||
r.URL.Path != "/favicon.png" {
|
||||
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable
|
||||
http.Redirect(w, r, "/install.html", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// enforce https?
|
||||
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil {
|
||||
if r.TLS == nil && Context.web.forceHTTPS && Context.web.httpsServer.server != nil {
|
||||
// yes, and we want host from host:port
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
|
@ -275,13 +207,14 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
|
|||
// construct new URL to redirect to
|
||||
newURL := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)),
|
||||
Host: net.JoinHostPort(host, strconv.Itoa(Context.web.portHTTPS)),
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
handler(w, r)
|
||||
}
|
||||
|
|
|
@ -349,7 +349,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
setts := Context.dnsFilter.GetConfig()
|
||||
setts.FilteringEnabled = true
|
||||
ApplyBlockedServices(&setts, config.DNS.BlockedServices)
|
||||
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
|
||||
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
|
||||
|
|
|
@ -35,7 +35,7 @@ type netInterfaceJSON struct {
|
|||
}
|
||||
|
||||
// Get initial installation settings
|
||||
func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
||||
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
||||
data := firstRunData{}
|
||||
data.WebPort = 80
|
||||
data.DNSPort = 53
|
||||
|
@ -93,7 +93,7 @@ type checkConfigResp struct {
|
|||
}
|
||||
|
||||
// Check if ports are available, respond with results
|
||||
func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
||||
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := checkConfigReq{}
|
||||
respData := checkConfigResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
|
@ -275,7 +275,7 @@ func copyInstallSettings(dst *configuration, src *configuration) {
|
|||
}
|
||||
|
||||
// Apply new configuration, start DNS server, restart Web server
|
||||
func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
newSettings := applyConfigReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(&newSettings)
|
||||
if err != nil {
|
||||
|
@ -325,22 +325,11 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
|||
config.DNS.BindHost = newSettings.DNS.IP
|
||||
config.DNS.Port = newSettings.DNS.Port
|
||||
|
||||
err = initDNSServer()
|
||||
var err2 error
|
||||
if err == nil {
|
||||
err2 = startDNSServer()
|
||||
if err2 != nil {
|
||||
closeDNSServer()
|
||||
}
|
||||
}
|
||||
if err != nil || err2 != nil {
|
||||
err = StartMods()
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(&config, &curConfig)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err)
|
||||
} else {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err2)
|
||||
}
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -356,19 +345,21 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
registerControlHandlers()
|
||||
|
||||
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||
if restartHTTP {
|
||||
go func() {
|
||||
_ = Context.httpServer.Shutdown(context.TODO())
|
||||
_ = Context.web.httpServer.Shutdown(context.TODO())
|
||||
}()
|
||||
}
|
||||
|
||||
returnOK(w)
|
||||
}
|
||||
|
||||
func registerInstallHandlers() {
|
||||
http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(handleInstallGetAddresses)))
|
||||
http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(handleInstallCheckConfig)))
|
||||
http.HandleFunc("/control/install/configure", preInstall(ensurePOST(handleInstallConfigure)))
|
||||
func (web *Web) registerInstallHandlers() {
|
||||
http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
|
||||
http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
|
||||
http.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
|
||||
}
|
||||
|
|
|
@ -476,7 +476,6 @@ func doUpdate(u *updateInfo) error {
|
|||
func finishUpdate(u *updateInfo) {
|
||||
log.Info("Stopping all tasks")
|
||||
cleanup()
|
||||
stopHTTPServer()
|
||||
cleanupAlways()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
|
|
83
home/dns.go
83
home/dns.go
|
@ -3,13 +3,13 @@ package home
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
@ -25,13 +25,9 @@ func onConfigModified() {
|
|||
// Please note that we must do it even if we don't start it
|
||||
// so that we had access to the query log and the stats
|
||||
func initDNSServer() error {
|
||||
var err error
|
||||
baseDir := Context.getDataDir()
|
||||
|
||||
err := os.MkdirAll(baseDir, 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Cannot create DNS data dir at %s: %s", baseDir, err)
|
||||
}
|
||||
|
||||
statsConf := stats.Config{
|
||||
Filename: filepath.Join(baseDir, "stats.db"),
|
||||
LimitDays: config.DNS.StatsInterval,
|
||||
|
@ -70,14 +66,6 @@ func initDNSServer() error {
|
|||
return fmt.Errorf("dnsServer.Prepare: %s", err)
|
||||
}
|
||||
|
||||
sessFilename := filepath.Join(baseDir, "sessions.db")
|
||||
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
|
||||
if Context.auth == nil {
|
||||
closeDNSServer()
|
||||
return fmt.Errorf("Couldn't initialize Auth module")
|
||||
}
|
||||
config.Users = nil
|
||||
|
||||
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
|
||||
Context.whois = initWhois(&Context.clients)
|
||||
|
||||
|
@ -169,26 +157,74 @@ func generateServerConfig() dnsforward.ServerConfig {
|
|||
OnDNSRequest: onDNSRequest,
|
||||
}
|
||||
|
||||
if config.TLS.Enabled {
|
||||
newconfig.TLSConfig = config.TLS.TLSConfig
|
||||
if config.TLS.PortDNSOverTLS != 0 {
|
||||
newconfig.TLSListenAddr = &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.TLS.PortDNSOverTLS}
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
if tlsConf.Enabled {
|
||||
newconfig.TLSConfig = tlsConf.TLSConfig
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
newconfig.TLSListenAddr = &net.TCPAddr{
|
||||
IP: net.ParseIP(config.DNS.BindHost),
|
||||
Port: tlsConf.PortDNSOverTLS,
|
||||
}
|
||||
}
|
||||
}
|
||||
newconfig.TLSv12Roots = Context.tlsRoots
|
||||
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||
|
||||
newconfig.FilterHandler = applyAdditionalFiltering
|
||||
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
||||
return newconfig
|
||||
}
|
||||
|
||||
// Get the list of DNS addresses the server is listening on
|
||||
func getDNSAddresses() []string {
|
||||
dnsAddresses := []string{}
|
||||
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
ifaces, e := util.GetValidNetInterfacesForWeb()
|
||||
if e != nil {
|
||||
log.Error("Couldn't get network interfaces: %v", e)
|
||||
return []string{}
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
addDNSAddress(&dnsAddresses, addr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
||||
}
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
|
||||
|
||||
if tlsConf.PortHTTPS != 0 {
|
||||
addr := tlsConf.ServerName
|
||||
if tlsConf.PortHTTPS != 443 {
|
||||
addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS)
|
||||
}
|
||||
addr = fmt.Sprintf("https://%s/dns-query", addr)
|
||||
dnsAddresses = append(dnsAddresses, addr)
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS)
|
||||
dnsAddresses = append(dnsAddresses, addr)
|
||||
}
|
||||
}
|
||||
|
||||
return dnsAddresses
|
||||
}
|
||||
|
||||
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
||||
return Context.clients.FindUpstreams(clientAddr)
|
||||
}
|
||||
|
||||
// If a client has his own settings, apply them
|
||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
ApplyBlockedServices(setts, config.DNS.BlockedServices)
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||
|
||||
if len(clientAddr) == 0 {
|
||||
return
|
||||
|
@ -202,7 +238,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri
|
|||
log.Debug("Using settings for client with IP %s", clientAddr)
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
ApplyBlockedServices(setts, c.BlockedServices)
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
||||
}
|
||||
|
||||
setts.ClientTags = c.Tags
|
||||
|
@ -224,6 +260,8 @@ func startDNSServer() error {
|
|||
|
||||
enableFilters(false)
|
||||
|
||||
Context.clients.Start()
|
||||
|
||||
err := Context.dnsServer.Start()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
|
@ -295,11 +333,6 @@ func closeDNSServer() {
|
|||
Context.queryLog = nil
|
||||
}
|
||||
|
||||
if Context.auth != nil {
|
||||
Context.auth.Close()
|
||||
Context.auth = nil
|
||||
}
|
||||
|
||||
Context.filters.Close()
|
||||
|
||||
log.Debug("Closed all DNS modules")
|
||||
|
|
188
home/home.go
188
home/home.go
|
@ -34,8 +34,6 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/gobuffalo/packr"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -67,9 +65,9 @@ type homeContext struct {
|
|||
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
|
||||
dhcpServer *dhcpd.Server // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
httpServer *http.Server // HTTP module
|
||||
httpsServer HTTPSServer // HTTPS module
|
||||
filters Filtering
|
||||
filters Filtering // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
|
||||
// Runtime properties
|
||||
// --
|
||||
|
@ -116,10 +114,20 @@ func Main(version string, channel string, armVer string) {
|
|||
Context.appSignalChannel = make(chan os.Signal)
|
||||
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
||||
go func() {
|
||||
<-Context.appSignalChannel
|
||||
for {
|
||||
sig := <-Context.appSignalChannel
|
||||
log.Info("Received signal '%s'", sig)
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
Context.clients.Reload()
|
||||
Context.tls.Reload()
|
||||
|
||||
default:
|
||||
cleanup()
|
||||
cleanupAlways()
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// run the protection
|
||||
|
@ -163,7 +171,6 @@ func run(args options) {
|
|||
}
|
||||
|
||||
initConfig()
|
||||
initServices()
|
||||
|
||||
Context.tlsRoots = util.LoadSystemRootCAs()
|
||||
Context.transport = &http.Transport{
|
||||
|
@ -218,6 +225,9 @@ func run(args options) {
|
|||
if args.bindPort != 0 {
|
||||
config.BindPort = args.bindPort
|
||||
}
|
||||
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
||||
Context.pidFileName = args.pidFile
|
||||
}
|
||||
|
||||
if !Context.firstRun {
|
||||
// Save the updated config
|
||||
|
@ -225,11 +235,42 @@ func run(args options) {
|
|||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err = initDNSServer()
|
||||
err := os.MkdirAll(Context.getDataDir(), 0755)
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
|
||||
}
|
||||
|
||||
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
|
||||
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
|
||||
if Context.auth == nil {
|
||||
log.Fatalf("Couldn't initialize Auth module")
|
||||
}
|
||||
config.Users = nil
|
||||
|
||||
Context.tls = tlsCreate(config.TLS)
|
||||
if Context.tls == nil {
|
||||
log.Fatalf("Can't initialize TLS module")
|
||||
}
|
||||
|
||||
webConf := WebConfig{
|
||||
firstRun: Context.firstRun,
|
||||
BindHost: config.BindHost,
|
||||
BindPort: config.BindPort,
|
||||
}
|
||||
Context.web = CreateWeb(&webConf)
|
||||
if Context.web == nil {
|
||||
log.Fatalf("Can't initialize Web module")
|
||||
}
|
||||
|
||||
if !Context.firstRun {
|
||||
err := initDNSServer()
|
||||
if err != nil {
|
||||
log.Fatalf("%s", err)
|
||||
}
|
||||
Context.tls.Start()
|
||||
|
||||
go func() {
|
||||
err := startDNSServer()
|
||||
if err != nil {
|
||||
|
@ -243,101 +284,27 @@ func run(args options) {
|
|||
}
|
||||
}
|
||||
|
||||
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
||||
Context.pidFileName = args.pidFile
|
||||
}
|
||||
|
||||
// Initialize and run the admin Web interface
|
||||
box := packr.NewBox("../build/static")
|
||||
|
||||
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
|
||||
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
|
||||
registerControlHandlers()
|
||||
|
||||
// add handlers for /install paths, we only need them when we're not configured yet
|
||||
if Context.firstRun {
|
||||
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
|
||||
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
|
||||
registerInstallHandlers()
|
||||
}
|
||||
|
||||
Context.httpsServer.cond = sync.NewCond(&Context.httpsServer.Mutex)
|
||||
|
||||
// for https, we have a separate goroutine loop
|
||||
go httpServerLoop()
|
||||
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !Context.httpsServer.shutdown {
|
||||
printHTTPAddresses("http")
|
||||
|
||||
// we need to have new instance, because after Shutdown() the Server is not usable
|
||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||
Context.httpServer = &http.Server{
|
||||
Addr: address,
|
||||
}
|
||||
err := Context.httpServer.ListenAndServe()
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
|
||||
}
|
||||
Context.web.Start()
|
||||
|
||||
// wait indefinitely for other go-routines to complete their job
|
||||
select {}
|
||||
}
|
||||
|
||||
func httpServerLoop() {
|
||||
for !Context.httpsServer.shutdown {
|
||||
Context.httpsServer.cond.L.Lock()
|
||||
// this mechanism doesn't let us through until all conditions are met
|
||||
for config.TLS.Enabled == false ||
|
||||
config.TLS.PortHTTPS == 0 ||
|
||||
len(config.TLS.PrivateKeyData) == 0 ||
|
||||
len(config.TLS.CertificateChainData) == 0 { // sleep until necessary data is supplied
|
||||
Context.httpsServer.cond.Wait()
|
||||
}
|
||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
|
||||
// validate current TLS config and update warnings (it could have been loaded from file)
|
||||
data := validateCertificates(string(config.TLS.CertificateChainData), string(config.TLS.PrivateKeyData), config.TLS.ServerName)
|
||||
if !data.ValidPair {
|
||||
cleanupAlways()
|
||||
log.Fatal(data.WarningValidation)
|
||||
}
|
||||
config.Lock()
|
||||
config.TLS.tlsConfigStatus = data // update warnings
|
||||
config.Unlock()
|
||||
|
||||
// prepare certs for HTTPS server
|
||||
// important -- they have to be copies, otherwise changing the contents in config.TLS will break encryption for in-flight requests
|
||||
certchain := make([]byte, len(config.TLS.CertificateChainData))
|
||||
copy(certchain, config.TLS.CertificateChainData)
|
||||
privatekey := make([]byte, len(config.TLS.PrivateKeyData))
|
||||
copy(privatekey, config.TLS.PrivateKeyData)
|
||||
cert, err := tls.X509KeyPair(certchain, privatekey)
|
||||
// StartMods - initialize and start DNS after installation
|
||||
func StartMods() error {
|
||||
err := initDNSServer()
|
||||
if err != nil {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
Context.httpsServer.cond.L.Unlock()
|
||||
|
||||
// prepare HTTPS server
|
||||
Context.httpsServer.server = &http.Server{
|
||||
Addr: address,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
RootCAs: Context.tlsRoots,
|
||||
},
|
||||
return err
|
||||
}
|
||||
|
||||
printHTTPAddresses("https")
|
||||
err = Context.httpsServer.server.ListenAndServeTLS("", "")
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
Context.tls.Start()
|
||||
|
||||
err = startDNSServer()
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the current user has root (administrator) rights
|
||||
|
@ -449,6 +416,15 @@ func configureLogger(args options) {
|
|||
func cleanup() {
|
||||
log.Info("Stopping AdGuard Home")
|
||||
|
||||
if Context.web != nil {
|
||||
Context.web.Close()
|
||||
Context.web = nil
|
||||
}
|
||||
if Context.auth != nil {
|
||||
Context.auth.Close()
|
||||
Context.auth = nil
|
||||
}
|
||||
|
||||
err := stopDNSServer()
|
||||
if err != nil {
|
||||
log.Error("Couldn't stop DNS server: %s", err)
|
||||
|
@ -457,17 +433,11 @@ func cleanup() {
|
|||
if err != nil {
|
||||
log.Error("Couldn't stop DHCP server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop HTTP server, possibly waiting for all active connections to be closed
|
||||
func stopHTTPServer() {
|
||||
log.Info("Stopping HTTP server...")
|
||||
Context.httpsServer.shutdown = true
|
||||
if Context.httpsServer.server != nil {
|
||||
_ = Context.httpsServer.server.Shutdown(context.TODO())
|
||||
if Context.tls != nil {
|
||||
Context.tls.Close()
|
||||
Context.tls = nil
|
||||
}
|
||||
_ = Context.httpServer.Shutdown(context.TODO())
|
||||
log.Info("Stopped HTTP server")
|
||||
}
|
||||
|
||||
// This function is called before application exits
|
||||
|
@ -519,7 +489,7 @@ func loadOptions() options {
|
|||
}
|
||||
o.bindPort = v
|
||||
}, nil},
|
||||
{"service", "s", "Service control action: status, install, uninstall, start, stop, restart", func(value string) {
|
||||
{"service", "s", "Service control action: status, install, uninstall, start, stop, restart, reload (configuration)", func(value string) {
|
||||
o.serviceControlAction = value
|
||||
}, nil},
|
||||
{"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) {
|
||||
|
@ -588,11 +558,13 @@ func loadOptions() options {
|
|||
func printHTTPAddresses(proto string) {
|
||||
var address string
|
||||
|
||||
if proto == "https" && config.TLS.ServerName != "" {
|
||||
if config.TLS.PortHTTPS == 443 {
|
||||
log.Printf("Go to https://%s", config.TLS.ServerName)
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
if proto == "https" && tlsConf.ServerName != "" {
|
||||
if tlsConf.PortHTTPS == 443 {
|
||||
log.Printf("Go to https://%s", tlsConf.ServerName)
|
||||
} else {
|
||||
log.Printf("Go to https://%s:%d", config.TLS.ServerName, config.TLS.PortHTTPS)
|
||||
log.Printf("Go to https://%s:%d", tlsConf.ServerName, tlsConf.PortHTTPS)
|
||||
}
|
||||
} else if config.BindHost == "0.0.0.0" {
|
||||
log.Println("AdGuard Home is available on the following addresses:")
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
|
@ -71,6 +73,48 @@ func svcAction(s service.Service, action string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Send SIGHUP to a process with ID taken from our pid-file
|
||||
// If pid-file doesn't exist, find our PID using 'ps' command
|
||||
func sendSigReload() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Error("Not implemented on Windows")
|
||||
return
|
||||
}
|
||||
|
||||
pidfile := fmt.Sprintf("/var/run/%s.pid", serviceName)
|
||||
data, err := ioutil.ReadFile(pidfile)
|
||||
if os.IsNotExist(err) {
|
||||
code, psdata, err := util.RunCommand("ps", "-C", serviceName, "-o", "pid=")
|
||||
if err != nil || code != 0 {
|
||||
log.Error("Can't find AdGuardHome process: %s code:%d", err, code)
|
||||
return
|
||||
}
|
||||
data = []byte(psdata)
|
||||
|
||||
} else if err != nil {
|
||||
log.Error("Can't read PID file %s: %s", pidfile, err)
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(string(data), "\n", 2)
|
||||
if len(parts) == 0 {
|
||||
log.Error("Can't read PID file %s: bad value", pidfile)
|
||||
return
|
||||
}
|
||||
|
||||
pid, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
log.Error("Can't read PID file %s: %s", pidfile, err)
|
||||
return
|
||||
}
|
||||
err = util.SendProcessSignal(pid, syscall.SIGHUP)
|
||||
if err != nil {
|
||||
log.Error("Can't send signal to PID %d: %s", pid, err)
|
||||
return
|
||||
}
|
||||
log.Debug("Sent signal to PID %d", pid)
|
||||
}
|
||||
|
||||
// handleServiceControlAction one of the possible control actions:
|
||||
// install -- installs a service/daemon
|
||||
// uninstall -- uninstalls it
|
||||
|
@ -84,6 +128,11 @@ func svcAction(s service.Service, action string) error {
|
|||
func handleServiceControlAction(action string) {
|
||||
log.Printf("Service control action: %s", action)
|
||||
|
||||
if action == "reload" {
|
||||
sendSigReload()
|
||||
return
|
||||
}
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatal("Unable to find the path to the current directory")
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
// Control: TLS configuring handlers
|
||||
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
|
@ -16,18 +13,125 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
var tlsWebHandlersRegistered = false
|
||||
|
||||
// TLSMod - TLS module object
|
||||
type TLSMod struct {
|
||||
certLastMod time.Time // last modification time of the certificate file
|
||||
conf tlsConfigSettings
|
||||
confLock sync.Mutex
|
||||
status tlsConfigStatus
|
||||
}
|
||||
|
||||
// Create TLS module
|
||||
func tlsCreate(conf tlsConfigSettings) *TLSMod {
|
||||
t := &TLSMod{}
|
||||
t.conf = conf
|
||||
if t.conf.Enabled {
|
||||
if !t.load() {
|
||||
return nil
|
||||
}
|
||||
t.setCertFileTime()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TLSMod) load() bool {
|
||||
if !tlsLoadConfig(&t.conf, &t.status) {
|
||||
return false
|
||||
}
|
||||
|
||||
// validate current TLS config and update warnings (it could have been loaded from file)
|
||||
data := validateCertificates(string(t.conf.CertificateChainData), string(t.conf.PrivateKeyData), t.conf.ServerName)
|
||||
if !data.ValidPair {
|
||||
log.Error(data.WarningValidation)
|
||||
return false
|
||||
}
|
||||
t.status = data
|
||||
return true
|
||||
}
|
||||
|
||||
// Close - close module
|
||||
func (t *TLSMod) Close() {
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write config
|
||||
func (t *TLSMod) WriteDiskConfig(conf *tlsConfigSettings) {
|
||||
t.confLock.Lock()
|
||||
*conf = t.conf
|
||||
t.confLock.Unlock()
|
||||
}
|
||||
|
||||
func (t *TLSMod) setCertFileTime() {
|
||||
if len(t.conf.CertificatePath) == 0 {
|
||||
return
|
||||
}
|
||||
fi, err := os.Stat(t.conf.CertificatePath)
|
||||
if err != nil {
|
||||
log.Error("TLS: %s", err)
|
||||
return
|
||||
}
|
||||
t.certLastMod = fi.ModTime().UTC()
|
||||
}
|
||||
|
||||
// Start - start the module
|
||||
func (t *TLSMod) Start() {
|
||||
if !tlsWebHandlersRegistered {
|
||||
tlsWebHandlersRegistered = true
|
||||
t.registerWebHandlers()
|
||||
}
|
||||
|
||||
t.confLock.Lock()
|
||||
tlsConf := t.conf
|
||||
t.confLock.Unlock()
|
||||
Context.web.TLSConfigChanged(tlsConf)
|
||||
}
|
||||
|
||||
// Reload - reload certificate file
|
||||
func (t *TLSMod) Reload() {
|
||||
t.confLock.Lock()
|
||||
tlsConf := t.conf
|
||||
t.confLock.Unlock()
|
||||
|
||||
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
|
||||
return
|
||||
}
|
||||
fi, err := os.Stat(tlsConf.CertificatePath)
|
||||
if err != nil {
|
||||
log.Error("TLS: %s", err)
|
||||
return
|
||||
}
|
||||
if fi.ModTime().UTC().Equal(t.certLastMod) {
|
||||
log.Debug("TLS: certificate file isn't modified")
|
||||
return
|
||||
}
|
||||
log.Debug("TLS: certificate file is modified")
|
||||
|
||||
t.confLock.Lock()
|
||||
r := t.load()
|
||||
t.confLock.Unlock()
|
||||
if !r {
|
||||
return
|
||||
}
|
||||
|
||||
t.certLastMod = fi.ModTime().UTC()
|
||||
|
||||
_ = reconfigureDNSServer()
|
||||
Context.web.TLSConfigChanged(tlsConf)
|
||||
}
|
||||
|
||||
// Set certificate and private key data
|
||||
func tlsLoadConfig(tls *tlsConfig, status *tlsConfigStatus) bool {
|
||||
func tlsLoadConfig(tls *tlsConfigSettings, status *tlsConfigStatus) bool {
|
||||
tls.CertificateChainData = []byte(tls.CertificateChain)
|
||||
tls.PrivateKeyData = []byte(tls.PrivateKey)
|
||||
|
||||
|
@ -61,98 +165,115 @@ func tlsLoadConfig(tls *tlsConfig, status *tlsConfigStatus) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// RegisterTLSHandlers registers HTTP handlers for TLS configuration
|
||||
func RegisterTLSHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/tls/status", handleTLSStatus)
|
||||
httpRegister(http.MethodPost, "/control/tls/configure", handleTLSConfigure)
|
||||
httpRegister(http.MethodPost, "/control/tls/validate", handleTLSValidate)
|
||||
type tlsConfigStatus struct {
|
||||
ValidCert bool `json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
|
||||
ValidChain bool `json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
|
||||
Subject string `json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
|
||||
Issuer string `json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
|
||||
NotBefore time.Time `json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
|
||||
NotAfter time.Time `json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
|
||||
DNSNames []string `json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
|
||||
|
||||
// key status
|
||||
ValidKey bool `json:"valid_key"` // ValidKey is true if the key is a valid private key
|
||||
KeyType string `json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
|
||||
|
||||
// is usable? set by validator
|
||||
ValidPair bool `json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
|
||||
|
||||
// warnings
|
||||
WarningValidation string `json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
|
||||
}
|
||||
|
||||
func handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
marshalTLS(w, config.TLS)
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type tlsConfig struct {
|
||||
tlsConfigSettings `json:",inline"`
|
||||
tlsConfigStatus `json:",inline"`
|
||||
}
|
||||
|
||||
func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
return
|
||||
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
t.confLock.Lock()
|
||||
data := tlsConfig{
|
||||
tlsConfigSettings: t.conf,
|
||||
tlsConfigStatus: t.status,
|
||||
}
|
||||
|
||||
// check if port is available
|
||||
// BUT: if we are already using this port, no need
|
||||
alreadyRunning := false
|
||||
if Context.httpsServer.server != nil {
|
||||
alreadyRunning = true
|
||||
}
|
||||
if !alreadyRunning {
|
||||
err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
status := tlsConfigStatus{}
|
||||
if tlsLoadConfig(&data, &status) {
|
||||
status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
|
||||
}
|
||||
data.tlsConfigStatus = status
|
||||
|
||||
t.confLock.Unlock()
|
||||
marshalTLS(w, data)
|
||||
}
|
||||
|
||||
func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
setts, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !WebCheckPortAvailable(setts.PortHTTPS) {
|
||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS)
|
||||
return
|
||||
}
|
||||
|
||||
status := tlsConfigStatus{}
|
||||
if tlsLoadConfig(&setts, &status) {
|
||||
status = validateCertificates(string(setts.CertificateChainData), string(setts.PrivateKeyData), setts.ServerName)
|
||||
}
|
||||
|
||||
data := tlsConfig{
|
||||
tlsConfigSettings: setts,
|
||||
tlsConfigStatus: status,
|
||||
}
|
||||
marshalTLS(w, data)
|
||||
}
|
||||
|
||||
func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// check if port is available
|
||||
// BUT: if we are already using this port, no need
|
||||
alreadyRunning := false
|
||||
if Context.httpsServer.server != nil {
|
||||
alreadyRunning = true
|
||||
}
|
||||
if !alreadyRunning {
|
||||
err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS)
|
||||
if err != nil {
|
||||
if !WebCheckPortAvailable(data.PortHTTPS) {
|
||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
status := tlsConfigStatus{}
|
||||
if !tlsLoadConfig(&data, &status) {
|
||||
data.tlsConfigStatus = status
|
||||
marshalTLS(w, data)
|
||||
data2 := tlsConfig{
|
||||
tlsConfigSettings: data,
|
||||
tlsConfigStatus: t.status,
|
||||
}
|
||||
marshalTLS(w, data2)
|
||||
return
|
||||
}
|
||||
data.tlsConfigStatus = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
|
||||
status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
|
||||
restartHTTPS := false
|
||||
if !reflect.DeepEqual(config.TLS.tlsConfigSettings, data.tlsConfigSettings) {
|
||||
t.confLock.Lock()
|
||||
if !reflect.DeepEqual(t.conf, data) {
|
||||
log.Printf("tls config settings have changed, will restart HTTPS server")
|
||||
restartHTTPS = true
|
||||
}
|
||||
config.TLS = data
|
||||
err = writeAllConfigsAndReloadDNS()
|
||||
t.conf = data
|
||||
t.status = status
|
||||
t.confLock.Unlock()
|
||||
t.setCertFileTime()
|
||||
onConfigModified()
|
||||
err = reconfigureDNSServer()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err)
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
}
|
||||
marshalTLS(w, data)
|
||||
data2 := tlsConfig{
|
||||
tlsConfigSettings: data,
|
||||
tlsConfigStatus: t.status,
|
||||
}
|
||||
marshalTLS(w, data2)
|
||||
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
|
||||
Context.httpsServer.cond.L.Lock()
|
||||
Context.httpsServer.cond.Broadcast()
|
||||
if Context.httpsServer.server != nil {
|
||||
Context.httpsServer.server.Shutdown(context.TODO())
|
||||
}
|
||||
Context.httpsServer.cond.L.Unlock()
|
||||
Context.web.TLSConfigChanged(data)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
@ -337,8 +458,8 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
|
|||
}
|
||||
|
||||
// unmarshalTLS handles base64-encoded certificates transparently
|
||||
func unmarshalTLS(r *http.Request) (tlsConfig, error) {
|
||||
data := tlsConfig{}
|
||||
func unmarshalTLS(r *http.Request) (tlsConfigSettings, error) {
|
||||
data := tlsConfigSettings{}
|
||||
err := json.NewDecoder(r.Body).Decode(&data)
|
||||
if err != nil {
|
||||
return data, errorx.Decorate(err, "Failed to parse new TLS config json")
|
||||
|
@ -389,3 +510,10 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
// registerWebHandlers registers HTTP handlers for TLS configuration
|
||||
func (t *TLSMod) registerWebHandlers() {
|
||||
httpRegister("GET", "/control/tls/status", t.handleTLSStatus)
|
||||
httpRegister("POST", "/control/tls/configure", t.handleTLSConfigure)
|
||||
httpRegister("POST", "/control/tls/validate", t.handleTLSValidate)
|
||||
}
|
|
@ -0,0 +1,189 @@
|
|||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/gobuffalo/packr"
|
||||
)
|
||||
|
||||
type WebConfig struct {
|
||||
firstRun bool
|
||||
BindHost string
|
||||
BindPort int
|
||||
PortHTTPS int
|
||||
}
|
||||
|
||||
// HTTPSServer - HTTPS Server
|
||||
type HTTPSServer struct {
|
||||
server *http.Server
|
||||
cond *sync.Cond
|
||||
condLock sync.Mutex
|
||||
shutdown bool // if TRUE, don't restart the server
|
||||
enabled bool
|
||||
cert tls.Certificate
|
||||
}
|
||||
|
||||
// Web - module object
|
||||
type Web struct {
|
||||
conf *WebConfig
|
||||
forceHTTPS bool
|
||||
portHTTPS int
|
||||
httpServer *http.Server // HTTP module
|
||||
httpsServer HTTPSServer // HTTPS module
|
||||
}
|
||||
|
||||
// CreateWeb - create module
|
||||
func CreateWeb(conf *WebConfig) *Web {
|
||||
w := Web{}
|
||||
w.conf = conf
|
||||
|
||||
// Initialize and run the admin Web interface
|
||||
box := packr.NewBox("../build/static")
|
||||
|
||||
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
|
||||
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
|
||||
|
||||
// add handlers for /install paths, we only need them when we're not configured yet
|
||||
if conf.firstRun {
|
||||
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
|
||||
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
|
||||
w.registerInstallHandlers()
|
||||
} else {
|
||||
registerControlHandlers()
|
||||
}
|
||||
|
||||
w.httpsServer.cond = sync.NewCond(&w.httpsServer.condLock)
|
||||
return &w
|
||||
}
|
||||
|
||||
// WebCheckPortAvailable - check if port is available
|
||||
// BUT: if we are already using this port, no need
|
||||
func WebCheckPortAvailable(port int) bool {
|
||||
alreadyRunning := false
|
||||
if Context.web.httpsServer.server != nil {
|
||||
alreadyRunning = true
|
||||
}
|
||||
if !alreadyRunning {
|
||||
err := util.CheckPortAvailable(config.BindHost, port)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TLSConfigChanged - called when TLS configuration has changed
|
||||
func (w *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
|
||||
log.Debug("Web: applying new TLS configuration")
|
||||
w.conf.PortHTTPS = tlsConf.PortHTTPS
|
||||
w.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
|
||||
w.portHTTPS = tlsConf.PortHTTPS
|
||||
|
||||
enabled := tlsConf.Enabled &&
|
||||
tlsConf.PortHTTPS != 0 &&
|
||||
len(tlsConf.PrivateKeyData) != 0 &&
|
||||
len(tlsConf.CertificateChainData) != 0
|
||||
var cert tls.Certificate
|
||||
var err error
|
||||
if enabled {
|
||||
cert, err = tls.X509KeyPair(tlsConf.CertificateChainData, tlsConf.PrivateKeyData)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
w.httpsServer.cond.L.Lock()
|
||||
if w.httpsServer.server != nil {
|
||||
w.httpsServer.server.Shutdown(context.TODO())
|
||||
}
|
||||
w.httpsServer.enabled = enabled
|
||||
w.httpsServer.cert = cert
|
||||
w.httpsServer.cond.Broadcast()
|
||||
w.httpsServer.cond.L.Unlock()
|
||||
}
|
||||
|
||||
// Start - start serving HTTP requests
|
||||
func (w *Web) Start() {
|
||||
// for https, we have a separate goroutine loop
|
||||
go w.httpServerLoop()
|
||||
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !w.httpsServer.shutdown {
|
||||
printHTTPAddresses("http")
|
||||
|
||||
// we need to have new instance, because after Shutdown() the Server is not usable
|
||||
address := net.JoinHostPort(w.conf.BindHost, strconv.Itoa(w.conf.BindPort))
|
||||
w.httpServer = &http.Server{
|
||||
Addr: address,
|
||||
}
|
||||
err := w.httpServer.ListenAndServe()
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
|
||||
}
|
||||
}
|
||||
|
||||
// Close - stop HTTP server, possibly waiting for all active connections to be closed
|
||||
func (w *Web) Close() {
|
||||
log.Info("Stopping HTTP server...")
|
||||
w.httpsServer.cond.L.Lock()
|
||||
w.httpsServer.shutdown = true
|
||||
w.httpsServer.cond.L.Unlock()
|
||||
if w.httpsServer.server != nil {
|
||||
_ = w.httpsServer.server.Shutdown(context.TODO())
|
||||
}
|
||||
if w.httpServer != nil {
|
||||
_ = w.httpServer.Shutdown(context.TODO())
|
||||
}
|
||||
|
||||
log.Info("Stopped HTTP server")
|
||||
}
|
||||
|
||||
func (w *Web) httpServerLoop() {
|
||||
for {
|
||||
w.httpsServer.cond.L.Lock()
|
||||
if w.httpsServer.shutdown {
|
||||
w.httpsServer.cond.L.Unlock()
|
||||
break
|
||||
}
|
||||
|
||||
// this mechanism doesn't let us through until all conditions are met
|
||||
for !w.httpsServer.enabled { // sleep until necessary data is supplied
|
||||
w.httpsServer.cond.Wait()
|
||||
if w.httpsServer.shutdown {
|
||||
w.httpsServer.cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.httpsServer.cond.L.Unlock()
|
||||
|
||||
// prepare HTTPS server
|
||||
address := net.JoinHostPort(w.conf.BindHost, strconv.Itoa(w.conf.PortHTTPS))
|
||||
w.httpsServer.server = &http.Server{
|
||||
Addr: address,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{w.httpsServer.cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
RootCAs: Context.tlsRoots,
|
||||
},
|
||||
}
|
||||
|
||||
printHTTPAddresses("https")
|
||||
err := w.httpsServer.server.ListenAndServeTLS("", "")
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -25,3 +25,8 @@ func SetRlimit(val uint) {
|
|||
func HaveAdminRights() (bool, error) {
|
||||
return os.Getuid() == 0, nil
|
||||
}
|
||||
|
||||
// SendProcessSignal - send signal to a process
|
||||
func SendProcessSignal(pid int, sig syscall.Signal) error {
|
||||
return syscall.Kill(pid, sig)
|
||||
}
|
||||
|
|
|
@ -25,3 +25,8 @@ func SetRlimit(val uint) {
|
|||
func HaveAdminRights() (bool, error) {
|
||||
return os.Getuid() == 0, nil
|
||||
}
|
||||
|
||||
// SendProcessSignal - send signal to a process
|
||||
func SendProcessSignal(pid int, sig syscall.Signal) error {
|
||||
return syscall.Kill(pid, sig)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
package util
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// Set user-specified limit of how many fd's we can use
|
||||
func SetRlimit(val uint) {
|
||||
|
@ -26,3 +31,7 @@ func HaveAdminRights() (bool, error) {
|
|||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func SendProcessSignal(pid int, sig syscall.Signal) error {
|
||||
return fmt.Errorf("not supported on Windows")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue