Merge: Refactoring: move global variables; move initialization of periodic tasks
Close #583 * commit 'b8444ff46aff5e45194c6cb61fdf4d3e7aa798fa': * minor * dnsforward: move initialization of periodic tasks to NewServer() * move "dnsctx" to "config" * move "dnsServer" to "config" * move "dhcpServer" to "config" * move "httpServer" to "config" * move "httpsServer" to "config" * move "pidFileName" to "config" * move "versionCheckJSON" to "config" * move "client", "transport" to "config" * move "controlLock" mutex to "config" * clients: move container object to "config"
This commit is contained in:
commit
bd5162ada3
|
@ -41,7 +41,6 @@ type Server struct {
|
|||
dnsFilter *dnsfilter.Dnsfilter // DNS filter instance
|
||||
queryLog *queryLog // Query log instance
|
||||
stats *stats // General server statistics
|
||||
once sync.Once
|
||||
|
||||
AllowedClients map[string]bool // IP addresses of whitelist clients
|
||||
DisallowedClients map[string]bool // IP addresses of clients that should be blocked
|
||||
|
@ -55,11 +54,24 @@ type Server struct {
|
|||
|
||||
// NewServer creates a new instance of the dnsforward.Server
|
||||
// baseDir is the base directory for query logs
|
||||
// Note: this function must be called only once
|
||||
func NewServer(baseDir string) *Server {
|
||||
return &Server{
|
||||
s := &Server{
|
||||
queryLog: newQueryLog(baseDir),
|
||||
stats: newStats(),
|
||||
}
|
||||
|
||||
log.Tracef("Loading stats from querylog")
|
||||
err := s.queryLog.fillStatsFromQueryLog(s.stats)
|
||||
if err != nil {
|
||||
log.Error("failed to load stats from querylog: %s", err)
|
||||
}
|
||||
|
||||
log.Printf("Start DNS server periodic jobs")
|
||||
go s.queryLog.periodicQueryLogRotate()
|
||||
go s.queryLog.runningTop.periodicHourlyTopRotate()
|
||||
go s.stats.statsRotator()
|
||||
return s
|
||||
}
|
||||
|
||||
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
|
||||
|
@ -169,33 +181,11 @@ func (s *Server) startInternal(config *ServerConfig) error {
|
|||
return errors.New("DNS server is already started")
|
||||
}
|
||||
|
||||
if s.queryLog == nil {
|
||||
s.queryLog = newQueryLog(".")
|
||||
}
|
||||
|
||||
if s.stats == nil {
|
||||
s.stats = newStats()
|
||||
}
|
||||
|
||||
err := s.initDNSFilter()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("Loading stats from querylog")
|
||||
err = s.queryLog.fillStatsFromQueryLog(s.stats)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "failed to load stats from querylog")
|
||||
}
|
||||
|
||||
// TODO: Think about reworking this, the current approach won't work properly if AG Home is restarted periodically
|
||||
s.once.Do(func() {
|
||||
log.Printf("Start DNS server periodic jobs")
|
||||
go s.queryLog.periodicQueryLogRotate()
|
||||
go s.queryLog.runningTop.periodicHourlyTopRotate()
|
||||
go s.stats.statsRotator()
|
||||
})
|
||||
|
||||
proxyConfig := proxy.Config{
|
||||
UDPListenAddr: s.conf.UDPListenAddr,
|
||||
TCPListenAddr: s.conf.TCPListenAddr,
|
||||
|
|
|
@ -66,10 +66,9 @@ type clientsContainer struct {
|
|||
lock sync.Mutex
|
||||
}
|
||||
|
||||
var clients clientsContainer
|
||||
|
||||
// Initialize clients container
|
||||
func clientsInit() {
|
||||
// Init initializes clients container
|
||||
// Note: this function must be called only once
|
||||
func (clients *clientsContainer) Init() {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
}
|
||||
|
@ -77,22 +76,24 @@ func clientsInit() {
|
|||
clients.ipIndex = make(map[string]*Client)
|
||||
clients.ipHost = make(map[string]ClientHost)
|
||||
|
||||
go periodicClientsUpdate()
|
||||
go clients.periodicUpdate()
|
||||
}
|
||||
|
||||
func periodicClientsUpdate() {
|
||||
func (clients *clientsContainer) periodicUpdate() {
|
||||
for {
|
||||
clientsAddFromHostsFile()
|
||||
clientsAddFromSystemARP()
|
||||
clients.addFromHostsFile()
|
||||
clients.addFromSystemARP()
|
||||
time.Sleep(clientsUpdatePeriod)
|
||||
}
|
||||
}
|
||||
|
||||
func clientsGetList() map[string]*Client {
|
||||
// GetList returns the pointer to clients list
|
||||
func (clients *clientsContainer) GetList() map[string]*Client {
|
||||
return clients.list
|
||||
}
|
||||
|
||||
func clientExists(ip string) bool {
|
||||
// Exists checks if client with this IP already exists
|
||||
func (clients *clientsContainer) Exists(ip string) bool {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
|
@ -105,8 +106,8 @@ func clientExists(ip string) bool {
|
|||
return ok
|
||||
}
|
||||
|
||||
// Search for a client by IP
|
||||
func clientFind(ip string) (Client, bool) {
|
||||
// Find searches for a client by IP
|
||||
func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
|
@ -121,7 +122,7 @@ func clientFind(ip string) (Client, bool) {
|
|||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ipAddr := dhcpServer.FindIPbyMAC(mac)
|
||||
ipAddr := config.dhcpServer.FindIPbyMAC(mac)
|
||||
if ipAddr == nil {
|
||||
continue
|
||||
}
|
||||
|
@ -135,7 +136,7 @@ func clientFind(ip string) (Client, bool) {
|
|||
}
|
||||
|
||||
// Check if Client object's fields are correct
|
||||
func clientCheck(c *Client) error {
|
||||
func (c *Client) check() error {
|
||||
if len(c.Name) == 0 {
|
||||
return fmt.Errorf("Invalid Name")
|
||||
}
|
||||
|
@ -162,8 +163,8 @@ func clientCheck(c *Client) error {
|
|||
|
||||
// Add a new client object
|
||||
// Return true: success; false: client exists.
|
||||
func clientAdd(c Client) (bool, error) {
|
||||
e := clientCheck(&c)
|
||||
func (clients *clientsContainer) Add(c Client) (bool, error) {
|
||||
e := c.check()
|
||||
if e != nil {
|
||||
return false, e
|
||||
}
|
||||
|
@ -194,8 +195,8 @@ func clientAdd(c Client) (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
// Remove a client
|
||||
func clientDel(name string) bool {
|
||||
// Del removes a client
|
||||
func (clients *clientsContainer) Del(name string) bool {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
|
@ -210,8 +211,8 @@ func clientDel(name string) bool {
|
|||
}
|
||||
|
||||
// Update a client
|
||||
func clientUpdate(name string, c Client) error {
|
||||
err := clientCheck(&c)
|
||||
func (clients *clientsContainer) Update(name string, c Client) error {
|
||||
err := c.check()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -257,10 +258,10 @@ func clientUpdate(name string, c Client) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Add new IP -> Host pair
|
||||
// AddHost adds new IP -> Host pair
|
||||
// Use priority of the source (etc/hosts > ARP > rDNS)
|
||||
// so we overwrite existing entries with an equal or higher priority
|
||||
func clientAddHost(ip, host string, source clientSource) (bool, error) {
|
||||
func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
|
@ -279,7 +280,7 @@ func clientAddHost(ip, host string, source clientSource) (bool, error) {
|
|||
}
|
||||
|
||||
// Parse system 'hosts' file and fill clients array
|
||||
func clientsAddFromHostsFile() {
|
||||
func (clients *clientsContainer) addFromHostsFile() {
|
||||
hostsFn := "/etc/hosts"
|
||||
if runtime.GOOS == "windows" {
|
||||
hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
|
||||
|
@ -304,7 +305,7 @@ func clientsAddFromHostsFile() {
|
|||
continue
|
||||
}
|
||||
|
||||
ok, e := clientAddHost(fields[0], fields[1], ClientSourceHostsFile)
|
||||
ok, e := clients.AddHost(fields[0], fields[1], ClientSourceHostsFile)
|
||||
if e != nil {
|
||||
log.Tracef("%s", e)
|
||||
}
|
||||
|
@ -319,7 +320,7 @@ func clientsAddFromHostsFile() {
|
|||
// Add IP -> Host pairs from the system's `arp -a` command output
|
||||
// The command's output is:
|
||||
// HOST (IP) at MAC on IFACE
|
||||
func clientsAddFromSystemARP() {
|
||||
func (clients *clientsContainer) addFromSystemARP() {
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
return
|
||||
|
@ -350,7 +351,7 @@ func clientsAddFromSystemARP() {
|
|||
continue
|
||||
}
|
||||
|
||||
ok, e := clientAddHost(ip, host, ClientSourceARP)
|
||||
ok, e := clients.AddHost(ip, host, ClientSourceARP)
|
||||
if e != nil {
|
||||
log.Tracef("%s", e)
|
||||
}
|
||||
|
@ -379,8 +380,8 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
data := clientListJSON{}
|
||||
|
||||
clients.lock.Lock()
|
||||
for _, c := range clients.list {
|
||||
config.clients.lock.Lock()
|
||||
for _, c := range config.clients.list {
|
||||
cj := clientJSON{
|
||||
IP: c.IP,
|
||||
MAC: c.MAC,
|
||||
|
@ -394,7 +395,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
if len(c.MAC) != 0 {
|
||||
hwAddr, _ := net.ParseMAC(c.MAC)
|
||||
ipAddr := dhcpServer.FindIPbyMAC(hwAddr)
|
||||
ipAddr := config.dhcpServer.FindIPbyMAC(hwAddr)
|
||||
if ipAddr != nil {
|
||||
cj.IP = ipAddr.String()
|
||||
}
|
||||
|
@ -402,7 +403,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
for ip, ch := range clients.ipHost {
|
||||
for ip, ch := range config.clients.ipHost {
|
||||
cj := clientHostJSON{
|
||||
IP: ip,
|
||||
Name: ch.Host,
|
||||
|
@ -416,7 +417,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
data.AutoClients = append(data.AutoClients, cj)
|
||||
}
|
||||
clients.lock.Unlock()
|
||||
config.clients.lock.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
e := json.NewEncoder(w).Encode(data)
|
||||
|
@ -462,7 +463,7 @@ func handleAddClient(w http.ResponseWriter, r *http.Request) {
|
|||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
ok, err := clientAdd(*c)
|
||||
ok, err := config.clients.Add(*c)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
|
@ -492,7 +493,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if !clientDel(cj.Name) {
|
||||
if !config.clients.Del(cj.Name) {
|
||||
httpError(w, http.StatusBadRequest, "Client not found")
|
||||
return
|
||||
}
|
||||
|
@ -501,7 +502,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
|
|||
returnOK(w)
|
||||
}
|
||||
|
||||
type clientUpdateJSON struct {
|
||||
type updateJSON struct {
|
||||
Name string `json:"name"`
|
||||
Data clientJSON `json:"data"`
|
||||
}
|
||||
|
@ -515,7 +516,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var dj clientUpdateJSON
|
||||
var dj updateJSON
|
||||
err = json.Unmarshal(body, &dj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
|
@ -532,7 +533,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
err = clientUpdate(dj.Name, *c)
|
||||
err = config.clients.Update(dj.Name, *c)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
|
|
|
@ -6,17 +6,18 @@ func TestClients(t *testing.T) {
|
|||
var c Client
|
||||
var e error
|
||||
var b bool
|
||||
clients := clientsContainer{}
|
||||
|
||||
clientsInit()
|
||||
clients.Init()
|
||||
|
||||
// add
|
||||
c = Client{
|
||||
IP: "1.1.1.1",
|
||||
Name: "client1",
|
||||
}
|
||||
b, e = clientAdd(c)
|
||||
b, e = clients.Add(c)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAdd #1")
|
||||
t.Fatalf("Add #1")
|
||||
}
|
||||
|
||||
// add #2
|
||||
|
@ -24,19 +25,19 @@ func TestClients(t *testing.T) {
|
|||
IP: "2.2.2.2",
|
||||
Name: "client2",
|
||||
}
|
||||
b, e = clientAdd(c)
|
||||
b, e = clients.Add(c)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAdd #2")
|
||||
t.Fatalf("Add #2")
|
||||
}
|
||||
|
||||
c, b = clientFind("1.1.1.1")
|
||||
c, b = clients.Find("1.1.1.1")
|
||||
if !b || c.Name != "client1" {
|
||||
t.Fatalf("clientFind #1")
|
||||
t.Fatalf("Find #1")
|
||||
}
|
||||
|
||||
c, b = clientFind("2.2.2.2")
|
||||
c, b = clients.Find("2.2.2.2")
|
||||
if !b || c.Name != "client2" {
|
||||
t.Fatalf("clientFind #2")
|
||||
t.Fatalf("Find #2")
|
||||
}
|
||||
|
||||
// failed add - name in use
|
||||
|
@ -44,9 +45,9 @@ func TestClients(t *testing.T) {
|
|||
IP: "1.2.3.5",
|
||||
Name: "client1",
|
||||
}
|
||||
b, _ = clientAdd(c)
|
||||
b, _ = clients.Add(c)
|
||||
if b {
|
||||
t.Fatalf("clientAdd - name in use")
|
||||
t.Fatalf("Add - name in use")
|
||||
}
|
||||
|
||||
// failed add - ip in use
|
||||
|
@ -54,91 +55,91 @@ func TestClients(t *testing.T) {
|
|||
IP: "2.2.2.2",
|
||||
Name: "client3",
|
||||
}
|
||||
b, e = clientAdd(c)
|
||||
b, e = clients.Add(c)
|
||||
if b || e == nil {
|
||||
t.Fatalf("clientAdd - ip in use")
|
||||
t.Fatalf("Add - ip in use")
|
||||
}
|
||||
|
||||
// get
|
||||
if clientExists("1.2.3.4") {
|
||||
t.Fatalf("clientExists")
|
||||
if clients.Exists("1.2.3.4") {
|
||||
t.Fatalf("Exists")
|
||||
}
|
||||
if !clientExists("1.1.1.1") {
|
||||
t.Fatalf("clientExists #1")
|
||||
if !clients.Exists("1.1.1.1") {
|
||||
t.Fatalf("Exists #1")
|
||||
}
|
||||
if !clientExists("2.2.2.2") {
|
||||
t.Fatalf("clientExists #2")
|
||||
if !clients.Exists("2.2.2.2") {
|
||||
t.Fatalf("Exists #2")
|
||||
}
|
||||
|
||||
// failed update - no such name
|
||||
c.IP = "1.2.3.0"
|
||||
c.Name = "client3"
|
||||
if clientUpdate("client3", c) == nil {
|
||||
t.Fatalf("clientUpdate")
|
||||
if clients.Update("client3", c) == nil {
|
||||
t.Fatalf("Update")
|
||||
}
|
||||
|
||||
// failed update - name in use
|
||||
c.IP = "1.2.3.0"
|
||||
c.Name = "client2"
|
||||
if clientUpdate("client1", c) == nil {
|
||||
t.Fatalf("clientUpdate - name in use")
|
||||
if clients.Update("client1", c) == nil {
|
||||
t.Fatalf("Update - name in use")
|
||||
}
|
||||
|
||||
// failed update - ip in use
|
||||
c.IP = "2.2.2.2"
|
||||
c.Name = "client1"
|
||||
if clientUpdate("client1", c) == nil {
|
||||
t.Fatalf("clientUpdate - ip in use")
|
||||
if clients.Update("client1", c) == nil {
|
||||
t.Fatalf("Update - ip in use")
|
||||
}
|
||||
|
||||
// update
|
||||
c.IP = "1.1.1.2"
|
||||
c.Name = "client1"
|
||||
if clientUpdate("client1", c) != nil {
|
||||
t.Fatalf("clientUpdate")
|
||||
if clients.Update("client1", c) != nil {
|
||||
t.Fatalf("Update")
|
||||
}
|
||||
|
||||
// get after update
|
||||
if clientExists("1.1.1.1") || !clientExists("1.1.1.2") {
|
||||
t.Fatalf("clientExists - get after update")
|
||||
if clients.Exists("1.1.1.1") || !clients.Exists("1.1.1.2") {
|
||||
t.Fatalf("Exists - get after update")
|
||||
}
|
||||
|
||||
// failed remove - no such name
|
||||
if clientDel("client3") {
|
||||
t.Fatalf("clientDel - no such name")
|
||||
if clients.Del("client3") {
|
||||
t.Fatalf("Del - no such name")
|
||||
}
|
||||
|
||||
// remove
|
||||
if !clientDel("client1") || clientExists("1.1.1.2") {
|
||||
t.Fatalf("clientDel")
|
||||
if !clients.Del("client1") || clients.Exists("1.1.1.2") {
|
||||
t.Fatalf("Del")
|
||||
}
|
||||
|
||||
// add host client
|
||||
b, e = clientAddHost("1.1.1.1", "host", ClientSourceARP)
|
||||
b, e = clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAddHost")
|
||||
}
|
||||
|
||||
// failed add - ip exists
|
||||
b, e = clientAddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||
b, e = clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||
if b || e != nil {
|
||||
t.Fatalf("clientAddHost - ip exists")
|
||||
}
|
||||
|
||||
// overwrite with new data
|
||||
b, e = clientAddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||
b, e = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAddHost - overwrite with new data")
|
||||
}
|
||||
|
||||
// overwrite with new data (higher priority)
|
||||
b, e = clientAddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||
b, e = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAddHost - overwrite with new data (higher priority)")
|
||||
}
|
||||
|
||||
// get
|
||||
if !clientExists("1.1.1.1") {
|
||||
if !clients.Exists("1.1.1.1") {
|
||||
t.Fatalf("clientAddHost")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package home
|
|||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
@ -38,6 +39,13 @@ type clientObject struct {
|
|||
SafeBrowsingEnabled bool `yaml:"safesearch_enabled"`
|
||||
}
|
||||
|
||||
type HTTPSServer struct {
|
||||
server *http.Server
|
||||
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
|
||||
sync.Mutex // protects config.TLS
|
||||
shutdown bool // if TRUE, don't restart the server
|
||||
}
|
||||
|
||||
// configuration is loaded from YAML
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type configuration struct {
|
||||
|
@ -48,10 +56,25 @@ type configuration struct {
|
|||
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
|
||||
ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else
|
||||
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
|
||||
pidFileName string // PID file name. Empty if no PID file was created.
|
||||
// runningAsService flag is set to true when options are passed from the service runner
|
||||
runningAsService bool
|
||||
disableUpdate bool // If set, don't check for updates
|
||||
appSignalChannel chan os.Signal
|
||||
clients clientsContainer
|
||||
controlLock sync.Mutex
|
||||
transport *http.Transport
|
||||
client *http.Client
|
||||
|
||||
// cached version.json to avoid hammering github.io for each page reload
|
||||
versionCheckJSON []byte
|
||||
versionCheckLastTime time.Time
|
||||
|
||||
dnsctx dnsContext
|
||||
dnsServer *dnsforward.Server
|
||||
dhcpServer dhcpd.Server
|
||||
httpServer *http.Server
|
||||
httpsServer HTTPSServer
|
||||
|
||||
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
||||
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
||||
|
@ -127,7 +150,6 @@ type tlsConfig struct {
|
|||
}
|
||||
|
||||
// initialize to default values, will be changed later when reading config or parsing command line
|
||||
// TODO: Get rid of global variables
|
||||
var config = configuration{
|
||||
ourConfigFilename: "AdGuardHome.yaml",
|
||||
BindPort: 3000,
|
||||
|
@ -167,8 +189,16 @@ var config = configuration{
|
|||
SchemaVersion: currentSchemaVersion,
|
||||
}
|
||||
|
||||
// init initializes default configuration for the current OS&ARCH
|
||||
func init() {
|
||||
// initConfig initializes default configuration for the current OS&ARCH
|
||||
func initConfig() {
|
||||
config.transport = &http.Transport{
|
||||
DialContext: customDialContext,
|
||||
}
|
||||
config.client = &http.Client{
|
||||
Timeout: time.Minute * 5,
|
||||
Transport: config.transport,
|
||||
}
|
||||
|
||||
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
|
||||
// Use plain DNS on MIPS, encryption is too slow
|
||||
defaultDNS = []string{"1.1.1.1", "1.0.0.1"}
|
||||
|
@ -233,7 +263,7 @@ func parseConfig() error {
|
|||
SafeSearchEnabled: cy.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
|
||||
}
|
||||
_, err = clientAdd(cli)
|
||||
_, err = config.clients.Add(cli)
|
||||
if err != nil {
|
||||
log.Tracef("clientAdd: %s", err)
|
||||
}
|
||||
|
@ -268,7 +298,7 @@ func (c *configuration) write() error {
|
|||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
clientsList := clientsGetList()
|
||||
clientsList := config.clients.GetList()
|
||||
for _, cli := range clientsList {
|
||||
ip := cli.IP
|
||||
if len(cli.MAC) != 0 {
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
|
@ -25,23 +24,8 @@ import (
|
|||
|
||||
const updatePeriod = time.Minute * 30
|
||||
|
||||
// cached version.json to avoid hammering github.io for each page reload
|
||||
var versionCheckJSON []byte
|
||||
var versionCheckLastTime time.Time
|
||||
|
||||
var protocols = []string{"tls://", "https://", "tcp://", "sdns://"}
|
||||
|
||||
var transport = &http.Transport{
|
||||
DialContext: customDialContext,
|
||||
}
|
||||
|
||||
var client = &http.Client{
|
||||
Timeout: time.Minute * 5,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
var controlLock sync.Mutex
|
||||
|
||||
// ----------------
|
||||
// helper functions
|
||||
// ----------------
|
||||
|
@ -188,7 +172,7 @@ func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
func handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
data := dnsServer.GetQueryLog()
|
||||
data := config.dnsServer.GetQueryLog()
|
||||
|
||||
jsonVal, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
|
@ -205,7 +189,7 @@ func handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
func handleStatsTop(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
s := dnsServer.GetStatsTop()
|
||||
s := config.dnsServer.GetStatsTop()
|
||||
|
||||
// use manual json marshalling because we want maps to be sorted by value
|
||||
statsJSON := bytes.Buffer{}
|
||||
|
@ -252,7 +236,7 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) {
|
|||
// handleStatsReset resets the stats caches
|
||||
func handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
dnsServer.PurgeStats()
|
||||
config.dnsServer.PurgeStats()
|
||||
_, err := fmt.Fprintf(w, "OK\n")
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
|
@ -262,7 +246,7 @@ func handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
|||
// handleStats returns aggregated stats data for the 24 hours
|
||||
func handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
summed := dnsServer.GetAggregatedStats()
|
||||
summed := config.dnsServer.GetAggregatedStats()
|
||||
|
||||
statsJSON, err := json.Marshal(summed)
|
||||
if err != nil {
|
||||
|
@ -309,7 +293,7 @@ func handleStatsHistory(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
data, err := dnsServer.GetStatsHistory(timeUnit, startTime, endTime)
|
||||
data, err := config.dnsServer.GetStatsHistory(timeUnit, startTime, endTime)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Cannot get stats history: %s", err)
|
||||
return
|
||||
|
@ -725,7 +709,7 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
|||
// Stop DNS server:
|
||||
// we close urlfilter object which in turn closes file descriptors to filter files.
|
||||
// Otherwise, Windows won't allow us to remove the file which is being currently used.
|
||||
_ = dnsServer.Stop()
|
||||
_ = config.dnsServer.Stop()
|
||||
|
||||
// go through each element and delete if url matches
|
||||
config.Lock()
|
||||
|
@ -984,7 +968,7 @@ func handleDOH(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
dnsServer.ServeHTTP(w, r)
|
||||
config.dnsServer.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// ------------------------
|
||||
|
|
|
@ -17,13 +17,13 @@ type accessListJSON struct {
|
|||
func handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
|
||||
controlLock.Lock()
|
||||
config.controlLock.Lock()
|
||||
j := accessListJSON{
|
||||
AllowedClients: config.DNS.AllowedClients,
|
||||
DisallowedClients: config.DNS.DisallowedClients,
|
||||
BlockedHosts: config.DNS.BlockedHosts,
|
||||
}
|
||||
controlLock.Unlock()
|
||||
config.controlLock.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(j)
|
||||
|
|
|
@ -264,7 +264,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
|||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||
if restartHTTP {
|
||||
go func() {
|
||||
httpServer.Shutdown(context.TODO())
|
||||
config.httpServer.Shutdown(context.TODO())
|
||||
}()
|
||||
}
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
|||
// check if port is available
|
||||
// BUT: if we are already using this port, no need
|
||||
alreadyRunning := false
|
||||
if httpsServer.server != nil {
|
||||
if config.httpsServer.server != nil {
|
||||
alreadyRunning = true
|
||||
}
|
||||
if !alreadyRunning {
|
||||
|
@ -72,7 +72,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
|||
// check if port is available
|
||||
// BUT: if we are already using this port, no need
|
||||
alreadyRunning := false
|
||||
if httpsServer.server != nil {
|
||||
if config.httpsServer.server != nil {
|
||||
alreadyRunning = true
|
||||
}
|
||||
if !alreadyRunning {
|
||||
|
@ -101,12 +101,12 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
|||
if restartHTTPS {
|
||||
go func() {
|
||||
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
|
||||
httpsServer.cond.L.Lock()
|
||||
httpsServer.cond.Broadcast()
|
||||
if httpsServer.server != nil {
|
||||
httpsServer.server.Shutdown(context.TODO())
|
||||
config.httpsServer.cond.L.Lock()
|
||||
config.httpsServer.cond.Broadcast()
|
||||
if config.httpsServer.server != nil {
|
||||
config.httpsServer.server.Shutdown(context.TODO())
|
||||
}
|
||||
httpsServer.cond.L.Unlock()
|
||||
config.httpsServer.cond.L.Unlock()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,10 +73,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
now := time.Now()
|
||||
if !req.RecheckNow {
|
||||
controlLock.Lock()
|
||||
cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0
|
||||
data := versionCheckJSON
|
||||
controlLock.Unlock()
|
||||
config.controlLock.Lock()
|
||||
cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0
|
||||
data := config.versionCheckJSON
|
||||
config.controlLock.Unlock()
|
||||
|
||||
if cached {
|
||||
log.Tracef("Returning cached data")
|
||||
|
@ -87,7 +87,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
log.Tracef("Downloading data from %s", versionCheckURL)
|
||||
resp, err := client.Get(versionCheckURL)
|
||||
resp, err := config.client.Get(versionCheckURL)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
|
||||
return
|
||||
|
@ -103,10 +103,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
controlLock.Lock()
|
||||
versionCheckLastTime = now
|
||||
versionCheckJSON = body
|
||||
controlLock.Unlock()
|
||||
config.controlLock.Lock()
|
||||
config.versionCheckLastTime = now
|
||||
config.versionCheckJSON = body
|
||||
config.controlLock.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(getVersionResp(body))
|
||||
|
@ -349,7 +349,7 @@ func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly,
|
|||
|
||||
// Download package file and save it to disk
|
||||
func getPackageFile(u *updateInfo) error {
|
||||
resp, err := client.Get(u.pkgURL)
|
||||
resp, err := config.client.Get(u.pkgURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("HTTP request failed: %s", err)
|
||||
}
|
||||
|
@ -501,12 +501,12 @@ func finishUpdate(u *updateInfo) {
|
|||
func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
|
||||
if len(versionCheckJSON) == 0 {
|
||||
if len(config.versionCheckJSON) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := getUpdateInfo(versionCheckJSON)
|
||||
u, err := getUpdateInfo(config.versionCheckJSON)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
|
|
24
home/dhcp.go
24
home/dhcp.go
|
@ -18,8 +18,6 @@ import (
|
|||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
var dhcpServer = dhcpd.Server{}
|
||||
|
||||
// []dhcpd.Lease -> JSON
|
||||
func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]string {
|
||||
leases := []map[string]string{}
|
||||
|
@ -41,8 +39,8 @@ func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]
|
|||
|
||||
func handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
leases := convertLeases(dhcpServer.Leases(), true)
|
||||
staticLeases := convertLeases(dhcpServer.StaticLeases(), false)
|
||||
leases := convertLeases(config.dhcpServer.Leases(), true)
|
||||
staticLeases := convertLeases(config.dhcpServer.StaticLeases(), false)
|
||||
status := map[string]interface{}{
|
||||
"config": config.DHCP,
|
||||
"leases": leases,
|
||||
|
@ -77,18 +75,18 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
err = dhcpServer.CheckConfig(newconfig.ServerConfig)
|
||||
err = config.dhcpServer.CheckConfig(newconfig.ServerConfig)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dhcpServer.Stop()
|
||||
err = config.dhcpServer.Stop()
|
||||
if err != nil {
|
||||
log.Error("failed to stop the DHCP server: %s", err)
|
||||
}
|
||||
|
||||
err = dhcpServer.Init(newconfig.ServerConfig)
|
||||
err = config.dhcpServer.Init(newconfig.ServerConfig)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
|
||||
return
|
||||
|
@ -105,7 +103,7 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
err = dhcpServer.Start()
|
||||
err = config.dhcpServer.Start()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
|
||||
return
|
||||
|
@ -389,7 +387,7 @@ func handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
|||
HWAddr: mac,
|
||||
Hostname: lj.Hostname,
|
||||
}
|
||||
err = dhcpServer.AddStaticLease(lease)
|
||||
err = config.dhcpServer.AddStaticLease(lease)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
|
@ -420,7 +418,7 @@ func handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
|||
HWAddr: mac,
|
||||
Hostname: lj.Hostname,
|
||||
}
|
||||
err = dhcpServer.RemoveStaticLease(lease)
|
||||
err = config.dhcpServer.RemoveStaticLease(lease)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
|
@ -434,12 +432,12 @@ func startDHCPServer() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
err := dhcpServer.Init(config.DHCP)
|
||||
err := config.dhcpServer.Init(config.DHCP)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't init DHCP server")
|
||||
}
|
||||
|
||||
err = dhcpServer.Start()
|
||||
err = config.dhcpServer.Start()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start DHCP server")
|
||||
}
|
||||
|
@ -451,7 +449,7 @@ func stopDHCPServer() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
err := dhcpServer.Stop()
|
||||
err := config.dhcpServer.Stop()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't stop DHCP server")
|
||||
}
|
||||
|
|
50
home/dns.go
50
home/dns.go
|
@ -17,8 +17,6 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var dnsServer *dnsforward.Server
|
||||
|
||||
const (
|
||||
rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
|
||||
)
|
||||
|
@ -32,8 +30,6 @@ type dnsContext struct {
|
|||
upstream upstream.Upstream // Upstream object for our own DNS server
|
||||
}
|
||||
|
||||
var dnsctx dnsContext
|
||||
|
||||
// initDNSServer creates an instance of the dnsforward.Server
|
||||
// Please note that we must do it even if we don't start it
|
||||
// so that we had access to the query log and the stats
|
||||
|
@ -43,7 +39,7 @@ func initDNSServer(baseDir string) {
|
|||
log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err)
|
||||
}
|
||||
|
||||
dnsServer = dnsforward.NewServer(baseDir)
|
||||
config.dnsServer = dnsforward.NewServer(baseDir)
|
||||
|
||||
bindhost := config.DNS.BindHost
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
|
@ -53,37 +49,37 @@ func initDNSServer(baseDir string) {
|
|||
opts := upstream.Options{
|
||||
Timeout: rdnsTimeout,
|
||||
}
|
||||
dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
|
||||
config.dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
|
||||
if err != nil {
|
||||
log.Error("upstream.AddressToUpstream: %s", err)
|
||||
return
|
||||
}
|
||||
dnsctx.rdnsIP = make(map[string]bool)
|
||||
dnsctx.rdnsChannel = make(chan string, 256)
|
||||
config.dnsctx.rdnsIP = make(map[string]bool)
|
||||
config.dnsctx.rdnsChannel = make(chan string, 256)
|
||||
go asyncRDNSLoop()
|
||||
}
|
||||
|
||||
func isRunning() bool {
|
||||
return dnsServer != nil && dnsServer.IsRunning()
|
||||
return config.dnsServer != nil && config.dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
func beginAsyncRDNS(ip string) {
|
||||
if clientExists(ip) {
|
||||
if config.clients.Exists(ip) {
|
||||
return
|
||||
}
|
||||
|
||||
// add IP to rdnsIP, if not exists
|
||||
dnsctx.rdnsLock.Lock()
|
||||
defer dnsctx.rdnsLock.Unlock()
|
||||
_, ok := dnsctx.rdnsIP[ip]
|
||||
config.dnsctx.rdnsLock.Lock()
|
||||
defer config.dnsctx.rdnsLock.Unlock()
|
||||
_, ok := config.dnsctx.rdnsIP[ip]
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
dnsctx.rdnsIP[ip] = true
|
||||
config.dnsctx.rdnsIP[ip] = true
|
||||
|
||||
log.Tracef("Adding %s for rDNS resolve", ip)
|
||||
select {
|
||||
case dnsctx.rdnsChannel <- ip:
|
||||
case config.dnsctx.rdnsChannel <- ip:
|
||||
//
|
||||
default:
|
||||
log.Tracef("rDNS queue is full")
|
||||
|
@ -110,7 +106,7 @@ func resolveRDNS(ip string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
resp, err := dnsctx.upstream.Exchange(&req)
|
||||
resp, err := config.dnsctx.upstream.Exchange(&req)
|
||||
if err != nil {
|
||||
log.Error("Error while making an rDNS lookup for %s: %s", ip, err)
|
||||
return ""
|
||||
|
@ -138,18 +134,18 @@ func resolveRDNS(ip string) string {
|
|||
func asyncRDNSLoop() {
|
||||
for {
|
||||
var ip string
|
||||
ip = <-dnsctx.rdnsChannel
|
||||
ip = <-config.dnsctx.rdnsChannel
|
||||
|
||||
host := resolveRDNS(ip)
|
||||
if len(host) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
dnsctx.rdnsLock.Lock()
|
||||
delete(dnsctx.rdnsIP, ip)
|
||||
dnsctx.rdnsLock.Unlock()
|
||||
config.dnsctx.rdnsLock.Lock()
|
||||
delete(config.dnsctx.rdnsIP, ip)
|
||||
config.dnsctx.rdnsLock.Unlock()
|
||||
|
||||
_, _ = clientAddHost(ip, host, ClientSourceRDNS)
|
||||
_, _ = config.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -221,7 +217,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) {
|
|||
|
||||
// If a client has his own settings, apply them
|
||||
func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
c, ok := clientFind(clientAddr)
|
||||
c, ok := config.clients.Find(clientAddr)
|
||||
if !ok || !c.UseOwnSettings {
|
||||
return
|
||||
}
|
||||
|
@ -242,12 +238,12 @@ func startDNSServer() error {
|
|||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
}
|
||||
err = dnsServer.Start(&newconfig)
|
||||
err = config.dnsServer.Start(&newconfig)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
}
|
||||
|
||||
top := dnsServer.GetStatsTop()
|
||||
top := config.dnsServer.GetStatsTop()
|
||||
for k := range top.Clients {
|
||||
beginAsyncRDNS(k)
|
||||
}
|
||||
|
@ -256,11 +252,11 @@ func startDNSServer() error {
|
|||
}
|
||||
|
||||
func reconfigureDNSServer() error {
|
||||
config, err := generateServerConfig()
|
||||
newconfig, err := generateServerConfig()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
}
|
||||
err = dnsServer.Reconfigure(&config)
|
||||
err = config.dnsServer.Reconfigure(&newconfig)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
}
|
||||
|
@ -273,7 +269,7 @@ func stopDNSServer() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
err := dnsServer.Stop()
|
||||
err := config.dnsServer.Stop()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
|
||||
}
|
||||
|
|
|
@ -222,7 +222,7 @@ func refreshFiltersIfNecessary(force bool) int {
|
|||
|
||||
stopped := false
|
||||
if updateCount != 0 {
|
||||
_ = dnsServer.Stop()
|
||||
_ = config.dnsServer.Stop()
|
||||
stopped = true
|
||||
}
|
||||
|
||||
|
@ -308,7 +308,7 @@ func parseFilterContents(contents []byte) (int, string) {
|
|||
func (filter *filter) update() (bool, error) {
|
||||
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
|
||||
|
||||
resp, err := client.Get(filter.URL)
|
||||
resp, err := config.client.Get(filter.URL)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
|
|
@ -35,8 +35,8 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
|
|||
}
|
||||
|
||||
if method == "POST" || method == "PUT" || method == "DELETE" {
|
||||
controlLock.Lock()
|
||||
defer controlLock.Unlock()
|
||||
config.controlLock.Lock()
|
||||
defer config.controlLock.Unlock()
|
||||
}
|
||||
|
||||
handler(w, r)
|
||||
|
@ -148,7 +148,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
|
|||
return
|
||||
}
|
||||
// enforce https?
|
||||
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && httpsServer.server != nil {
|
||||
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && config.httpsServer.server != nil {
|
||||
// yes, and we want host from host:port
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
|
|
48
home/home.go
48
home/home.go
|
@ -25,15 +25,6 @@ import (
|
|||
"github.com/gobuffalo/packr"
|
||||
)
|
||||
|
||||
var httpServer *http.Server
|
||||
var httpsServer struct {
|
||||
server *http.Server
|
||||
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
|
||||
sync.Mutex // protects config.TLS
|
||||
shutdown bool // if TRUE, don't restart the server
|
||||
}
|
||||
var pidFileName string // PID file name. Empty if no PID file was created.
|
||||
|
||||
const (
|
||||
// Used in config to indicate that syslog or eventlog (win) should be used for logger output
|
||||
configSyslog = "syslog"
|
||||
|
@ -48,7 +39,7 @@ var (
|
|||
|
||||
const versionCheckPeriod = time.Hour * 8
|
||||
|
||||
// main is the entry point
|
||||
// Main is the entry point
|
||||
func Main(version string, channel string) {
|
||||
// Init update-related global variables
|
||||
versionString = version
|
||||
|
@ -108,7 +99,8 @@ func run(args options) {
|
|||
os.Exit(0)
|
||||
}()
|
||||
|
||||
clientsInit()
|
||||
initConfig()
|
||||
config.clients.Init()
|
||||
|
||||
if !config.firstRun {
|
||||
// Do the upgrade if necessary
|
||||
|
@ -168,7 +160,7 @@ func run(args options) {
|
|||
}
|
||||
|
||||
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
||||
pidFileName = args.pidFile
|
||||
config.pidFileName = args.pidFile
|
||||
}
|
||||
|
||||
// Update filters we've just loaded right away, don't wait for periodic update timer
|
||||
|
@ -192,21 +184,21 @@ func run(args options) {
|
|||
registerInstallHandlers()
|
||||
}
|
||||
|
||||
httpsServer.cond = sync.NewCond(&httpsServer.Mutex)
|
||||
config.httpsServer.cond = sync.NewCond(&config.httpsServer.Mutex)
|
||||
|
||||
// for https, we have a separate goroutine loop
|
||||
go httpServerLoop()
|
||||
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !httpsServer.shutdown {
|
||||
for !config.httpsServer.shutdown {
|
||||
printHTTPAddresses("http")
|
||||
|
||||
// we need to have new instance, because after Shutdown() the Server is not usable
|
||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||
httpServer = &http.Server{
|
||||
config.httpServer = &http.Server{
|
||||
Addr: address,
|
||||
}
|
||||
err := httpServer.ListenAndServe()
|
||||
err := config.httpServer.ListenAndServe()
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
|
@ -219,14 +211,14 @@ func run(args options) {
|
|||
}
|
||||
|
||||
func httpServerLoop() {
|
||||
for !httpsServer.shutdown {
|
||||
httpsServer.cond.L.Lock()
|
||||
for !config.httpsServer.shutdown {
|
||||
config.httpsServer.cond.L.Lock()
|
||||
// this mechanism doesn't let us through until all conditions are met
|
||||
for config.TLS.Enabled == false ||
|
||||
config.TLS.PortHTTPS == 0 ||
|
||||
config.TLS.PrivateKey == "" ||
|
||||
config.TLS.CertificateChain == "" { // sleep until necessary data is supplied
|
||||
httpsServer.cond.Wait()
|
||||
config.httpsServer.cond.Wait()
|
||||
}
|
||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
|
||||
// validate current TLS config and update warnings (it could have been loaded from file)
|
||||
|
@ -250,10 +242,10 @@ func httpServerLoop() {
|
|||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
httpsServer.cond.L.Unlock()
|
||||
config.httpsServer.cond.L.Unlock()
|
||||
|
||||
// prepare HTTPS server
|
||||
httpsServer.server = &http.Server{
|
||||
config.httpsServer.server = &http.Server{
|
||||
Addr: address,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
|
@ -262,7 +254,7 @@ func httpServerLoop() {
|
|||
}
|
||||
|
||||
printHTTPAddresses("https")
|
||||
err = httpsServer.server.ListenAndServeTLS("", "")
|
||||
err = config.httpsServer.server.ListenAndServeTLS("", "")
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
|
@ -399,17 +391,17 @@ func cleanup() {
|
|||
|
||||
// Stop HTTP server, possibly waiting for all active connections to be closed
|
||||
func stopHTTPServer() {
|
||||
httpsServer.shutdown = true
|
||||
if httpsServer.server != nil {
|
||||
httpsServer.server.Shutdown(context.TODO())
|
||||
config.httpsServer.shutdown = true
|
||||
if config.httpsServer.server != nil {
|
||||
config.httpsServer.server.Shutdown(context.TODO())
|
||||
}
|
||||
httpServer.Shutdown(context.TODO())
|
||||
config.httpServer.Shutdown(context.TODO())
|
||||
}
|
||||
|
||||
// This function is called before application exits
|
||||
func cleanupAlways() {
|
||||
if len(pidFileName) != 0 {
|
||||
os.Remove(pidFileName)
|
||||
if len(config.pidFileName) != 0 {
|
||||
os.Remove(config.pidFileName)
|
||||
}
|
||||
log.Info("Stopped")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue