* clients: move container object to "config"

Now the functions use parameter, not a global object.
This commit is contained in:
Simon Zolin 2019-07-09 18:19:50 +03:00
parent 1973901802
commit d51f43e27a
5 changed files with 81 additions and 78 deletions

View File

@ -66,10 +66,9 @@ type clientsContainer struct {
lock sync.Mutex lock sync.Mutex
} }
var clients clientsContainer // Init initializes clients container
// Note: this function must be called only once
// Initialize clients container func (clients *clientsContainer) Init() {
func clientsInit() {
if clients.list != nil { if clients.list != nil {
log.Fatal("clients.list != nil") log.Fatal("clients.list != nil")
} }
@ -77,22 +76,24 @@ func clientsInit() {
clients.ipIndex = make(map[string]*Client) clients.ipIndex = make(map[string]*Client)
clients.ipHost = make(map[string]ClientHost) clients.ipHost = make(map[string]ClientHost)
go periodicClientsUpdate() go clients.periodicUpdate()
} }
func periodicClientsUpdate() { func (clients *clientsContainer) periodicUpdate() {
for { for {
clientsAddFromHostsFile() clients.addFromHostsFile()
clientsAddFromSystemARP() clients.addFromSystemARP()
time.Sleep(clientsUpdatePeriod) time.Sleep(clientsUpdatePeriod)
} }
} }
func clientsGetList() map[string]*Client { // GetList returns the pointer to clients list
func (clients *clientsContainer) GetList() map[string]*Client {
return clients.list return clients.list
} }
func clientExists(ip string) bool { // Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip string) bool {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -105,8 +106,8 @@ func clientExists(ip string) bool {
return ok return ok
} }
// Search for a client by IP // Find searches for a client by IP
func clientFind(ip string) (Client, bool) { func (clients *clientsContainer) Find(ip string) (Client, bool) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -135,7 +136,7 @@ func clientFind(ip string) (Client, bool) {
} }
// Check if Client object's fields are correct // Check if Client object's fields are correct
func clientCheck(c *Client) error { func (c *Client) check() error {
if len(c.Name) == 0 { if len(c.Name) == 0 {
return fmt.Errorf("Invalid Name") return fmt.Errorf("Invalid Name")
} }
@ -162,8 +163,8 @@ func clientCheck(c *Client) error {
// Add a new client object // Add a new client object
// Return true: success; false: client exists. // Return true: success; false: client exists.
func clientAdd(c Client) (bool, error) { func (clients *clientsContainer) Add(c Client) (bool, error) {
e := clientCheck(&c) e := c.check()
if e != nil { if e != nil {
return false, e return false, e
} }
@ -194,8 +195,8 @@ func clientAdd(c Client) (bool, error) {
return true, nil return true, nil
} }
// Remove a client // Del removes a client
func clientDel(name string) bool { func (clients *clientsContainer) Del(name string) bool {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -210,8 +211,8 @@ func clientDel(name string) bool {
} }
// Update a client // Update a client
func clientUpdate(name string, c Client) error { func (clients *clientsContainer) Update(name string, c Client) error {
err := clientCheck(&c) err := c.check()
if err != nil { if err != nil {
return err return err
} }
@ -257,10 +258,10 @@ func clientUpdate(name string, c Client) error {
return nil return nil
} }
// Add new IP -> Host pair // AddHost adds new IP -> Host pair
// Use priority of the source (etc/hosts > ARP > rDNS) // Use priority of the source (etc/hosts > ARP > rDNS)
// so we overwrite existing entries with an equal or higher priority // so we overwrite existing entries with an equal or higher priority
func clientAddHost(ip, host string, source clientSource) (bool, error) { func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -279,7 +280,7 @@ func clientAddHost(ip, host string, source clientSource) (bool, error) {
} }
// Parse system 'hosts' file and fill clients array // Parse system 'hosts' file and fill clients array
func clientsAddFromHostsFile() { func (clients *clientsContainer) addFromHostsFile() {
hostsFn := "/etc/hosts" hostsFn := "/etc/hosts"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts") hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
@ -304,7 +305,7 @@ func clientsAddFromHostsFile() {
continue continue
} }
ok, e := clientAddHost(fields[0], fields[1], ClientSourceHostsFile) ok, e := clients.AddHost(fields[0], fields[1], ClientSourceHostsFile)
if e != nil { if e != nil {
log.Tracef("%s", e) log.Tracef("%s", e)
} }
@ -319,7 +320,7 @@ func clientsAddFromHostsFile() {
// Add IP -> Host pairs from the system's `arp -a` command output // Add IP -> Host pairs from the system's `arp -a` command output
// The command's output is: // The command's output is:
// HOST (IP) at MAC on IFACE // HOST (IP) at MAC on IFACE
func clientsAddFromSystemARP() { func (clients *clientsContainer) addFromSystemARP() {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
return return
@ -350,7 +351,7 @@ func clientsAddFromSystemARP() {
continue continue
} }
ok, e := clientAddHost(ip, host, ClientSourceARP) ok, e := clients.AddHost(ip, host, ClientSourceARP)
if e != nil { if e != nil {
log.Tracef("%s", e) log.Tracef("%s", e)
} }
@ -379,8 +380,8 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
data := clientListJSON{} data := clientListJSON{}
clients.lock.Lock() config.clients.lock.Lock()
for _, c := range clients.list { for _, c := range config.clients.list {
cj := clientJSON{ cj := clientJSON{
IP: c.IP, IP: c.IP,
MAC: c.MAC, MAC: c.MAC,
@ -402,7 +403,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
data.Clients = append(data.Clients, cj) data.Clients = append(data.Clients, cj)
} }
for ip, ch := range clients.ipHost { for ip, ch := range config.clients.ipHost {
cj := clientHostJSON{ cj := clientHostJSON{
IP: ip, IP: ip,
Name: ch.Host, Name: ch.Host,
@ -416,7 +417,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
} }
data.AutoClients = append(data.AutoClients, cj) data.AutoClients = append(data.AutoClients, cj)
} }
clients.lock.Unlock() config.clients.lock.Unlock()
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data) e := json.NewEncoder(w).Encode(data)
@ -462,7 +463,7 @@ func handleAddClient(w http.ResponseWriter, r *http.Request) {
httpError(w, http.StatusBadRequest, "%s", err) httpError(w, http.StatusBadRequest, "%s", err)
return return
} }
ok, err := clientAdd(*c) ok, err := config.clients.Add(*c)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) httpError(w, http.StatusBadRequest, "%s", err)
return return
@ -492,7 +493,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
return return
} }
if !clientDel(cj.Name) { if !config.clients.Del(cj.Name) {
httpError(w, http.StatusBadRequest, "Client not found") httpError(w, http.StatusBadRequest, "Client not found")
return return
} }
@ -501,7 +502,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
returnOK(w) returnOK(w)
} }
type clientUpdateJSON struct { type updateJSON struct {
Name string `json:"name"` Name string `json:"name"`
Data clientJSON `json:"data"` Data clientJSON `json:"data"`
} }
@ -515,7 +516,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
return return
} }
var dj clientUpdateJSON var dj updateJSON
err = json.Unmarshal(body, &dj) err = json.Unmarshal(body, &dj)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err) httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
@ -532,7 +533,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
return return
} }
err = clientUpdate(dj.Name, *c) err = config.clients.Update(dj.Name, *c)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) httpError(w, http.StatusBadRequest, "%s", err)
return return

