Merge: Refactoring: move global variables; move initialization of periodic tasks

Close #583

* commit 'b8444ff46aff5e45194c6cb61fdf4d3e7aa798fa':
  * minor
  * dnsforward: move initialization of periodic tasks to NewServer()
  * move "dnsctx" to "config"
  * move "dnsServer" to "config"
  * move "dhcpServer" to "config"
  * move "httpServer" to "config"
  * move "httpsServer" to "config"
  * move "pidFileName" to "config"
  * move "versionCheckJSON" to "config"
  * move "client", "transport" to "config"
  * move "controlLock" mutex to "config"
  * clients: move container object to "config"
This commit is contained in:
Simon Zolin 2019-07-19 15:18:51 +03:00
commit bd5162ada3
14 changed files with 213 additions and 221 deletions

View File

@ -41,7 +41,6 @@ type Server struct {
dnsFilter *dnsfilter.Dnsfilter // DNS filter instance dnsFilter *dnsfilter.Dnsfilter // DNS filter instance
queryLog *queryLog // Query log instance queryLog *queryLog // Query log instance
stats *stats // General server statistics stats *stats // General server statistics
once sync.Once
AllowedClients map[string]bool // IP addresses of whitelist clients AllowedClients map[string]bool // IP addresses of whitelist clients
DisallowedClients map[string]bool // IP addresses of clients that should be blocked DisallowedClients map[string]bool // IP addresses of clients that should be blocked
@ -55,11 +54,24 @@ type Server struct {
// NewServer creates a new instance of the dnsforward.Server // NewServer creates a new instance of the dnsforward.Server
// baseDir is the base directory for query logs // baseDir is the base directory for query logs
// Note: this function must be called only once
func NewServer(baseDir string) *Server { func NewServer(baseDir string) *Server {
return &Server{ s := &Server{
queryLog: newQueryLog(baseDir), queryLog: newQueryLog(baseDir),
stats: newStats(), stats: newStats(),
} }
log.Tracef("Loading stats from querylog")
err := s.queryLog.fillStatsFromQueryLog(s.stats)
if err != nil {
log.Error("failed to load stats from querylog: %s", err)
}
log.Printf("Start DNS server periodic jobs")
go s.queryLog.periodicQueryLogRotate()
go s.queryLog.runningTop.periodicHourlyTopRotate()
go s.stats.statsRotator()
return s
} }
// FilteringConfig represents the DNS filtering configuration of AdGuard Home // FilteringConfig represents the DNS filtering configuration of AdGuard Home
@ -169,33 +181,11 @@ func (s *Server) startInternal(config *ServerConfig) error {
return errors.New("DNS server is already started") return errors.New("DNS server is already started")
} }
if s.queryLog == nil {
s.queryLog = newQueryLog(".")
}
if s.stats == nil {
s.stats = newStats()
}
err := s.initDNSFilter() err := s.initDNSFilter()
if err != nil { if err != nil {
return err return err
} }
log.Tracef("Loading stats from querylog")
err = s.queryLog.fillStatsFromQueryLog(s.stats)
if err != nil {
return errorx.Decorate(err, "failed to load stats from querylog")
}
// TODO: Think about reworking this, the current approach won't work properly if AG Home is restarted periodically
s.once.Do(func() {
log.Printf("Start DNS server periodic jobs")
go s.queryLog.periodicQueryLogRotate()
go s.queryLog.runningTop.periodicHourlyTopRotate()
go s.stats.statsRotator()
})
proxyConfig := proxy.Config{ proxyConfig := proxy.Config{
UDPListenAddr: s.conf.UDPListenAddr, UDPListenAddr: s.conf.UDPListenAddr,
TCPListenAddr: s.conf.TCPListenAddr, TCPListenAddr: s.conf.TCPListenAddr,

View File

@ -66,10 +66,9 @@ type clientsContainer struct {
lock sync.Mutex lock sync.Mutex
} }
var clients clientsContainer // Init initializes clients container
// Note: this function must be called only once
// Initialize clients container func (clients *clientsContainer) Init() {
func clientsInit() {
if clients.list != nil { if clients.list != nil {
log.Fatal("clients.list != nil") log.Fatal("clients.list != nil")
} }
@ -77,22 +76,24 @@ func clientsInit() {
clients.ipIndex = make(map[string]*Client) clients.ipIndex = make(map[string]*Client)
clients.ipHost = make(map[string]ClientHost) clients.ipHost = make(map[string]ClientHost)
go periodicClientsUpdate() go clients.periodicUpdate()
} }
func periodicClientsUpdate() { func (clients *clientsContainer) periodicUpdate() {
for { for {
clientsAddFromHostsFile() clients.addFromHostsFile()
clientsAddFromSystemARP() clients.addFromSystemARP()
time.Sleep(clientsUpdatePeriod) time.Sleep(clientsUpdatePeriod)
} }
} }
func clientsGetList() map[string]*Client { // GetList returns the pointer to clients list
func (clients *clientsContainer) GetList() map[string]*Client {
return clients.list return clients.list
} }
func clientExists(ip string) bool { // Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip string) bool {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -105,8 +106,8 @@ func clientExists(ip string) bool {
return ok return ok
} }
// Search for a client by IP // Find searches for a client by IP
func clientFind(ip string) (Client, bool) { func (clients *clientsContainer) Find(ip string) (Client, bool) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -121,7 +122,7 @@ func clientFind(ip string) (Client, bool) {
if err != nil { if err != nil {
continue continue
} }
ipAddr := dhcpServer.FindIPbyMAC(mac) ipAddr := config.dhcpServer.FindIPbyMAC(mac)
if ipAddr == nil { if ipAddr == nil {
continue continue
} }
@ -135,7 +136,7 @@ func clientFind(ip string) (Client, bool) {
} }
// Check if Client object's fields are correct // Check if Client object's fields are correct
func clientCheck(c *Client) error { func (c *Client) check() error {
if len(c.Name) == 0 { if len(c.Name) == 0 {
return fmt.Errorf("Invalid Name") return fmt.Errorf("Invalid Name")
} }
@ -162,8 +163,8 @@ func clientCheck(c *Client) error {
// Add a new client object // Add a new client object
// Return true: success; false: client exists. // Return true: success; false: client exists.
func clientAdd(c Client) (bool, error) { func (clients *clientsContainer) Add(c Client) (bool, error) {
e := clientCheck(&c) e := c.check()
if e != nil { if e != nil {
return false, e return false, e
} }
@ -194,8 +195,8 @@ func clientAdd(c Client) (bool, error) {
return true, nil return true, nil
} }
// Remove a client // Del removes a client
func clientDel(name string) bool { func (clients *clientsContainer) Del(name string) bool {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -210,8 +211,8 @@ func clientDel(name string) bool {
} }
// Update a client // Update a client
func clientUpdate(name string, c Client) error { func (clients *clientsContainer) Update(name string, c Client) error {
err := clientCheck(&c) err := c.check()
if err != nil { if err != nil {
return err return err
} }
@ -257,10 +258,10 @@ func clientUpdate(name string, c Client) error {
return nil return nil
} }
// Add new IP -> Host pair // AddHost adds new IP -> Host pair
// Use priority of the source (etc/hosts > ARP > rDNS) // Use priority of the source (etc/hosts > ARP > rDNS)
// so we overwrite existing entries with an equal or higher priority // so we overwrite existing entries with an equal or higher priority
func clientAddHost(ip, host string, source clientSource) (bool, error) { func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -279,7 +280,7 @@ func clientAddHost(ip, host string, source clientSource) (bool, error) {
} }
// Parse system 'hosts' file and fill clients array // Parse system 'hosts' file and fill clients array
func clientsAddFromHostsFile() { func (clients *clientsContainer) addFromHostsFile() {
hostsFn := "/etc/hosts" hostsFn := "/etc/hosts"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts") hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
@ -304,7 +305,7 @@ func clientsAddFromHostsFile() {
continue continue
} }
ok, e := clientAddHost(fields[0], fields[1], ClientSourceHostsFile) ok, e := clients.AddHost(fields[0], fields[1], ClientSourceHostsFile)
if e != nil { if e != nil {
log.Tracef("%s", e) log.Tracef("%s", e)
} }
@ -319,7 +320,7 @@ func clientsAddFromHostsFile() {
// Add IP -> Host pairs from the system's `arp -a` command output // Add IP -> Host pairs from the system's `arp -a` command output
// The command's output is: // The command's output is:
// HOST (IP) at MAC on IFACE // HOST (IP) at MAC on IFACE
func clientsAddFromSystemARP() { func (clients *clientsContainer) addFromSystemARP() {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
return return
@ -350,7 +351,7 @@ func clientsAddFromSystemARP() {
continue continue
} }
ok, e := clientAddHost(ip, host, ClientSourceARP) ok, e := clients.AddHost(ip, host, ClientSourceARP)
if e != nil { if e != nil {
log.Tracef("%s", e) log.Tracef("%s", e)
} }
@ -379,8 +380,8 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
data := clientListJSON{} data := clientListJSON{}
clients.lock.Lock() config.clients.lock.Lock()
for _, c := range clients.list { for _, c := range config.clients.list {
cj := clientJSON{ cj := clientJSON{
IP: c.IP, IP: c.IP,
MAC: c.MAC, MAC: c.MAC,
@ -394,7 +395,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
if len(c.MAC) != 0 { if len(c.MAC) != 0 {
hwAddr, _ := net.ParseMAC(c.MAC) hwAddr, _ := net.ParseMAC(c.MAC)
ipAddr := dhcpServer.FindIPbyMAC(hwAddr) ipAddr := config.dhcpServer.FindIPbyMAC(hwAddr)
if ipAddr != nil { if ipAddr != nil {
cj.IP = ipAddr.String() cj.IP = ipAddr.String()
} }
@ -402,7 +403,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
data.Clients = append(data.Clients, cj) data.Clients = append(data.Clients, cj)
} }
for ip, ch := range clients.ipHost { for ip, ch := range config.clients.ipHost {
cj := clientHostJSON{ cj := clientHostJSON{
IP: ip, IP: ip,
Name: ch.Host, Name: ch.Host,
@ -416,7 +417,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
} }
data.AutoClients = append(data.AutoClients, cj) data.AutoClients = append(data.AutoClients, cj)
} }
clients.lock.Unlock() config.clients.lock.Unlock()
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data) e := json.NewEncoder(w).Encode(data)
@ -462,7 +463,7 @@ func handleAddClient(w http.ResponseWriter, r *http.Request) {
httpError(w, http.StatusBadRequest, "%s", err) httpError(w, http.StatusBadRequest, "%s", err)
return return
} }
ok, err := clientAdd(*c) ok, err := config.clients.Add(*c)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) httpError(w, http.StatusBadRequest, "%s", err)
return return
@ -492,7 +493,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
return return
} }
if !clientDel(cj.Name) { if !config.clients.Del(cj.Name) {
httpError(w, http.StatusBadRequest, "Client not found") httpError(w, http.StatusBadRequest, "Client not found")
return return
} }
@ -501,7 +502,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
returnOK(w) returnOK(w)
} }
type clientUpdateJSON struct { type updateJSON struct {
Name string `json:"name"` Name string `json:"name"`
Data clientJSON `json:"data"` Data clientJSON `json:"data"`
} }
@ -515,7 +516,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
return return
} }
var dj clientUpdateJSON var dj updateJSON
err = json.Unmarshal(body, &dj) err = json.Unmarshal(body, &dj)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err) httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
@ -532,7 +533,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
return return
} }
err = clientUpdate(dj.Name, *c) err = config.clients.Update(dj.Name, *c)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) httpError(w, http.StatusBadRequest, "%s", err)
return return

View File

@ -6,17 +6,18 @@ func TestClients(t *testing.T) {
var c Client var c Client
var e error var e error
var b bool var b bool
clients := clientsContainer{}
clientsInit() clients.Init()
// add // add
c = Client{ c = Client{
IP: "1.1.1.1", IP: "1.1.1.1",
Name: "client1", Name: "client1",
} }
b, e = clientAdd(c) b, e = clients.Add(c)
if !b || e != nil { if !b || e != nil {
t.Fatalf("clientAdd #1") t.Fatalf("Add #1")
} }
// add #2 // add #2
@ -24,19 +25,19 @@ func TestClients(t *testing.T) {
IP: "2.2.2.2", IP: "2.2.2.2",
Name: "client2", Name: "client2",
} }
b, e = clientAdd(c) b, e = clients.Add(c)
if !b || e != nil { if !b || e != nil {
t.Fatalf("clientAdd #2") t.Fatalf("Add #2")
} }
c, b = clientFind("1.1.1.1") c, b = clients.Find("1.1.1.1")
if !b || c.Name != "client1" { if !b || c.Name != "client1" {
t.Fatalf("clientFind #1") t.Fatalf("Find #1")
} }
c, b = clientFind("2.2.2.2") c, b = clients.Find("2.2.2.2")
if !b || c.Name != "client2" { if !b || c.Name != "client2" {
t.Fatalf("clientFind #2") t.Fatalf("Find #2")
} }
// failed add - name in use // failed add - name in use
@ -44,9 +45,9 @@ func TestClients(t *testing.T) {
IP: "1.2.3.5", IP: "1.2.3.5",
Name: "client1", Name: "client1",
} }
b, _ = clientAdd(c) b, _ = clients.Add(c)
if b { if b {
t.Fatalf("clientAdd - name in use") t.Fatalf("Add - name in use")
} }
// failed add - ip in use // failed add - ip in use
@ -54,91 +55,91 @@ func TestClients(t *testing.T) {
IP: "2.2.2.2", IP: "2.2.2.2",
Name: "client3", Name: "client3",
} }
b, e = clientAdd(c) b, e = clients.Add(c)
if b || e == nil { if b || e == nil {
t.Fatalf("clientAdd - ip in use") t.Fatalf("Add - ip in use")
} }
// get // get
if clientExists("1.2.3.4") { if clients.Exists("1.2.3.4") {
t.Fatalf("clientExists") t.Fatalf("Exists")
} }
if !clientExists("1.1.1.1") { if !clients.Exists("1.1.1.1") {
t.Fatalf("clientExists #1") t.Fatalf("Exists #1")
} }
if !clientExists("2.2.2.2") { if !clients.Exists("2.2.2.2") {
t.Fatalf("clientExists #2") t.Fatalf("Exists #2")
} }
// failed update - no such name // failed update - no such name
c.IP = "1.2.3.0" c.IP = "1.2.3.0"
c.Name = "client3" c.Name = "client3"
if clientUpdate("client3", c) == nil { if clients.Update("client3", c) == nil {
t.Fatalf("clientUpdate") t.Fatalf("Update")
} }
// failed update - name in use // failed update - name in use
c.IP = "1.2.3.0" c.IP = "1.2.3.0"
c.Name = "client2" c.Name = "client2"
if clientUpdate("client1", c) == nil { if clients.Update("client1", c) == nil {
t.Fatalf("clientUpdate - name in use") t.Fatalf("Update - name in use")
} }
// failed update - ip in use // failed update - ip in use
c.IP = "2.2.2.2" c.IP = "2.2.2.2"
c.Name = "client1" c.Name = "client1"
if clientUpdate("client1", c) == nil { if clients.Update("client1", c) == nil {
t.Fatalf("clientUpdate - ip in use") t.Fatalf("Update - ip in use")
} }
// update // update
c.IP = "1.1.1.2" c.IP = "1.1.1.2"
c.Name = "client1" c.Name = "client1"
if clientUpdate("client1", c) != nil { if clients.Update("client1", c) != nil {
t.Fatalf("clientUpdate") t.Fatalf("Update")
} }
// get after update // get after update
if clientExists("1.1.1.1") || !clientExists("1.1.1.2") { if clients.Exists("1.1.1.1") || !clients.Exists("1.1.1.2") {
t.Fatalf("clientExists - get after update") t.Fatalf("Exists - get after update")
} }
// failed remove - no such name // failed remove - no such name
if clientDel("client3") { if clients.Del("client3") {
t.Fatalf("clientDel - no such name") t.Fatalf("Del - no such name")
} }
// remove // remove
if !clientDel("client1") || clientExists("1.1.1.2") { if !clients.Del("client1") || clients.Exists("1.1.1.2") {
t.Fatalf("clientDel") t.Fatalf("Del")
} }
// add host client // add host client
b, e = clientAddHost("1.1.1.1", "host", ClientSourceARP) b, e = clients.AddHost("1.1.1.1", "host", ClientSourceARP)
if !b || e != nil { if !b || e != nil {
t.Fatalf("clientAddHost") t.Fatalf("clientAddHost")
} }
// failed add - ip exists // failed add - ip exists
b, e = clientAddHost("1.1.1.1", "host1", ClientSourceRDNS) b, e = clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
if b || e != nil { if b || e != nil {
t.Fatalf("clientAddHost - ip exists") t.Fatalf("clientAddHost - ip exists")
} }
// overwrite with new data // overwrite with new data
b, e = clientAddHost("1.1.1.1", "host2", ClientSourceARP) b, e = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
if !b || e != nil { if !b || e != nil {
t.Fatalf("clientAddHost - overwrite with new data") t.Fatalf("clientAddHost - overwrite with new data")
} }
// overwrite with new data (higher priority) // overwrite with new data (higher priority)
b, e = clientAddHost("1.1.1.1", "host3", ClientSourceHostsFile) b, e = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
if !b || e != nil { if !b || e != nil {
t.Fatalf("clientAddHost - overwrite with new data (higher priority)") t.Fatalf("clientAddHost - overwrite with new data (higher priority)")
} }
// get // get
if !clientExists("1.1.1.1") { if !clients.Exists("1.1.1.1") {
t.Fatalf("clientAddHost") t.Fatalf("clientAddHost")
} }
} }

