diff --git a/home/clients.go b/home/clients.go index 4ae478be..d0767e47 100644 --- a/home/clients.go +++ b/home/clients.go @@ -66,10 +66,9 @@ type clientsContainer struct { lock sync.Mutex } -var clients clientsContainer - -// Initialize clients container -func clientsInit() { +// Init initializes clients container +// Note: this function must be called only once +func (clients *clientsContainer) Init() { if clients.list != nil { log.Fatal("clients.list != nil") } @@ -77,22 +76,24 @@ func clientsInit() { clients.ipIndex = make(map[string]*Client) clients.ipHost = make(map[string]ClientHost) - go periodicClientsUpdate() + go clients.periodicUpdate() } -func periodicClientsUpdate() { +func (clients *clientsContainer) periodicUpdate() { for { - clientsAddFromHostsFile() - clientsAddFromSystemARP() + clients.addFromHostsFile() + clients.addFromSystemARP() time.Sleep(clientsUpdatePeriod) } } -func clientsGetList() map[string]*Client { +// GetList returns the pointer to clients list +func (clients *clientsContainer) GetList() map[string]*Client { return clients.list } -func clientExists(ip string) bool { +// Exists checks if client with this IP already exists +func (clients *clientsContainer) Exists(ip string) bool { clients.lock.Lock() defer clients.lock.Unlock() @@ -105,8 +106,8 @@ func clientExists(ip string) bool { return ok } -// Search for a client by IP -func clientFind(ip string) (Client, bool) { +// Find searches for a client by IP +func (clients *clientsContainer) Find(ip string) (Client, bool) { clients.lock.Lock() defer clients.lock.Unlock() @@ -135,7 +136,7 @@ func clientFind(ip string) (Client, bool) { } // Check if Client object's fields are correct -func clientCheck(c *Client) error { +func (c *Client) check() error { if len(c.Name) == 0 { return fmt.Errorf("Invalid Name") } @@ -162,8 +163,8 @@ func clientCheck(c *Client) error { // Add a new client object // Return true: success; false: client exists. -func clientAdd(c Client) (bool, error) { - e := clientCheck(&c) +func (clients *clientsContainer) Add(c Client) (bool, error) { + e := c.check() if e != nil { return false, e } @@ -194,8 +195,8 @@ func clientAdd(c Client) (bool, error) { return true, nil } -// Remove a client -func clientDel(name string) bool { +// Del removes a client +func (clients *clientsContainer) Del(name string) bool { clients.lock.Lock() defer clients.lock.Unlock() @@ -210,8 +211,8 @@ func clientDel(name string) bool { } // Update a client -func clientUpdate(name string, c Client) error { - err := clientCheck(&c) +func (clients *clientsContainer) Update(name string, c Client) error { + err := c.check() if err != nil { return err } @@ -257,10 +258,10 @@ func clientUpdate(name string, c Client) error { return nil } -// Add new IP -> Host pair +// AddHost adds new IP -> Host pair // Use priority of the source (etc/hosts > ARP > rDNS) // so we overwrite existing entries with an equal or higher priority -func clientAddHost(ip, host string, source clientSource) (bool, error) { +func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) { clients.lock.Lock() defer clients.lock.Unlock() @@ -279,7 +280,7 @@ func clientAddHost(ip, host string, source clientSource) (bool, error) { } // Parse system 'hosts' file and fill clients array -func clientsAddFromHostsFile() { +func (clients *clientsContainer) addFromHostsFile() { hostsFn := "/etc/hosts" if runtime.GOOS == "windows" { hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts") @@ -304,7 +305,7 @@ func clientsAddFromHostsFile() { continue } - ok, e := clientAddHost(fields[0], fields[1], ClientSourceHostsFile) + ok, e := clients.AddHost(fields[0], fields[1], ClientSourceHostsFile) if e != nil { log.Tracef("%s", e) } @@ -319,7 +320,7 @@ func clientsAddFromHostsFile() { // Add IP -> Host pairs from the system's `arp -a` command output // The command's output is: // HOST (IP) at MAC on IFACE -func clientsAddFromSystemARP() { +func (clients *clientsContainer) addFromSystemARP() { if runtime.GOOS == "windows" { return @@ -350,7 +351,7 @@ func clientsAddFromSystemARP() { continue } - ok, e := clientAddHost(ip, host, ClientSourceARP) + ok, e := clients.AddHost(ip, host, ClientSourceARP) if e != nil { log.Tracef("%s", e) } @@ -379,8 +380,8 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) { data := clientListJSON{} - clients.lock.Lock() - for _, c := range clients.list { + config.clients.lock.Lock() + for _, c := range config.clients.list { cj := clientJSON{ IP: c.IP, MAC: c.MAC, @@ -402,7 +403,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) { data.Clients = append(data.Clients, cj) } - for ip, ch := range clients.ipHost { + for ip, ch := range config.clients.ipHost { cj := clientHostJSON{ IP: ip, Name: ch.Host, @@ -416,7 +417,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) { } data.AutoClients = append(data.AutoClients, cj) } - clients.lock.Unlock() + config.clients.lock.Unlock() w.Header().Set("Content-Type", "application/json") e := json.NewEncoder(w).Encode(data) @@ -462,7 +463,7 @@ func handleAddClient(w http.ResponseWriter, r *http.Request) { httpError(w, http.StatusBadRequest, "%s", err) return } - ok, err := clientAdd(*c) + ok, err := config.clients.Add(*c) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return @@ -492,7 +493,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) { return } - if !clientDel(cj.Name) { + if !config.clients.Del(cj.Name) { httpError(w, http.StatusBadRequest, "Client not found") return } @@ -501,7 +502,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) { returnOK(w) } -type clientUpdateJSON struct { +type updateJSON struct { Name string `json:"name"` Data clientJSON `json:"data"` } @@ -515,7 +516,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) { return } - var dj clientUpdateJSON + var dj updateJSON err = json.Unmarshal(body, &dj) if err != nil { httpError(w, http.StatusBadRequest, "JSON parse: %s", err) @@ -532,7 +533,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) { return } - err = clientUpdate(dj.Name, *c) + err = config.clients.Update(dj.Name, *c) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return diff --git a/home/clients_test.go b/home/clients_test.go index 4b45cee2..d5dc5143 100644 --- a/home/clients_test.go +++ b/home/clients_test.go @@ -6,17 +6,18 @@ func TestClients(t *testing.T) { var c Client var e error var b bool + clients := clientsContainer{} - clientsInit() + clients.Init() // add c = Client{ IP: "1.1.1.1", Name: "client1", } - b, e = clientAdd(c) + b, e = clients.Add(c) if !b || e != nil { - t.Fatalf("clientAdd #1") + t.Fatalf("Add #1") } // add #2 @@ -24,19 +25,19 @@ func TestClients(t *testing.T) { IP: "2.2.2.2", Name: "client2", } - b, e = clientAdd(c) + b, e = clients.Add(c) if !b || e != nil { - t.Fatalf("clientAdd #2") + t.Fatalf("Add #2") } - c, b = clientFind("1.1.1.1") + c, b = clients.Find("1.1.1.1") if !b || c.Name != "client1" { - t.Fatalf("clientFind #1") + t.Fatalf("Find #1") } - c, b = clientFind("2.2.2.2") + c, b = clients.Find("2.2.2.2") if !b || c.Name != "client2" { - t.Fatalf("clientFind #2") + t.Fatalf("Find #2") } // failed add - name in use @@ -44,9 +45,9 @@ func TestClients(t *testing.T) { IP: "1.2.3.5", Name: "client1", } - b, _ = clientAdd(c) + b, _ = clients.Add(c) if b { - t.Fatalf("clientAdd - name in use") + t.Fatalf("Add - name in use") } // failed add - ip in use @@ -54,91 +55,91 @@ func TestClients(t *testing.T) { IP: "2.2.2.2", Name: "client3", } - b, e = clientAdd(c) + b, e = clients.Add(c) if b || e == nil { - t.Fatalf("clientAdd - ip in use") + t.Fatalf("Add - ip in use") } // get - if clientExists("1.2.3.4") { - t.Fatalf("clientExists") + if clients.Exists("1.2.3.4") { + t.Fatalf("Exists") } - if !clientExists("1.1.1.1") { - t.Fatalf("clientExists #1") + if !clients.Exists("1.1.1.1") { + t.Fatalf("Exists #1") } - if !clientExists("2.2.2.2") { - t.Fatalf("clientExists #2") + if !clients.Exists("2.2.2.2") { + t.Fatalf("Exists #2") } // failed update - no such name c.IP = "1.2.3.0" c.Name = "client3" - if clientUpdate("client3", c) == nil { - t.Fatalf("clientUpdate") + if clients.Update("client3", c) == nil { + t.Fatalf("Update") } // failed update - name in use c.IP = "1.2.3.0" c.Name = "client2" - if clientUpdate("client1", c) == nil { - t.Fatalf("clientUpdate - name in use") + if clients.Update("client1", c) == nil { + t.Fatalf("Update - name in use") } // failed update - ip in use c.IP = "2.2.2.2" c.Name = "client1" - if clientUpdate("client1", c) == nil { - t.Fatalf("clientUpdate - ip in use") + if clients.Update("client1", c) == nil { + t.Fatalf("Update - ip in use") } // update c.IP = "1.1.1.2" c.Name = "client1" - if clientUpdate("client1", c) != nil { - t.Fatalf("clientUpdate") + if clients.Update("client1", c) != nil { + t.Fatalf("Update") } // get after update - if clientExists("1.1.1.1") || !clientExists("1.1.1.2") { - t.Fatalf("clientExists - get after update") + if clients.Exists("1.1.1.1") || !clients.Exists("1.1.1.2") { + t.Fatalf("Exists - get after update") } // failed remove - no such name - if clientDel("client3") { - t.Fatalf("clientDel - no such name") + if clients.Del("client3") { + t.Fatalf("Del - no such name") } // remove - if !clientDel("client1") || clientExists("1.1.1.2") { - t.Fatalf("clientDel") + if !clients.Del("client1") || clients.Exists("1.1.1.2") { + t.Fatalf("Del") } // add host client - b, e = clientAddHost("1.1.1.1", "host", ClientSourceARP) + b, e = clients.AddHost("1.1.1.1", "host", ClientSourceARP) if !b || e != nil { t.Fatalf("clientAddHost") } // failed add - ip exists - b, e = clientAddHost("1.1.1.1", "host1", ClientSourceRDNS) + b, e = clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS) if b || e != nil { t.Fatalf("clientAddHost - ip exists") } // overwrite with new data - b, e = clientAddHost("1.1.1.1", "host2", ClientSourceARP) + b, e = clients.AddHost("1.1.1.1", "host2", ClientSourceARP) if !b || e != nil { t.Fatalf("clientAddHost - overwrite with new data") } // overwrite with new data (higher priority) - b, e = clientAddHost("1.1.1.1", "host3", ClientSourceHostsFile) + b, e = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile) if !b || e != nil { t.Fatalf("clientAddHost - overwrite with new data (higher priority)") } // get - if !clientExists("1.1.1.1") { + if !clients.Exists("1.1.1.1") { t.Fatalf("clientAddHost") } } diff --git a/home/config.go b/home/config.go index 9a294376..a0622b49 100644 --- a/home/config.go +++ b/home/config.go @@ -52,6 +52,7 @@ type configuration struct { runningAsService bool disableUpdate bool // If set, don't check for updates appSignalChannel chan os.Signal + clients clientsContainer BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server @@ -233,7 +234,7 @@ func parseConfig() error { SafeSearchEnabled: cy.SafeSearchEnabled, SafeBrowsingEnabled: cy.SafeBrowsingEnabled, } - _, err = clientAdd(cli) + _, err = config.clients.Add(cli) if err != nil { log.Tracef("clientAdd: %s", err) } @@ -268,7 +269,7 @@ func (c *configuration) write() error { c.Lock() defer c.Unlock() - clientsList := clientsGetList() + clientsList := config.clients.GetList() for _, cli := range clientsList { ip := cli.IP if len(cli.MAC) != 0 { diff --git a/home/dns.go b/home/dns.go index 52f57202..4f3ed2d2 100644 --- a/home/dns.go +++ b/home/dns.go @@ -68,7 +68,7 @@ func isRunning() bool { } func beginAsyncRDNS(ip string) { - if clientExists(ip) { + if config.clients.Exists(ip) { return } @@ -149,7 +149,7 @@ func asyncRDNSLoop() { delete(dnsctx.rdnsIP, ip) dnsctx.rdnsLock.Unlock() - _, _ = clientAddHost(ip, host, ClientSourceRDNS) + _, _ = config.clients.AddHost(ip, host, ClientSourceRDNS) } } @@ -221,7 +221,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) { // If a client has his own settings, apply them func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { - c, ok := clientFind(clientAddr) + c, ok := config.clients.Find(clientAddr) if !ok || !c.UseOwnSettings { return } diff --git a/home/home.go b/home/home.go index 768bd17a..10801c05 100644 --- a/home/home.go +++ b/home/home.go @@ -108,7 +108,7 @@ func run(args options) { os.Exit(0) }() - clientsInit() + config.clients.Init() if !config.firstRun { // Do the upgrade if necessary