View File

@ -6,17 +6,18 @@ func TestClients(t *testing.T) {
var c Client var c Client
var e error var e error
var b bool var b bool
clients := clientsContainer{}
clientsInit() clients.Init()
// add // add
c = Client{ c = Client{
IP: "1.1.1.1", IP: "1.1.1.1",
Name: "client1", Name: "client1",
} }
b, e = clientAdd(c) b, e = clients.Add(c)
if !b || e != nil { if !b || e != nil {
t.Fatalf("clientAdd #1") t.Fatalf("Add #1")
} }
// add #2 // add #2
@ -24,19 +25,19 @@ func TestClients(t *testing.T) {
IP: "2.2.2.2", IP: "2.2.2.2",
Name: "client2", Name: "client2",
} }
b, e = clientAdd(c) b, e = clients.Add(c)
if !b || e != nil { if !b || e != nil {
t.Fatalf("clientAdd #2") t.Fatalf("Add #2")
} }
c, b = clientFind("1.1.1.1") c, b = clients.Find("1.1.1.1")
if !b || c.Name != "client1" { if !b || c.Name != "client1" {
t.Fatalf("clientFind #1") t.Fatalf("Find #1")
} }
c, b = clientFind("2.2.2.2") c, b = clients.Find("2.2.2.2")
if !b || c.Name != "client2" { if !b || c.Name != "client2" {
t.Fatalf("clientFind #2") t.Fatalf("Find #2")
} }
// failed add - name in use // failed add - name in use
@ -44,9 +45,9 @@ func TestClients(t *testing.T) {
IP: "1.2.3.5", IP: "1.2.3.5",
Name: "client1", Name: "client1",
} }
b, _ = clientAdd(c) b, _ = clients.Add(c)
if b { if b {
t.Fatalf("clientAdd - name in use") t.Fatalf("Add - name in use")
} }
// failed add - ip in use // failed add - ip in use
@ -54,91 +55,91 @@ func TestClients(t *testing.T) {
IP: "2.2.2.2", IP: "2.2.2.2",
Name: "client3", Name: "client3",
} }
b, e = clientAdd(c) b, e = clients.Add(c)
if b || e == nil { if b || e == nil {
t.Fatalf("clientAdd - ip in use") t.Fatalf("Add - ip in use")
} }
// get // get
if clientExists("1.2.3.4") { if clients.Exists("1.2.3.4") {
t.Fatalf("clientExists") t.Fatalf("Exists")
} }
if !clientExists("1.1.1.1") { if !clients.Exists("1.1.1.1") {
t.Fatalf("clientExists #1") t.Fatalf("Exists #1")
} }
if !clientExists("2.2.2.2") { if !clients.Exists("2.2.2.2") {
t.Fatalf("clientExists #2") t.Fatalf("Exists #2")
} }
// failed update - no such name // failed update - no such name
c.IP = "1.2.3.0" c.IP = "1.2.3.0"
c.Name = "client3" c.Name = "client3"
if clientUpdate("client3", c) == nil { if clients.Update("client3", c) == nil {
t.Fatalf("clientUpdate") t.Fatalf("Update")
} }
// failed update - name in use // failed update - name in use
c.IP = "1.2.3.0" c.IP = "1.2.3.0"
c.Name = "client2" c.Name = "client2"
if clientUpdate("client1", c) == nil { if clients.Update("client1", c) == nil {
t.Fatalf("clientUpdate - name in use") t.Fatalf("Update - name in use")
} }
// failed update - ip in use // failed update - ip in use
c.IP = "2.2.2.2" c.IP = "2.2.2.2"
c.Name = "client1" c.Name = "client1"
if clientUpdate("client1", c) == nil { if clients.Update("client1", c) == nil {
t.Fatalf("clientUpdate - ip in use") t.Fatalf("Update - ip in use")
} }
// update // update
c.IP = "1.1.1.2" c.IP = "1.1.1.2"
c.Name = "client1" c.Name = "client1"
if clientUpdate("client1", c) != nil { if clients.Update("client1", c) != nil {
t.Fatalf("clientUpdate") t.Fatalf("Update")
} }
// get after update // get after update
if clientExists("1.1.1.1") || !clientExists("1.1.1.2") { if clients.Exists("1.1.1.1") || !clients.Exists("1.1.1.2") {
t.Fatalf("clientExists - get after update") t.Fatalf("Exists - get after update")
} }
// failed remove - no such name // failed remove - no such name
if clientDel("client3") { if clients.Del("client3") {
t.Fatalf("clientDel - no such name") t.Fatalf("Del - no such name")
} }
// remove // remove
if !clientDel("client1") || clientExists("1.1.1.2") { if !clients.Del("client1") || clients.Exists("1.1.1.2") {
t.Fatalf("clientDel") t.Fatalf("Del")
} }
// add host client // add host client
b, e = clientAddHost("1.1.1.1", "host", ClientSourceARP) b, e = clients.AddHost("1.1.1.1", "host", ClientSourceARP)
if !b || e != nil { if !b || e != nil {
t.Fatalf("clientAddHost") t.Fatalf("clientAddHost")
} }
// failed add - ip exists // failed add - ip exists
b, e = clientAddHost("1.1.1.1", "host1", ClientSourceRDNS) b, e = clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
if b || e != nil { if b || e != nil {
t.Fatalf("clientAddHost - ip exists") t.Fatalf("clientAddHost - ip exists")
} }
// overwrite with new data // overwrite with new data
b, e = clientAddHost("1.1.1.1", "host2", ClientSourceARP) b, e = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
if !b || e != nil { if !b || e != nil {
t.Fatalf("clientAddHost - overwrite with new data") t.Fatalf("clientAddHost - overwrite with new data")
} }
// overwrite with new data (higher priority) // overwrite with new data (higher priority)
b, e = clientAddHost("1.1.1.1", "host3", ClientSourceHostsFile) b, e = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
if !b || e != nil { if !b || e != nil {
t.Fatalf("clientAddHost - overwrite with new data (higher priority)") t.Fatalf("clientAddHost - overwrite with new data (higher priority)")
} }
// get // get
if !clientExists("1.1.1.1") { if !clients.Exists("1.1.1.1") {
t.Fatalf("clientAddHost") t.Fatalf("clientAddHost")
} }
} }