View File

@ -2,6 +2,7 @@ package home
import ( import (
"io/ioutil" "io/ioutil"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -38,6 +39,13 @@ type clientObject struct {
SafeBrowsingEnabled bool `yaml:"safesearch_enabled"` SafeBrowsingEnabled bool `yaml:"safesearch_enabled"`
} }
type HTTPSServer struct {
server *http.Server
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
sync.Mutex // protects config.TLS
shutdown bool // if TRUE, don't restart the server
}
// configuration is loaded from YAML // configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type configuration struct { type configuration struct {
@ -48,10 +56,25 @@ type configuration struct {
ourConfigFilename string // Config filename (can be overridden via the command line arguments) ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
pidFileName string // PID file name. Empty if no PID file was created.
// runningAsService flag is set to true when options are passed from the service runner // runningAsService flag is set to true when options are passed from the service runner
runningAsService bool runningAsService bool
disableUpdate bool // If set, don't check for updates disableUpdate bool // If set, don't check for updates
appSignalChannel chan os.Signal appSignalChannel chan os.Signal
clients clientsContainer
controlLock sync.Mutex
transport *http.Transport
client *http.Client
// cached version.json to avoid hammering github.io for each page reload
versionCheckJSON []byte
versionCheckLastTime time.Time
dnsctx dnsContext
dnsServer *dnsforward.Server
dhcpServer dhcpd.Server
httpServer *http.Server
httpsServer HTTPSServer
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
@ -127,7 +150,6 @@ type tlsConfig struct {
} }
// initialize to default values, will be changed later when reading config or parsing command line // initialize to default values, will be changed later when reading config or parsing command line
// TODO: Get rid of global variables
var config = configuration{ var config = configuration{
ourConfigFilename: "AdGuardHome.yaml", ourConfigFilename: "AdGuardHome.yaml",
BindPort: 3000, BindPort: 3000,
@ -167,8 +189,16 @@ var config = configuration{
SchemaVersion: currentSchemaVersion, SchemaVersion: currentSchemaVersion,
} }
// init initializes default configuration for the current OS&ARCH // initConfig initializes default configuration for the current OS&ARCH
func init() { func initConfig() {
config.transport = &http.Transport{
DialContext: customDialContext,
}
config.client = &http.Client{
Timeout: time.Minute * 5,
Transport: config.transport,
}
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" { if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
// Use plain DNS on MIPS, encryption is too slow // Use plain DNS on MIPS, encryption is too slow
defaultDNS = []string{"1.1.1.1", "1.0.0.1"} defaultDNS = []string{"1.1.1.1", "1.0.0.1"}
@ -233,7 +263,7 @@ func parseConfig() error {
SafeSearchEnabled: cy.SafeSearchEnabled, SafeSearchEnabled: cy.SafeSearchEnabled,
SafeBrowsingEnabled: cy.SafeBrowsingEnabled, SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
} }
_, err = clientAdd(cli) _, err = config.clients.Add(cli)
if err != nil { if err != nil {
log.Tracef("clientAdd: %s", err) log.Tracef("clientAdd: %s", err)
} }
@ -268,7 +298,7 @@ func (c *configuration) write() error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
clientsList := clientsGetList() clientsList := config.clients.GetList()
for _, cli := range clientsList { for _, cli := range clientsList {
ip := cli.IP ip := cli.IP
if len(cli.MAC) != 0 { if len(cli.MAC) != 0 {

View File

@ -11,7 +11,6 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
@ -25,23 +24,8 @@ import (
const updatePeriod = time.Minute * 30 const updatePeriod = time.Minute * 30
// cached version.json to avoid hammering github.io for each page reload
var versionCheckJSON []byte
var versionCheckLastTime time.Time
var protocols = []string{"tls://", "https://", "tcp://", "sdns://"} var protocols = []string{"tls://", "https://", "tcp://", "sdns://"}
var transport = &http.Transport{
DialContext: customDialContext,
}
var client = &http.Client{
Timeout: time.Minute * 5,
Transport: transport,
}
var controlLock sync.Mutex
// ---------------- // ----------------
// helper functions // helper functions
// ---------------- // ----------------
@ -188,7 +172,7 @@ func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) {
func handleQueryLog(w http.ResponseWriter, r *http.Request) { func handleQueryLog(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL) log.Tracef("%s %v", r.Method, r.URL)
data := dnsServer.GetQueryLog() data := config.dnsServer.GetQueryLog()
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
@ -205,7 +189,7 @@ func handleQueryLog(w http.ResponseWriter, r *http.Request) {
func handleStatsTop(w http.ResponseWriter, r *http.Request) { func handleStatsTop(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL) log.Tracef("%s %v", r.Method, r.URL)
s := dnsServer.GetStatsTop() s := config.dnsServer.GetStatsTop()
// use manual json marshalling because we want maps to be sorted by value // use manual json marshalling because we want maps to be sorted by value
statsJSON := bytes.Buffer{} statsJSON := bytes.Buffer{}
@ -252,7 +236,7 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) {
// handleStatsReset resets the stats caches // handleStatsReset resets the stats caches
func handleStatsReset(w http.ResponseWriter, r *http.Request) { func handleStatsReset(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL) log.Tracef("%s %v", r.Method, r.URL)
dnsServer.PurgeStats() config.dnsServer.PurgeStats()
_, err := fmt.Fprintf(w, "OK\n") _, err := fmt.Fprintf(w, "OK\n")
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
@ -262,7 +246,7 @@ func handleStatsReset(w http.ResponseWriter, r *http.Request) {
// handleStats returns aggregated stats data for the 24 hours // handleStats returns aggregated stats data for the 24 hours
func handleStats(w http.ResponseWriter, r *http.Request) { func handleStats(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL) log.Tracef("%s %v", r.Method, r.URL)
summed := dnsServer.GetAggregatedStats() summed := config.dnsServer.GetAggregatedStats()
statsJSON, err := json.Marshal(summed) statsJSON, err := json.Marshal(summed)
if err != nil { if err != nil {
@ -309,7 +293,7 @@ func handleStatsHistory(w http.ResponseWriter, r *http.Request) {
return return
} }
data, err := dnsServer.GetStatsHistory(timeUnit, startTime, endTime) data, err := config.dnsServer.GetStatsHistory(timeUnit, startTime, endTime)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Cannot get stats history: %s", err) httpError(w, http.StatusBadRequest, "Cannot get stats history: %s", err)
return return
@ -725,7 +709,7 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
// Stop DNS server: // Stop DNS server:
// we close urlfilter object which in turn closes file descriptors to filter files. // we close urlfilter object which in turn closes file descriptors to filter files.
// Otherwise, Windows won't allow us to remove the file which is being currently used. // Otherwise, Windows won't allow us to remove the file which is being currently used.
_ = dnsServer.Stop() _ = config.dnsServer.Stop()
// go through each element and delete if url matches // go through each element and delete if url matches
config.Lock() config.Lock()
@ -984,7 +968,7 @@ func handleDOH(w http.ResponseWriter, r *http.Request) {
return return
} }
dnsServer.ServeHTTP(w, r) config.dnsServer.ServeHTTP(w, r)
} }
// ------------------------ // ------------------------

View File

@ -17,13 +17,13 @@ type accessListJSON struct {
func handleAccessList(w http.ResponseWriter, r *http.Request) { func handleAccessList(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL) log.Tracef("%s %v", r.Method, r.URL)
controlLock.Lock() config.controlLock.Lock()
j := accessListJSON{ j := accessListJSON{
AllowedClients: config.DNS.AllowedClients, AllowedClients: config.DNS.AllowedClients,
DisallowedClients: config.DNS.DisallowedClients, DisallowedClients: config.DNS.DisallowedClients,
BlockedHosts: config.DNS.BlockedHosts, BlockedHosts: config.DNS.BlockedHosts,
} }
controlLock.Unlock() config.controlLock.Unlock()
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(j) err := json.NewEncoder(w).Encode(j)

View File

@ -264,7 +264,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
if restartHTTP { if restartHTTP {
go func() { go func() {
httpServer.Shutdown(context.TODO()) config.httpServer.Shutdown(context.TODO())
}() }()
} }

View File

@ -46,7 +46,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
// check if port is available // check if port is available
// BUT: if we are already using this port, no need // BUT: if we are already using this port, no need
alreadyRunning := false alreadyRunning := false
if httpsServer.server != nil { if config.httpsServer.server != nil {
alreadyRunning = true alreadyRunning = true
} }
if !alreadyRunning { if !alreadyRunning {
@ -72,7 +72,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
// check if port is available // check if port is available
// BUT: if we are already using this port, no need // BUT: if we are already using this port, no need
alreadyRunning := false alreadyRunning := false
if httpsServer.server != nil { if config.httpsServer.server != nil {
alreadyRunning = true alreadyRunning = true
} }
if !alreadyRunning { if !alreadyRunning {
@ -101,12 +101,12 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
if restartHTTPS { if restartHTTPS {
go func() { go func() {
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
httpsServer.cond.L.Lock() config.httpsServer.cond.L.Lock()
httpsServer.cond.Broadcast() config.httpsServer.cond.Broadcast()
if httpsServer.server != nil { if config.httpsServer.server != nil {
httpsServer.server.Shutdown(context.TODO()) config.httpsServer.server.Shutdown(context.TODO())
} }
httpsServer.cond.L.Unlock() config.httpsServer.cond.L.Unlock()
}() }()
} }
} }

View File

@ -73,10 +73,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
now := time.Now() now := time.Now()
if !req.RecheckNow { if !req.RecheckNow {
controlLock.Lock() config.controlLock.Lock()
cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0 cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0
data := versionCheckJSON data := config.versionCheckJSON
controlLock.Unlock() config.controlLock.Unlock()
if cached { if cached {
log.Tracef("Returning cached data") log.Tracef("Returning cached data")
@ -87,7 +87,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
} }
log.Tracef("Downloading data from %s", versionCheckURL) log.Tracef("Downloading data from %s", versionCheckURL)
resp, err := client.Get(versionCheckURL) resp, err := config.client.Get(versionCheckURL)
if err != nil { if err != nil {
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
return return
@ -103,10 +103,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
return return
} }
controlLock.Lock() config.controlLock.Lock()
versionCheckLastTime = now config.versionCheckLastTime = now
versionCheckJSON = body config.versionCheckJSON = body
controlLock.Unlock() config.controlLock.Unlock()
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(getVersionResp(body)) _, err = w.Write(getVersionResp(body))
@ -349,7 +349,7 @@ func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly,
// Download package file and save it to disk // Download package file and save it to disk
func getPackageFile(u *updateInfo) error { func getPackageFile(u *updateInfo) error {
resp, err := client.Get(u.pkgURL) resp, err := config.client.Get(u.pkgURL)
if err != nil { if err != nil {
return fmt.Errorf("HTTP request failed: %s", err) return fmt.Errorf("HTTP request failed: %s", err)
} }
@ -501,12 +501,12 @@ func finishUpdate(u *updateInfo) {
func handleUpdate(w http.ResponseWriter, r *http.Request) { func handleUpdate(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL) log.Tracef("%s %v", r.Method, r.URL)
if len(versionCheckJSON) == 0 { if len(config.versionCheckJSON) == 0 {
httpError(w, http.StatusBadRequest, "/update request isn't allowed now") httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
return return
} }
u, err := getUpdateInfo(versionCheckJSON) u, err := getUpdateInfo(config.versionCheckJSON)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err) httpError(w, http.StatusInternalServerError, "%s", err)
return return

View File

@ -18,8 +18,6 @@ import (
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
) )
var dhcpServer = dhcpd.Server{}
// []dhcpd.Lease -> JSON // []dhcpd.Lease -> JSON
func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]string { func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]string {
leases := []map[string]string{} leases := []map[string]string{}
@ -41,8 +39,8 @@ func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]
func handleDHCPStatus(w http.ResponseWriter, r *http.Request) { func handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL) log.Tracef("%s %v", r.Method, r.URL)
leases := convertLeases(dhcpServer.Leases(), true) leases := convertLeases(config.dhcpServer.Leases(), true)
staticLeases := convertLeases(dhcpServer.StaticLeases(), false) staticLeases := convertLeases(config.dhcpServer.StaticLeases(), false)
status := map[string]interface{}{ status := map[string]interface{}{
"config": config.DHCP, "config": config.DHCP,
"leases": leases, "leases": leases,
@ -77,18 +75,18 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
err = dhcpServer.CheckConfig(newconfig.ServerConfig) err = config.dhcpServer.CheckConfig(newconfig.ServerConfig)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err) httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
return return
} }
err = dhcpServer.Stop() err = config.dhcpServer.Stop()
if err != nil { if err != nil {
log.Error("failed to stop the DHCP server: %s", err) log.Error("failed to stop the DHCP server: %s", err)
} }
err = dhcpServer.Init(newconfig.ServerConfig) err = config.dhcpServer.Init(newconfig.ServerConfig)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err) httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
return return
@ -105,7 +103,7 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
} }
} }
err = dhcpServer.Start() err = config.dhcpServer.Start()
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err) httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
return return
@ -389,7 +387,7 @@ func handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
HWAddr: mac, HWAddr: mac,
Hostname: lj.Hostname, Hostname: lj.Hostname,
} }
err = dhcpServer.AddStaticLease(lease) err = config.dhcpServer.AddStaticLease(lease)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) httpError(w, http.StatusBadRequest, "%s", err)
return return
@ -420,7 +418,7 @@ func handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
HWAddr: mac, HWAddr: mac,
Hostname: lj.Hostname, Hostname: lj.Hostname,
} }
err = dhcpServer.RemoveStaticLease(lease) err = config.dhcpServer.RemoveStaticLease(lease)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) httpError(w, http.StatusBadRequest, "%s", err)
return return
@ -434,12 +432,12 @@ func startDHCPServer() error {
return nil return nil
} }
err := dhcpServer.Init(config.DHCP) err := config.dhcpServer.Init(config.DHCP)
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't init DHCP server") return errorx.Decorate(err, "Couldn't init DHCP server")
} }
err = dhcpServer.Start() err = config.dhcpServer.Start()
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't start DHCP server") return errorx.Decorate(err, "Couldn't start DHCP server")
} }
@ -451,7 +449,7 @@ func stopDHCPServer() error {
return nil return nil
} }
err := dhcpServer.Stop() err := config.dhcpServer.Stop()
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't stop DHCP server") return errorx.Decorate(err, "Couldn't stop DHCP server")
} }

