Merge: Refactoring: move global variables; move initialization of periodic tasks

Close #583

* commit 'b8444ff46aff5e45194c6cb61fdf4d3e7aa798fa':
  * minor
  * dnsforward: move initialization of periodic tasks to NewServer()
  * move "dnsctx" to "config"
  * move "dnsServer" to "config"
  * move "dhcpServer" to "config"
  * move "httpServer" to "config"
  * move "httpsServer" to "config"
  * move "pidFileName" to "config"
  * move "versionCheckJSON" to "config"
  * move "client", "transport" to "config"
  * move "controlLock" mutex to "config"
  * clients: move container object to "config"
This commit is contained in:
Simon Zolin 2019-07-19 15:18:51 +03:00
commit bd5162ada3
14 changed files with 213 additions and 221 deletions

View File

@ -41,7 +41,6 @@ type Server struct {
dnsFilter *dnsfilter.Dnsfilter // DNS filter instance
queryLog *queryLog // Query log instance
stats *stats // General server statistics
once sync.Once
AllowedClients map[string]bool // IP addresses of whitelist clients
DisallowedClients map[string]bool // IP addresses of clients that should be blocked
@ -55,11 +54,24 @@ type Server struct {
// NewServer creates a new instance of the dnsforward.Server
// baseDir is the base directory for query logs
// Note: this function must be called only once
func NewServer(baseDir string) *Server {
return &Server{
s := &Server{
queryLog: newQueryLog(baseDir),
stats: newStats(),
}
log.Tracef("Loading stats from querylog")
err := s.queryLog.fillStatsFromQueryLog(s.stats)
if err != nil {
log.Error("failed to load stats from querylog: %s", err)
}
log.Printf("Start DNS server periodic jobs")
go s.queryLog.periodicQueryLogRotate()
go s.queryLog.runningTop.periodicHourlyTopRotate()
go s.stats.statsRotator()
return s
}
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
@ -169,33 +181,11 @@ func (s *Server) startInternal(config *ServerConfig) error {
return errors.New("DNS server is already started")
}
if s.queryLog == nil {
s.queryLog = newQueryLog(".")
}
if s.stats == nil {
s.stats = newStats()
}
err := s.initDNSFilter()
if err != nil {
return err
}
log.Tracef("Loading stats from querylog")
err = s.queryLog.fillStatsFromQueryLog(s.stats)
if err != nil {
return errorx.Decorate(err, "failed to load stats from querylog")
}
// TODO: Think about reworking this, the current approach won't work properly if AG Home is restarted periodically
s.once.Do(func() {
log.Printf("Start DNS server periodic jobs")
go s.queryLog.periodicQueryLogRotate()
go s.queryLog.runningTop.periodicHourlyTopRotate()
go s.stats.statsRotator()
})
proxyConfig := proxy.Config{
UDPListenAddr: s.conf.UDPListenAddr,
TCPListenAddr: s.conf.TCPListenAddr,

View File

@ -66,10 +66,9 @@ type clientsContainer struct {
lock sync.Mutex
}
var clients clientsContainer
// Initialize clients container
func clientsInit() {
// Init initializes clients container
// Note: this function must be called only once
func (clients *clientsContainer) Init() {
if clients.list != nil {
log.Fatal("clients.list != nil")
}
@ -77,22 +76,24 @@ func clientsInit() {
clients.ipIndex = make(map[string]*Client)
clients.ipHost = make(map[string]ClientHost)
go periodicClientsUpdate()
go clients.periodicUpdate()
}
func periodicClientsUpdate() {
func (clients *clientsContainer) periodicUpdate() {
for {
clientsAddFromHostsFile()
clientsAddFromSystemARP()
clients.addFromHostsFile()
clients.addFromSystemARP()
time.Sleep(clientsUpdatePeriod)
}
}
func clientsGetList() map[string]*Client {
// GetList returns the pointer to clients list
func (clients *clientsContainer) GetList() map[string]*Client {
return clients.list
}
func clientExists(ip string) bool {
// Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip string) bool {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -105,8 +106,8 @@ func clientExists(ip string) bool {
return ok
}
// Search for a client by IP
func clientFind(ip string) (Client, bool) {
// Find searches for a client by IP
func (clients *clientsContainer) Find(ip string) (Client, bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -121,7 +122,7 @@ func clientFind(ip string) (Client, bool) {
if err != nil {
continue
}
ipAddr := dhcpServer.FindIPbyMAC(mac)
ipAddr := config.dhcpServer.FindIPbyMAC(mac)
if ipAddr == nil {
continue
}
@ -135,7 +136,7 @@ func clientFind(ip string) (Client, bool) {
}
// Check if Client object's fields are correct
func clientCheck(c *Client) error {
func (c *Client) check() error {
if len(c.Name) == 0 {
return fmt.Errorf("Invalid Name")
}
@ -162,8 +163,8 @@ func clientCheck(c *Client) error {
// Add a new client object
// Return true: success; false: client exists.
func clientAdd(c Client) (bool, error) {
e := clientCheck(&c)
func (clients *clientsContainer) Add(c Client) (bool, error) {
e := c.check()
if e != nil {
return false, e
}
@ -194,8 +195,8 @@ func clientAdd(c Client) (bool, error) {
return true, nil
}
// Remove a client
func clientDel(name string) bool {
// Del removes a client
func (clients *clientsContainer) Del(name string) bool {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -210,8 +211,8 @@ func clientDel(name string) bool {
}
// Update a client
func clientUpdate(name string, c Client) error {
err := clientCheck(&c)
func (clients *clientsContainer) Update(name string, c Client) error {
err := c.check()
if err != nil {
return err
}
@ -257,10 +258,10 @@ func clientUpdate(name string, c Client) error {
return nil
}
// Add new IP -> Host pair
// AddHost adds new IP -> Host pair
// Use priority of the source (etc/hosts > ARP > rDNS)
// so we overwrite existing entries with an equal or higher priority
func clientAddHost(ip, host string, source clientSource) (bool, error) {
func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -279,7 +280,7 @@ func clientAddHost(ip, host string, source clientSource) (bool, error) {
}
// Parse system 'hosts' file and fill clients array
func clientsAddFromHostsFile() {
func (clients *clientsContainer) addFromHostsFile() {
hostsFn := "/etc/hosts"
if runtime.GOOS == "windows" {
hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
@ -304,7 +305,7 @@ func clientsAddFromHostsFile() {
continue
}
ok, e := clientAddHost(fields[0], fields[1], ClientSourceHostsFile)
ok, e := clients.AddHost(fields[0], fields[1], ClientSourceHostsFile)
if e != nil {
log.Tracef("%s", e)
}
@ -319,7 +320,7 @@ func clientsAddFromHostsFile() {
// Add IP -> Host pairs from the system's `arp -a` command output
// The command's output is:
// HOST (IP) at MAC on IFACE
func clientsAddFromSystemARP() {
func (clients *clientsContainer) addFromSystemARP() {
if runtime.GOOS == "windows" {
return
@ -350,7 +351,7 @@ func clientsAddFromSystemARP() {
continue
}
ok, e := clientAddHost(ip, host, ClientSourceARP)
ok, e := clients.AddHost(ip, host, ClientSourceARP)
if e != nil {
log.Tracef("%s", e)
}
@ -379,8 +380,8 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
data := clientListJSON{}
clients.lock.Lock()
for _, c := range clients.list {
config.clients.lock.Lock()
for _, c := range config.clients.list {
cj := clientJSON{
IP: c.IP,
MAC: c.MAC,
@ -394,7 +395,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
if len(c.MAC) != 0 {
hwAddr, _ := net.ParseMAC(c.MAC)
ipAddr := dhcpServer.FindIPbyMAC(hwAddr)
ipAddr := config.dhcpServer.FindIPbyMAC(hwAddr)
if ipAddr != nil {
cj.IP = ipAddr.String()
}
@ -402,7 +403,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
data.Clients = append(data.Clients, cj)
}
for ip, ch := range clients.ipHost {
for ip, ch := range config.clients.ipHost {
cj := clientHostJSON{
IP: ip,
Name: ch.Host,
@ -416,7 +417,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
}
data.AutoClients = append(data.AutoClients, cj)
}
clients.lock.Unlock()
config.clients.lock.Unlock()
w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data)
@ -462,7 +463,7 @@ func handleAddClient(w http.ResponseWriter, r *http.Request) {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
ok, err := clientAdd(*c)
ok, err := config.clients.Add(*c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
@ -492,7 +493,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
return
}
if !clientDel(cj.Name) {
if !config.clients.Del(cj.Name) {
httpError(w, http.StatusBadRequest, "Client not found")
return
}
@ -501,7 +502,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
returnOK(w)
}
type clientUpdateJSON struct {
type updateJSON struct {
Name string `json:"name"`
Data clientJSON `json:"data"`
}
@ -515,7 +516,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
return
}
var dj clientUpdateJSON
var dj updateJSON
err = json.Unmarshal(body, &dj)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
@ -532,7 +533,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
return
}
err = clientUpdate(dj.Name, *c)
err = config.clients.Update(dj.Name, *c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return

View File

@ -6,17 +6,18 @@ func TestClients(t *testing.T) {
var c Client
var e error
var b bool
clients := clientsContainer{}
clientsInit()
clients.Init()
// add
c = Client{
IP: "1.1.1.1",
Name: "client1",
}
b, e = clientAdd(c)
b, e = clients.Add(c)
if !b || e != nil {
t.Fatalf("clientAdd #1")
t.Fatalf("Add #1")
}
// add #2
@ -24,19 +25,19 @@ func TestClients(t *testing.T) {
IP: "2.2.2.2",
Name: "client2",
}
b, e = clientAdd(c)
b, e = clients.Add(c)
if !b || e != nil {
t.Fatalf("clientAdd #2")
t.Fatalf("Add #2")
}
c, b = clientFind("1.1.1.1")
c, b = clients.Find("1.1.1.1")
if !b || c.Name != "client1" {
t.Fatalf("clientFind #1")
t.Fatalf("Find #1")
}
c, b = clientFind("2.2.2.2")
c, b = clients.Find("2.2.2.2")
if !b || c.Name != "client2" {
t.Fatalf("clientFind #2")
t.Fatalf("Find #2")
}
// failed add - name in use
@ -44,9 +45,9 @@ func TestClients(t *testing.T) {
IP: "1.2.3.5",
Name: "client1",
}
b, _ = clientAdd(c)
b, _ = clients.Add(c)
if b {
t.Fatalf("clientAdd - name in use")
t.Fatalf("Add - name in use")
}
// failed add - ip in use
@ -54,91 +55,91 @@ func TestClients(t *testing.T) {
IP: "2.2.2.2",
Name: "client3",
}
b, e = clientAdd(c)
b, e = clients.Add(c)
if b || e == nil {
t.Fatalf("clientAdd - ip in use")
t.Fatalf("Add - ip in use")
}
// get
if clientExists("1.2.3.4") {
t.Fatalf("clientExists")
if clients.Exists("1.2.3.4") {
t.Fatalf("Exists")
}
if !clientExists("1.1.1.1") {
t.Fatalf("clientExists #1")
if !clients.Exists("1.1.1.1") {
t.Fatalf("Exists #1")
}
if !clientExists("2.2.2.2") {
t.Fatalf("clientExists #2")
if !clients.Exists("2.2.2.2") {
t.Fatalf("Exists #2")
}
// failed update - no such name
c.IP = "1.2.3.0"
c.Name = "client3"
if clientUpdate("client3", c) == nil {
t.Fatalf("clientUpdate")
if clients.Update("client3", c) == nil {
t.Fatalf("Update")
}
// failed update - name in use
c.IP = "1.2.3.0"
c.Name = "client2"
if clientUpdate("client1", c) == nil {
t.Fatalf("clientUpdate - name in use")
if clients.Update("client1", c) == nil {
t.Fatalf("Update - name in use")
}
// failed update - ip in use
c.IP = "2.2.2.2"
c.Name = "client1"
if clientUpdate("client1", c) == nil {
t.Fatalf("clientUpdate - ip in use")
if clients.Update("client1", c) == nil {
t.Fatalf("Update - ip in use")
}
// update
c.IP = "1.1.1.2"
c.Name = "client1"
if clientUpdate("client1", c) != nil {
t.Fatalf("clientUpdate")
if clients.Update("client1", c) != nil {
t.Fatalf("Update")
}
// get after update
if clientExists("1.1.1.1") || !clientExists("1.1.1.2") {
t.Fatalf("clientExists - get after update")
if clients.Exists("1.1.1.1") || !clients.Exists("1.1.1.2") {
t.Fatalf("Exists - get after update")
}
// failed remove - no such name
if clientDel("client3") {
t.Fatalf("clientDel - no such name")
if clients.Del("client3") {
t.Fatalf("Del - no such name")
}
// remove
if !clientDel("client1") || clientExists("1.1.1.2") {
t.Fatalf("clientDel")
if !clients.Del("client1") || clients.Exists("1.1.1.2") {
t.Fatalf("Del")
}
// add host client
b, e = clientAddHost("1.1.1.1", "host", ClientSourceARP)
b, e = clients.AddHost("1.1.1.1", "host", ClientSourceARP)
if !b || e != nil {
t.Fatalf("clientAddHost")
}
// failed add - ip exists
b, e = clientAddHost("1.1.1.1", "host1", ClientSourceRDNS)
b, e = clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
if b || e != nil {
t.Fatalf("clientAddHost - ip exists")
}
// overwrite with new data
b, e = clientAddHost("1.1.1.1", "host2", ClientSourceARP)
b, e = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
if !b || e != nil {
t.Fatalf("clientAddHost - overwrite with new data")
}
// overwrite with new data (higher priority)
b, e = clientAddHost("1.1.1.1", "host3", ClientSourceHostsFile)
b, e = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
if !b || e != nil {
t.Fatalf("clientAddHost - overwrite with new data (higher priority)")
}
// get
if !clientExists("1.1.1.1") {
if !clients.Exists("1.1.1.1") {
t.Fatalf("clientAddHost")
}
}

View File

@ -2,6 +2,7 @@ package home
import (
"io/ioutil"
"net/http"
"os"
"path/filepath"
"runtime"
@ -38,6 +39,13 @@ type clientObject struct {
SafeBrowsingEnabled bool `yaml:"safesearch_enabled"`
}
type HTTPSServer struct {
server *http.Server
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
sync.Mutex // protects config.TLS
shutdown bool // if TRUE, don't restart the server
}
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
type configuration struct {
@ -48,10 +56,25 @@ type configuration struct {
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
pidFileName string // PID file name. Empty if no PID file was created.
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
disableUpdate bool // If set, don't check for updates
appSignalChannel chan os.Signal
clients clientsContainer
controlLock sync.Mutex
transport *http.Transport
client *http.Client
// cached version.json to avoid hammering github.io for each page reload
versionCheckJSON []byte
versionCheckLastTime time.Time
dnsctx dnsContext
dnsServer *dnsforward.Server
dhcpServer dhcpd.Server
httpServer *http.Server
httpsServer HTTPSServer
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
@ -127,7 +150,6 @@ type tlsConfig struct {
}
// initialize to default values, will be changed later when reading config or parsing command line
// TODO: Get rid of global variables
var config = configuration{
ourConfigFilename: "AdGuardHome.yaml",
BindPort: 3000,
@ -167,8 +189,16 @@ var config = configuration{
SchemaVersion: currentSchemaVersion,
}
// init initializes default configuration for the current OS&ARCH
func init() {
// initConfig initializes default configuration for the current OS&ARCH
func initConfig() {
config.transport = &http.Transport{
DialContext: customDialContext,
}
config.client = &http.Client{
Timeout: time.Minute * 5,
Transport: config.transport,
}
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
// Use plain DNS on MIPS, encryption is too slow
defaultDNS = []string{"1.1.1.1", "1.0.0.1"}
@ -233,7 +263,7 @@ func parseConfig() error {
SafeSearchEnabled: cy.SafeSearchEnabled,
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
}
_, err = clientAdd(cli)
_, err = config.clients.Add(cli)
if err != nil {
log.Tracef("clientAdd: %s", err)
}
@ -268,7 +298,7 @@ func (c *configuration) write() error {
c.Lock()
defer c.Unlock()
clientsList := clientsGetList()
clientsList := config.clients.GetList()
for _, cli := range clientsList {
ip := cli.IP
if len(cli.MAC) != 0 {

View File

@ -11,7 +11,6 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
@ -25,23 +24,8 @@ import (
const updatePeriod = time.Minute * 30
// cached version.json to avoid hammering github.io for each page reload
var versionCheckJSON []byte
var versionCheckLastTime time.Time
var protocols = []string{"tls://", "https://", "tcp://", "sdns://"}
var transport = &http.Transport{
DialContext: customDialContext,
}
var client = &http.Client{
Timeout: time.Minute * 5,
Transport: transport,
}
var controlLock sync.Mutex
// ----------------
// helper functions
// ----------------
@ -188,7 +172,7 @@ func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) {
func handleQueryLog(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
data := dnsServer.GetQueryLog()
data := config.dnsServer.GetQueryLog()
jsonVal, err := json.Marshal(data)
if err != nil {
@ -205,7 +189,7 @@ func handleQueryLog(w http.ResponseWriter, r *http.Request) {
func handleStatsTop(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
s := dnsServer.GetStatsTop()
s := config.dnsServer.GetStatsTop()
// use manual json marshalling because we want maps to be sorted by value
statsJSON := bytes.Buffer{}
@ -252,7 +236,7 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) {
// handleStatsReset resets the stats caches
func handleStatsReset(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
dnsServer.PurgeStats()
config.dnsServer.PurgeStats()
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
@ -262,7 +246,7 @@ func handleStatsReset(w http.ResponseWriter, r *http.Request) {
// handleStats returns aggregated stats data for the 24 hours
func handleStats(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
summed := dnsServer.GetAggregatedStats()
summed := config.dnsServer.GetAggregatedStats()
statsJSON, err := json.Marshal(summed)
if err != nil {
@ -309,7 +293,7 @@ func handleStatsHistory(w http.ResponseWriter, r *http.Request) {
return
}
data, err := dnsServer.GetStatsHistory(timeUnit, startTime, endTime)
data, err := config.dnsServer.GetStatsHistory(timeUnit, startTime, endTime)
if err != nil {
httpError(w, http.StatusBadRequest, "Cannot get stats history: %s", err)
return
@ -725,7 +709,7 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
// Stop DNS server:
// we close urlfilter object which in turn closes file descriptors to filter files.
// Otherwise, Windows won't allow us to remove the file which is being currently used.
_ = dnsServer.Stop()
_ = config.dnsServer.Stop()
// go through each element and delete if url matches
config.Lock()
@ -984,7 +968,7 @@ func handleDOH(w http.ResponseWriter, r *http.Request) {
return
}
dnsServer.ServeHTTP(w, r)
config.dnsServer.ServeHTTP(w, r)
}
// ------------------------

View File

@ -17,13 +17,13 @@ type accessListJSON struct {
func handleAccessList(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
controlLock.Lock()
config.controlLock.Lock()
j := accessListJSON{
AllowedClients: config.DNS.AllowedClients,
DisallowedClients: config.DNS.DisallowedClients,
BlockedHosts: config.DNS.BlockedHosts,
}
controlLock.Unlock()
config.controlLock.Unlock()
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(j)

View File

@ -264,7 +264,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
if restartHTTP {
go func() {
httpServer.Shutdown(context.TODO())
config.httpServer.Shutdown(context.TODO())
}()
}

View File

@ -46,7 +46,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
// check if port is available
// BUT: if we are already using this port, no need
alreadyRunning := false
if httpsServer.server != nil {
if config.httpsServer.server != nil {
alreadyRunning = true
}
if !alreadyRunning {
@ -72,7 +72,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
// check if port is available
// BUT: if we are already using this port, no need
alreadyRunning := false
if httpsServer.server != nil {
if config.httpsServer.server != nil {
alreadyRunning = true
}
if !alreadyRunning {
@ -101,12 +101,12 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
if restartHTTPS {
go func() {
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
httpsServer.cond.L.Lock()
httpsServer.cond.Broadcast()
if httpsServer.server != nil {
httpsServer.server.Shutdown(context.TODO())
config.httpsServer.cond.L.Lock()
config.httpsServer.cond.Broadcast()
if config.httpsServer.server != nil {
config.httpsServer.server.Shutdown(context.TODO())
}
httpsServer.cond.L.Unlock()
config.httpsServer.cond.L.Unlock()
}()
}
}

View File

@ -73,10 +73,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
now := time.Now()
if !req.RecheckNow {
controlLock.Lock()
cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0
data := versionCheckJSON
controlLock.Unlock()
config.controlLock.Lock()
cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0
data := config.versionCheckJSON
config.controlLock.Unlock()
if cached {
log.Tracef("Returning cached data")
@ -87,7 +87,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
}
log.Tracef("Downloading data from %s", versionCheckURL)
resp, err := client.Get(versionCheckURL)
resp, err := config.client.Get(versionCheckURL)
if err != nil {
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
return
@ -103,10 +103,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
return
}
controlLock.Lock()
versionCheckLastTime = now
versionCheckJSON = body
controlLock.Unlock()
config.controlLock.Lock()
config.versionCheckLastTime = now
config.versionCheckJSON = body
config.controlLock.Unlock()
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(getVersionResp(body))
@ -349,7 +349,7 @@ func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly,
// Download package file and save it to disk
func getPackageFile(u *updateInfo) error {
resp, err := client.Get(u.pkgURL)
resp, err := config.client.Get(u.pkgURL)
if err != nil {
return fmt.Errorf("HTTP request failed: %s", err)
}
@ -501,12 +501,12 @@ func finishUpdate(u *updateInfo) {
func handleUpdate(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
if len(versionCheckJSON) == 0 {
if len(config.versionCheckJSON) == 0 {
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
return
}
u, err := getUpdateInfo(versionCheckJSON)
u, err := getUpdateInfo(config.versionCheckJSON)
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return

View File

@ -18,8 +18,6 @@ import (
"github.com/joomcode/errorx"
)
var dhcpServer = dhcpd.Server{}
// []dhcpd.Lease -> JSON
func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]string {
leases := []map[string]string{}
@ -41,8 +39,8 @@ func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]
func handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
leases := convertLeases(dhcpServer.Leases(), true)
staticLeases := convertLeases(dhcpServer.StaticLeases(), false)
leases := convertLeases(config.dhcpServer.Leases(), true)
staticLeases := convertLeases(config.dhcpServer.StaticLeases(), false)
status := map[string]interface{}{
"config": config.DHCP,
"leases": leases,
@ -77,18 +75,18 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
return
}
err = dhcpServer.CheckConfig(newconfig.ServerConfig)
err = config.dhcpServer.CheckConfig(newconfig.ServerConfig)
if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
return
}
err = dhcpServer.Stop()
err = config.dhcpServer.Stop()
if err != nil {
log.Error("failed to stop the DHCP server: %s", err)
}
err = dhcpServer.Init(newconfig.ServerConfig)
err = config.dhcpServer.Init(newconfig.ServerConfig)
if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
return
@ -105,7 +103,7 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
}
}
err = dhcpServer.Start()
err = config.dhcpServer.Start()
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
return
@ -389,7 +387,7 @@ func handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
HWAddr: mac,
Hostname: lj.Hostname,
}
err = dhcpServer.AddStaticLease(lease)
err = config.dhcpServer.AddStaticLease(lease)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
@ -420,7 +418,7 @@ func handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
HWAddr: mac,
Hostname: lj.Hostname,
}
err = dhcpServer.RemoveStaticLease(lease)
err = config.dhcpServer.RemoveStaticLease(lease)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
@ -434,12 +432,12 @@ func startDHCPServer() error {
return nil
}
err := dhcpServer.Init(config.DHCP)
err := config.dhcpServer.Init(config.DHCP)
if err != nil {
return errorx.Decorate(err, "Couldn't init DHCP server")
}
err = dhcpServer.Start()
err = config.dhcpServer.Start()
if err != nil {
return errorx.Decorate(err, "Couldn't start DHCP server")
}
@ -451,7 +449,7 @@ func stopDHCPServer() error {
return nil
}
err := dhcpServer.Stop()
err := config.dhcpServer.Stop()
if err != nil {
return errorx.Decorate(err, "Couldn't stop DHCP server")
}

View File

@ -17,8 +17,6 @@ import (
"github.com/miekg/dns"
)
var dnsServer *dnsforward.Server
const (
rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
)
@ -32,8 +30,6 @@ type dnsContext struct {
upstream upstream.Upstream // Upstream object for our own DNS server
}
var dnsctx dnsContext
// initDNSServer creates an instance of the dnsforward.Server
// Please note that we must do it even if we don't start it
// so that we had access to the query log and the stats
@ -43,7 +39,7 @@ func initDNSServer(baseDir string) {
log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err)
}
dnsServer = dnsforward.NewServer(baseDir)
config.dnsServer = dnsforward.NewServer(baseDir)
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
@ -53,37 +49,37 @@ func initDNSServer(baseDir string) {
opts := upstream.Options{
Timeout: rdnsTimeout,
}
dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
config.dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
if err != nil {
log.Error("upstream.AddressToUpstream: %s", err)
return
}
dnsctx.rdnsIP = make(map[string]bool)
dnsctx.rdnsChannel = make(chan string, 256)
config.dnsctx.rdnsIP = make(map[string]bool)
config.dnsctx.rdnsChannel = make(chan string, 256)
go asyncRDNSLoop()
}
func isRunning() bool {
return dnsServer != nil && dnsServer.IsRunning()
return config.dnsServer != nil && config.dnsServer.IsRunning()
}
func beginAsyncRDNS(ip string) {
if clientExists(ip) {
if config.clients.Exists(ip) {
return
}
// add IP to rdnsIP, if not exists
dnsctx.rdnsLock.Lock()
defer dnsctx.rdnsLock.Unlock()
_, ok := dnsctx.rdnsIP[ip]
config.dnsctx.rdnsLock.Lock()
defer config.dnsctx.rdnsLock.Unlock()
_, ok := config.dnsctx.rdnsIP[ip]
if ok {
return
}
dnsctx.rdnsIP[ip] = true
config.dnsctx.rdnsIP[ip] = true
log.Tracef("Adding %s for rDNS resolve", ip)
select {
case dnsctx.rdnsChannel <- ip:
case config.dnsctx.rdnsChannel <- ip:
//
default:
log.Tracef("rDNS queue is full")
@ -110,7 +106,7 @@ func resolveRDNS(ip string) string {
return ""
}
resp, err := dnsctx.upstream.Exchange(&req)
resp, err := config.dnsctx.upstream.Exchange(&req)
if err != nil {
log.Error("Error while making an rDNS lookup for %s: %s", ip, err)
return ""
@ -138,18 +134,18 @@ func resolveRDNS(ip string) string {
func asyncRDNSLoop() {
for {
var ip string
ip = <-dnsctx.rdnsChannel
ip = <-config.dnsctx.rdnsChannel
host := resolveRDNS(ip)
if len(host) == 0 {
continue
}
dnsctx.rdnsLock.Lock()
delete(dnsctx.rdnsIP, ip)
dnsctx.rdnsLock.Unlock()
config.dnsctx.rdnsLock.Lock()
delete(config.dnsctx.rdnsIP, ip)
config.dnsctx.rdnsLock.Unlock()
_, _ = clientAddHost(ip, host, ClientSourceRDNS)
_, _ = config.clients.AddHost(ip, host, ClientSourceRDNS)
}
}
@ -221,7 +217,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) {
// If a client has his own settings, apply them
func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
c, ok := clientFind(clientAddr)
c, ok := config.clients.Find(clientAddr)
if !ok || !c.UseOwnSettings {
return
}
@ -242,12 +238,12 @@ func startDNSServer() error {
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
err = dnsServer.Start(&newconfig)
err = config.dnsServer.Start(&newconfig)
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
top := dnsServer.GetStatsTop()
top := config.dnsServer.GetStatsTop()
for k := range top.Clients {
beginAsyncRDNS(k)
}
@ -256,11 +252,11 @@ func startDNSServer() error {
}
func reconfigureDNSServer() error {
config, err := generateServerConfig()
newconfig, err := generateServerConfig()
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
err = dnsServer.Reconfigure(&config)
err = config.dnsServer.Reconfigure(&newconfig)
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
@ -273,7 +269,7 @@ func stopDNSServer() error {
return nil
}
err := dnsServer.Stop()
err := config.dnsServer.Stop()
if err != nil {
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
}

View File

@ -222,7 +222,7 @@ func refreshFiltersIfNecessary(force bool) int {
stopped := false
if updateCount != 0 {
_ = dnsServer.Stop()
_ = config.dnsServer.Stop()
stopped = true
}
@ -308,7 +308,7 @@ func parseFilterContents(contents []byte) (int, string) {
func (filter *filter) update() (bool, error) {
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
resp, err := client.Get(filter.URL)
resp, err := config.client.Get(filter.URL)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}

View File

@ -35,8 +35,8 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
}
if method == "POST" || method == "PUT" || method == "DELETE" {
controlLock.Lock()
defer controlLock.Unlock()
config.controlLock.Lock()
defer config.controlLock.Unlock()
}
handler(w, r)
@ -148,7 +148,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
return
}
// enforce https?
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && httpsServer.server != nil {
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && config.httpsServer.server != nil {
// yes, and we want host from host:port
host, _, err := net.SplitHostPort(r.Host)
if err != nil {

View File

@ -25,15 +25,6 @@ import (
"github.com/gobuffalo/packr"
)
var httpServer *http.Server
var httpsServer struct {
server *http.Server
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
sync.Mutex // protects config.TLS
shutdown bool // if TRUE, don't restart the server
}
var pidFileName string // PID file name. Empty if no PID file was created.
const (
// Used in config to indicate that syslog or eventlog (win) should be used for logger output
configSyslog = "syslog"
@ -48,7 +39,7 @@ var (
const versionCheckPeriod = time.Hour * 8
// main is the entry point
// Main is the entry point
func Main(version string, channel string) {
// Init update-related global variables
versionString = version
@ -108,7 +99,8 @@ func run(args options) {
os.Exit(0)
}()
clientsInit()
initConfig()
config.clients.Init()
if !config.firstRun {
// Do the upgrade if necessary
@ -168,7 +160,7 @@ func run(args options) {
}
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
pidFileName = args.pidFile
config.pidFileName = args.pidFile
}
// Update filters we've just loaded right away, don't wait for periodic update timer
@ -192,21 +184,21 @@ func run(args options) {
registerInstallHandlers()
}
httpsServer.cond = sync.NewCond(&httpsServer.Mutex)
config.httpsServer.cond = sync.NewCond(&config.httpsServer.Mutex)
// for https, we have a separate goroutine loop
go httpServerLoop()
// this loop is used as an ability to change listening host and/or port
for !httpsServer.shutdown {
for !config.httpsServer.shutdown {
printHTTPAddresses("http")
// we need to have new instance, because after Shutdown() the Server is not usable
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
httpServer = &http.Server{
config.httpServer = &http.Server{
Addr: address,
}
err := httpServer.ListenAndServe()
err := config.httpServer.ListenAndServe()
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
@ -219,14 +211,14 @@ func run(args options) {
}
func httpServerLoop() {
for !httpsServer.shutdown {
httpsServer.cond.L.Lock()
for !config.httpsServer.shutdown {
config.httpsServer.cond.L.Lock()
// this mechanism doesn't let us through until all conditions are met
for config.TLS.Enabled == false ||
config.TLS.PortHTTPS == 0 ||
config.TLS.PrivateKey == "" ||
config.TLS.CertificateChain == "" { // sleep until necessary data is supplied
httpsServer.cond.Wait()
config.httpsServer.cond.Wait()
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
// validate current TLS config and update warnings (it could have been loaded from file)
@ -250,10 +242,10 @@ func httpServerLoop() {
cleanupAlways()
log.Fatal(err)
}
httpsServer.cond.L.Unlock()
config.httpsServer.cond.L.Unlock()
// prepare HTTPS server
httpsServer.server = &http.Server{
config.httpsServer.server = &http.Server{
Addr: address,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
@ -262,7 +254,7 @@ func httpServerLoop() {
}
printHTTPAddresses("https")
err = httpsServer.server.ListenAndServeTLS("", "")
err = config.httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
@ -399,17 +391,17 @@ func cleanup() {
// Stop HTTP server, possibly waiting for all active connections to be closed
func stopHTTPServer() {
httpsServer.shutdown = true
if httpsServer.server != nil {
httpsServer.server.Shutdown(context.TODO())
config.httpsServer.shutdown = true
if config.httpsServer.server != nil {
config.httpsServer.server.Shutdown(context.TODO())
}
httpServer.Shutdown(context.TODO())
config.httpServer.Shutdown(context.TODO())
}
// This function is called before application exits
func cleanupAlways() {
if len(pidFileName) != 0 {
os.Remove(pidFileName)
if len(config.pidFileName) != 0 {
os.Remove(config.pidFileName)
}
log.Info("Stopped")
}