View File

@ -52,6 +52,7 @@ type configuration struct {
runningAsService bool runningAsService bool
disableUpdate bool // If set, don't check for updates disableUpdate bool // If set, don't check for updates
appSignalChannel chan os.Signal appSignalChannel chan os.Signal
clients clientsContainer
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
@ -233,7 +234,7 @@ func parseConfig() error {
SafeSearchEnabled: cy.SafeSearchEnabled, SafeSearchEnabled: cy.SafeSearchEnabled,
SafeBrowsingEnabled: cy.SafeBrowsingEnabled, SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
} }
_, err = clientAdd(cli) _, err = config.clients.Add(cli)
if err != nil { if err != nil {
log.Tracef("clientAdd: %s", err) log.Tracef("clientAdd: %s", err)
} }
@ -268,7 +269,7 @@ func (c *configuration) write() error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
clientsList := clientsGetList() clientsList := config.clients.GetList()
for _, cli := range clientsList { for _, cli := range clientsList {
ip := cli.IP ip := cli.IP
if len(cli.MAC) != 0 { if len(cli.MAC) != 0 {

View File

@ -68,7 +68,7 @@ func isRunning() bool {
} }
func beginAsyncRDNS(ip string) { func beginAsyncRDNS(ip string) {
if clientExists(ip) { if config.clients.Exists(ip) {
return return
} }
@ -149,7 +149,7 @@ func asyncRDNSLoop() {
delete(dnsctx.rdnsIP, ip) delete(dnsctx.rdnsIP, ip)
dnsctx.rdnsLock.Unlock() dnsctx.rdnsLock.Unlock()
_, _ = clientAddHost(ip, host, ClientSourceRDNS) _, _ = config.clients.AddHost(ip, host, ClientSourceRDNS)
} }
} }
@ -221,7 +221,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) {
// If a client has his own settings, apply them // If a client has his own settings, apply them
func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
c, ok := clientFind(clientAddr) c, ok := config.clients.Find(clientAddr)
if !ok || !c.UseOwnSettings { if !ok || !c.UseOwnSettings {
return return
} }

View File

@ -108,7 +108,7 @@ func run(args options) {
os.Exit(0) os.Exit(0)
}() }()
clientsInit() config.clients.Init()
if !config.firstRun { if !config.firstRun {
// Do the upgrade if necessary // Do the upgrade if necessary