View File

@ -17,8 +17,6 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
var dnsServer *dnsforward.Server
const ( const (
rdnsTimeout = 3 * time.Second // max time to wait for rDNS response rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
) )
@ -32,8 +30,6 @@ type dnsContext struct {
upstream upstream.Upstream // Upstream object for our own DNS server upstream upstream.Upstream // Upstream object for our own DNS server
} }
var dnsctx dnsContext
// initDNSServer creates an instance of the dnsforward.Server // initDNSServer creates an instance of the dnsforward.Server
// Please note that we must do it even if we don't start it // Please note that we must do it even if we don't start it
// so that we had access to the query log and the stats // so that we had access to the query log and the stats
@ -43,7 +39,7 @@ func initDNSServer(baseDir string) {
log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err) log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err)
} }
dnsServer = dnsforward.NewServer(baseDir) config.dnsServer = dnsforward.NewServer(baseDir)
bindhost := config.DNS.BindHost bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" { if config.DNS.BindHost == "0.0.0.0" {
@ -53,37 +49,37 @@ func initDNSServer(baseDir string) {
opts := upstream.Options{ opts := upstream.Options{
Timeout: rdnsTimeout, Timeout: rdnsTimeout,
} }
dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts) config.dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
if err != nil { if err != nil {
log.Error("upstream.AddressToUpstream: %s", err) log.Error("upstream.AddressToUpstream: %s", err)
return return
} }
dnsctx.rdnsIP = make(map[string]bool) config.dnsctx.rdnsIP = make(map[string]bool)
dnsctx.rdnsChannel = make(chan string, 256) config.dnsctx.rdnsChannel = make(chan string, 256)
go asyncRDNSLoop() go asyncRDNSLoop()
} }
func isRunning() bool { func isRunning() bool {
return dnsServer != nil && dnsServer.IsRunning() return config.dnsServer != nil && config.dnsServer.IsRunning()
} }
func beginAsyncRDNS(ip string) { func beginAsyncRDNS(ip string) {
if clientExists(ip) { if config.clients.Exists(ip) {
return return
} }
// add IP to rdnsIP, if not exists // add IP to rdnsIP, if not exists
dnsctx.rdnsLock.Lock() config.dnsctx.rdnsLock.Lock()
defer dnsctx.rdnsLock.Unlock() defer config.dnsctx.rdnsLock.Unlock()
_, ok := dnsctx.rdnsIP[ip] _, ok := config.dnsctx.rdnsIP[ip]
if ok { if ok {
return return
} }
dnsctx.rdnsIP[ip] = true config.dnsctx.rdnsIP[ip] = true
log.Tracef("Adding %s for rDNS resolve", ip) log.Tracef("Adding %s for rDNS resolve", ip)
select { select {
case dnsctx.rdnsChannel <- ip: case config.dnsctx.rdnsChannel <- ip:
// //
default: default:
log.Tracef("rDNS queue is full") log.Tracef("rDNS queue is full")
@ -110,7 +106,7 @@ func resolveRDNS(ip string) string {
return "" return ""
} }
resp, err := dnsctx.upstream.Exchange(&req) resp, err := config.dnsctx.upstream.Exchange(&req)
if err != nil { if err != nil {
log.Error("Error while making an rDNS lookup for %s: %s", ip, err) log.Error("Error while making an rDNS lookup for %s: %s", ip, err)
return "" return ""
@ -138,18 +134,18 @@ func resolveRDNS(ip string) string {
func asyncRDNSLoop() { func asyncRDNSLoop() {
for { for {
var ip string var ip string
ip = <-dnsctx.rdnsChannel ip = <-config.dnsctx.rdnsChannel
host := resolveRDNS(ip) host := resolveRDNS(ip)
if len(host) == 0 { if len(host) == 0 {
continue continue
} }
dnsctx.rdnsLock.Lock() config.dnsctx.rdnsLock.Lock()
delete(dnsctx.rdnsIP, ip) delete(config.dnsctx.rdnsIP, ip)
dnsctx.rdnsLock.Unlock() config.dnsctx.rdnsLock.Unlock()
_, _ = clientAddHost(ip, host, ClientSourceRDNS) _, _ = config.clients.AddHost(ip, host, ClientSourceRDNS)
} }
} }
@ -221,7 +217,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) {
// If a client has his own settings, apply them // If a client has his own settings, apply them
func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
c, ok := clientFind(clientAddr) c, ok := config.clients.Find(clientAddr)
if !ok || !c.UseOwnSettings { if !ok || !c.UseOwnSettings {
return return
} }
@ -242,12 +238,12 @@ func startDNSServer() error {
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server") return errorx.Decorate(err, "Couldn't start forwarding DNS server")
} }
err = dnsServer.Start(&newconfig) err = config.dnsServer.Start(&newconfig)
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server") return errorx.Decorate(err, "Couldn't start forwarding DNS server")
} }
top := dnsServer.GetStatsTop() top := config.dnsServer.GetStatsTop()
for k := range top.Clients { for k := range top.Clients {
beginAsyncRDNS(k) beginAsyncRDNS(k)
} }
@ -256,11 +252,11 @@ func startDNSServer() error {
} }
func reconfigureDNSServer() error { func reconfigureDNSServer() error {
config, err := generateServerConfig() newconfig, err := generateServerConfig()
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server") return errorx.Decorate(err, "Couldn't start forwarding DNS server")
} }
err = dnsServer.Reconfigure(&config) err = config.dnsServer.Reconfigure(&newconfig)
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server") return errorx.Decorate(err, "Couldn't start forwarding DNS server")
} }
@ -273,7 +269,7 @@ func stopDNSServer() error {
return nil return nil
} }
err := dnsServer.Stop() err := config.dnsServer.Stop()
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't stop forwarding DNS server") return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
} }

View File

@ -222,7 +222,7 @@ func refreshFiltersIfNecessary(force bool) int {
stopped := false stopped := false
if updateCount != 0 { if updateCount != 0 {
_ = dnsServer.Stop() _ = config.dnsServer.Stop()
stopped = true stopped = true
} }
@ -308,7 +308,7 @@ func parseFilterContents(contents []byte) (int, string) {
func (filter *filter) update() (bool, error) { func (filter *filter) update() (bool, error) {
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL) log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
resp, err := client.Get(filter.URL) resp, err := config.client.Get(filter.URL)
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
} }

View File

@ -35,8 +35,8 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
} }
if method == "POST" || method == "PUT" || method == "DELETE" { if method == "POST" || method == "PUT" || method == "DELETE" {
controlLock.Lock() config.controlLock.Lock()
defer controlLock.Unlock() defer config.controlLock.Unlock()
} }
handler(w, r) handler(w, r)
@ -148,7 +148,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
return return
} }
// enforce https? // enforce https?
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && httpsServer.server != nil { if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && config.httpsServer.server != nil {
// yes, and we want host from host:port // yes, and we want host from host:port
host, _, err := net.SplitHostPort(r.Host) host, _, err := net.SplitHostPort(r.Host)
if err != nil { if err != nil {

View File

@ -25,15 +25,6 @@ import (
"github.com/gobuffalo/packr" "github.com/gobuffalo/packr"
) )
var httpServer *http.Server
var httpsServer struct {
server *http.Server
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
sync.Mutex // protects config.TLS
shutdown bool // if TRUE, don't restart the server
}
var pidFileName string // PID file name. Empty if no PID file was created.
const ( const (
// Used in config to indicate that syslog or eventlog (win) should be used for logger output // Used in config to indicate that syslog or eventlog (win) should be used for logger output
configSyslog = "syslog" configSyslog = "syslog"
@ -48,7 +39,7 @@ var (
const versionCheckPeriod = time.Hour * 8 const versionCheckPeriod = time.Hour * 8
// main is the entry point // Main is the entry point
func Main(version string, channel string) { func Main(version string, channel string) {
// Init update-related global variables // Init update-related global variables
versionString = version versionString = version
@ -108,7 +99,8 @@ func run(args options) {
os.Exit(0) os.Exit(0)
}() }()
clientsInit() initConfig()
config.clients.Init()
if !config.firstRun { if !config.firstRun {
// Do the upgrade if necessary // Do the upgrade if necessary
@ -168,7 +160,7 @@ func run(args options) {
} }
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
pidFileName = args.pidFile config.pidFileName = args.pidFile
} }
// Update filters we've just loaded right away, don't wait for periodic update timer // Update filters we've just loaded right away, don't wait for periodic update timer
@ -192,21 +184,21 @@ func run(args options) {
registerInstallHandlers() registerInstallHandlers()
} }
httpsServer.cond = sync.NewCond(&httpsServer.Mutex) config.httpsServer.cond = sync.NewCond(&config.httpsServer.Mutex)
// for https, we have a separate goroutine loop // for https, we have a separate goroutine loop
go httpServerLoop() go httpServerLoop()
// this loop is used as an ability to change listening host and/or port // this loop is used as an ability to change listening host and/or port
for !httpsServer.shutdown { for !config.httpsServer.shutdown {
printHTTPAddresses("http") printHTTPAddresses("http")
// we need to have new instance, because after Shutdown() the Server is not usable // we need to have new instance, because after Shutdown() the Server is not usable
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
httpServer = &http.Server{ config.httpServer = &http.Server{
Addr: address, Addr: address,
} }
err := httpServer.ListenAndServe() err := config.httpServer.ListenAndServe()
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
cleanupAlways() cleanupAlways()
log.Fatal(err) log.Fatal(err)
@ -219,14 +211,14 @@ func run(args options) {
} }
func httpServerLoop() { func httpServerLoop() {
for !httpsServer.shutdown { for !config.httpsServer.shutdown {
httpsServer.cond.L.Lock() config.httpsServer.cond.L.Lock()
// this mechanism doesn't let us through until all conditions are met // this mechanism doesn't let us through until all conditions are met
for config.TLS.Enabled == false || for config.TLS.Enabled == false ||
config.TLS.PortHTTPS == 0 || config.TLS.PortHTTPS == 0 ||
config.TLS.PrivateKey == "" || config.TLS.PrivateKey == "" ||
config.TLS.CertificateChain == "" { // sleep until necessary data is supplied config.TLS.CertificateChain == "" { // sleep until necessary data is supplied
httpsServer.cond.Wait() config.httpsServer.cond.Wait()
} }
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS)) address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
// validate current TLS config and update warnings (it could have been loaded from file) // validate current TLS config and update warnings (it could have been loaded from file)
@ -250,10 +242,10 @@ func httpServerLoop() {
cleanupAlways() cleanupAlways()
log.Fatal(err) log.Fatal(err)
} }
httpsServer.cond.L.Unlock() config.httpsServer.cond.L.Unlock()
// prepare HTTPS server // prepare HTTPS server
httpsServer.server = &http.Server{ config.httpsServer.server = &http.Server{
Addr: address, Addr: address,
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
@ -262,7 +254,7 @@ func httpServerLoop() {
} }
printHTTPAddresses("https") printHTTPAddresses("https")
err = httpsServer.server.ListenAndServeTLS("", "") err = config.httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
cleanupAlways() cleanupAlways()
log.Fatal(err) log.Fatal(err)
@ -399,17 +391,17 @@ func cleanup() {
// Stop HTTP server, possibly waiting for all active connections to be closed // Stop HTTP server, possibly waiting for all active connections to be closed
func stopHTTPServer() { func stopHTTPServer() {
httpsServer.shutdown = true config.httpsServer.shutdown = true
if httpsServer.server != nil { if config.httpsServer.server != nil {
httpsServer.server.Shutdown(context.TODO()) config.httpsServer.server.Shutdown(context.TODO())
} }
httpServer.Shutdown(context.TODO()) config.httpServer.Shutdown(context.TODO())
} }
// This function is called before application exits // This function is called before application exits
func cleanupAlways() { func cleanupAlways() {
if len(pidFileName) != 0 { if len(config.pidFileName) != 0 {
os.Remove(pidFileName) os.Remove(config.pidFileName)
} }
log.Info("Stopped") log.Info("Stopped")
} }