* clients: move container object to "config"

Now the functions use parameter, not a global object.
This commit is contained in:
Simon Zolin 2019-07-09 18:19:50 +03:00
parent 1973901802
commit d51f43e27a
5 changed files with 81 additions and 78 deletions

View File

@ -66,10 +66,9 @@ type clientsContainer struct {
lock sync.Mutex
}
var clients clientsContainer
// Initialize clients container
func clientsInit() {
// Init initializes clients container
// Note: this function must be called only once
func (clients *clientsContainer) Init() {
if clients.list != nil {
log.Fatal("clients.list != nil")
}
@ -77,22 +76,24 @@ func clientsInit() {
clients.ipIndex = make(map[string]*Client)
clients.ipHost = make(map[string]ClientHost)
go periodicClientsUpdate()
go clients.periodicUpdate()
}
func periodicClientsUpdate() {
func (clients *clientsContainer) periodicUpdate() {
for {
clientsAddFromHostsFile()
clientsAddFromSystemARP()
clients.addFromHostsFile()
clients.addFromSystemARP()
time.Sleep(clientsUpdatePeriod)
}
}
func clientsGetList() map[string]*Client {
// GetList returns the pointer to clients list
func (clients *clientsContainer) GetList() map[string]*Client {
return clients.list
}
func clientExists(ip string) bool {
// Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip string) bool {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -105,8 +106,8 @@ func clientExists(ip string) bool {
return ok
}
// Search for a client by IP
func clientFind(ip string) (Client, bool) {
// Find searches for a client by IP
func (clients *clientsContainer) Find(ip string) (Client, bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -135,7 +136,7 @@ func clientFind(ip string) (Client, bool) {
}
// Check if Client object's fields are correct
func clientCheck(c *Client) error {
func (c *Client) check() error {
if len(c.Name) == 0 {
return fmt.Errorf("Invalid Name")
}
@ -162,8 +163,8 @@ func clientCheck(c *Client) error {
// Add a new client object
// Return true: success; false: client exists.
func clientAdd(c Client) (bool, error) {
e := clientCheck(&c)
func (clients *clientsContainer) Add(c Client) (bool, error) {
e := c.check()
if e != nil {
return false, e
}
@ -194,8 +195,8 @@ func clientAdd(c Client) (bool, error) {
return true, nil
}
// Remove a client
func clientDel(name string) bool {
// Del removes a client
func (clients *clientsContainer) Del(name string) bool {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -210,8 +211,8 @@ func clientDel(name string) bool {
}
// Update a client
func clientUpdate(name string, c Client) error {
err := clientCheck(&c)
func (clients *clientsContainer) Update(name string, c Client) error {
err := c.check()
if err != nil {
return err
}
@ -257,10 +258,10 @@ func clientUpdate(name string, c Client) error {
return nil
}
// Add new IP -> Host pair
// AddHost adds new IP -> Host pair
// Use priority of the source (etc/hosts > ARP > rDNS)
// so we overwrite existing entries with an equal or higher priority
func clientAddHost(ip, host string, source clientSource) (bool, error) {
func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
clients.lock.Lock()
defer clients.lock.Unlock()
@ -279,7 +280,7 @@ func clientAddHost(ip, host string, source clientSource) (bool, error) {
}
// Parse system 'hosts' file and fill clients array
func clientsAddFromHostsFile() {
func (clients *clientsContainer) addFromHostsFile() {
hostsFn := "/etc/hosts"
if runtime.GOOS == "windows" {
hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
@ -304,7 +305,7 @@ func clientsAddFromHostsFile() {
continue
}
ok, e := clientAddHost(fields[0], fields[1], ClientSourceHostsFile)
ok, e := clients.AddHost(fields[0], fields[1], ClientSourceHostsFile)
if e != nil {
log.Tracef("%s", e)
}
@ -319,7 +320,7 @@ func clientsAddFromHostsFile() {
// Add IP -> Host pairs from the system's `arp -a` command output
// The command's output is:
// HOST (IP) at MAC on IFACE
func clientsAddFromSystemARP() {
func (clients *clientsContainer) addFromSystemARP() {
if runtime.GOOS == "windows" {
return
@ -350,7 +351,7 @@ func clientsAddFromSystemARP() {
continue
}
ok, e := clientAddHost(ip, host, ClientSourceARP)
ok, e := clients.AddHost(ip, host, ClientSourceARP)
if e != nil {
log.Tracef("%s", e)
}
@ -379,8 +380,8 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
data := clientListJSON{}
clients.lock.Lock()
for _, c := range clients.list {
config.clients.lock.Lock()
for _, c := range config.clients.list {
cj := clientJSON{
IP: c.IP,
MAC: c.MAC,
@ -402,7 +403,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
data.Clients = append(data.Clients, cj)
}
for ip, ch := range clients.ipHost {
for ip, ch := range config.clients.ipHost {
cj := clientHostJSON{
IP: ip,
Name: ch.Host,
@ -416,7 +417,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
}
data.AutoClients = append(data.AutoClients, cj)
}
clients.lock.Unlock()
config.clients.lock.Unlock()
w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data)
@ -462,7 +463,7 @@ func handleAddClient(w http.ResponseWriter, r *http.Request) {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
ok, err := clientAdd(*c)
ok, err := config.clients.Add(*c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
@ -492,7 +493,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
return
}
if !clientDel(cj.Name) {
if !config.clients.Del(cj.Name) {
httpError(w, http.StatusBadRequest, "Client not found")
return
}
@ -501,7 +502,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
returnOK(w)
}
type clientUpdateJSON struct {
type updateJSON struct {
Name string `json:"name"`
Data clientJSON `json:"data"`
}
@ -515,7 +516,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
return
}
var dj clientUpdateJSON
var dj updateJSON
err = json.Unmarshal(body, &dj)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
@ -532,7 +533,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
return
}
err = clientUpdate(dj.Name, *c)
err = config.clients.Update(dj.Name, *c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return

View File

@ -6,17 +6,18 @@ func TestClients(t *testing.T) {
var c Client
var e error
var b bool
clients := clientsContainer{}
clientsInit()
clients.Init()
// add
c = Client{
IP: "1.1.1.1",
Name: "client1",
}
b, e = clientAdd(c)
b, e = clients.Add(c)
if !b || e != nil {
t.Fatalf("clientAdd #1")
t.Fatalf("Add #1")
}
// add #2
@ -24,19 +25,19 @@ func TestClients(t *testing.T) {
IP: "2.2.2.2",
Name: "client2",
}
b, e = clientAdd(c)
b, e = clients.Add(c)
if !b || e != nil {
t.Fatalf("clientAdd #2")
t.Fatalf("Add #2")
}
c, b = clientFind("1.1.1.1")
c, b = clients.Find("1.1.1.1")
if !b || c.Name != "client1" {
t.Fatalf("clientFind #1")
t.Fatalf("Find #1")
}
c, b = clientFind("2.2.2.2")
c, b = clients.Find("2.2.2.2")
if !b || c.Name != "client2" {
t.Fatalf("clientFind #2")
t.Fatalf("Find #2")
}
// failed add - name in use
@ -44,9 +45,9 @@ func TestClients(t *testing.T) {
IP: "1.2.3.5",
Name: "client1",
}
b, _ = clientAdd(c)
b, _ = clients.Add(c)
if b {
t.Fatalf("clientAdd - name in use")
t.Fatalf("Add - name in use")
}
// failed add - ip in use
@ -54,91 +55,91 @@ func TestClients(t *testing.T) {
IP: "2.2.2.2",
Name: "client3",
}
b, e = clientAdd(c)
b, e = clients.Add(c)
if b || e == nil {
t.Fatalf("clientAdd - ip in use")
t.Fatalf("Add - ip in use")
}
// get
if clientExists("1.2.3.4") {
t.Fatalf("clientExists")
if clients.Exists("1.2.3.4") {
t.Fatalf("Exists")
}
if !clientExists("1.1.1.1") {
t.Fatalf("clientExists #1")
if !clients.Exists("1.1.1.1") {
t.Fatalf("Exists #1")
}
if !clientExists("2.2.2.2") {
t.Fatalf("clientExists #2")
if !clients.Exists("2.2.2.2") {
t.Fatalf("Exists #2")
}
// failed update - no such name
c.IP = "1.2.3.0"
c.Name = "client3"
if clientUpdate("client3", c) == nil {
t.Fatalf("clientUpdate")
if clients.Update("client3", c) == nil {
t.Fatalf("Update")
}
// failed update - name in use
c.IP = "1.2.3.0"
c.Name = "client2"
if clientUpdate("client1", c) == nil {
t.Fatalf("clientUpdate - name in use")
if clients.Update("client1", c) == nil {
t.Fatalf("Update - name in use")
}
// failed update - ip in use
c.IP = "2.2.2.2"
c.Name = "client1"
if clientUpdate("client1", c) == nil {
t.Fatalf("clientUpdate - ip in use")
if clients.Update("client1", c) == nil {
t.Fatalf("Update - ip in use")
}
// update
c.IP = "1.1.1.2"
c.Name = "client1"
if clientUpdate("client1", c) != nil {
t.Fatalf("clientUpdate")
if clients.Update("client1", c) != nil {
t.Fatalf("Update")
}
// get after update
if clientExists("1.1.1.1") || !clientExists("1.1.1.2") {
t.Fatalf("clientExists - get after update")
if clients.Exists("1.1.1.1") || !clients.Exists("1.1.1.2") {
t.Fatalf("Exists - get after update")
}
// failed remove - no such name
if clientDel("client3") {
t.Fatalf("clientDel - no such name")
if clients.Del("client3") {
t.Fatalf("Del - no such name")
}
// remove
if !clientDel("client1") || clientExists("1.1.1.2") {
t.Fatalf("clientDel")
if !clients.Del("client1") || clients.Exists("1.1.1.2") {
t.Fatalf("Del")
}
// add host client
b, e = clientAddHost("1.1.1.1", "host", ClientSourceARP)
b, e = clients.AddHost("1.1.1.1", "host", ClientSourceARP)
if !b || e != nil {
t.Fatalf("clientAddHost")
}
// failed add - ip exists
b, e = clientAddHost("1.1.1.1", "host1", ClientSourceRDNS)
b, e = clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
if b || e != nil {
t.Fatalf("clientAddHost - ip exists")
}
// overwrite with new data
b, e = clientAddHost("1.1.1.1", "host2", ClientSourceARP)
b, e = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
if !b || e != nil {
t.Fatalf("clientAddHost - overwrite with new data")
}
// overwrite with new data (higher priority)
b, e = clientAddHost("1.1.1.1", "host3", ClientSourceHostsFile)
b, e = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
if !b || e != nil {
t.Fatalf("clientAddHost - overwrite with new data (higher priority)")
}
// get
if !clientExists("1.1.1.1") {
if !clients.Exists("1.1.1.1") {
t.Fatalf("clientAddHost")
}
}

View File

@ -52,6 +52,7 @@ type configuration struct {
runningAsService bool
disableUpdate bool // If set, don't check for updates
appSignalChannel chan os.Signal
clients clientsContainer
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
@ -233,7 +234,7 @@ func parseConfig() error {
SafeSearchEnabled: cy.SafeSearchEnabled,
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
}
_, err = clientAdd(cli)
_, err = config.clients.Add(cli)
if err != nil {
log.Tracef("clientAdd: %s", err)
}
@ -268,7 +269,7 @@ func (c *configuration) write() error {
c.Lock()
defer c.Unlock()
clientsList := clientsGetList()
clientsList := config.clients.GetList()
for _, cli := range clientsList {
ip := cli.IP
if len(cli.MAC) != 0 {

View File

@ -68,7 +68,7 @@ func isRunning() bool {
}
func beginAsyncRDNS(ip string) {
if clientExists(ip) {
if config.clients.Exists(ip) {
return
}
@ -149,7 +149,7 @@ func asyncRDNSLoop() {
delete(dnsctx.rdnsIP, ip)
dnsctx.rdnsLock.Unlock()
_, _ = clientAddHost(ip, host, ClientSourceRDNS)
_, _ = config.clients.AddHost(ip, host, ClientSourceRDNS)
}
}
@ -221,7 +221,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) {
// If a client has his own settings, apply them
func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
c, ok := clientFind(clientAddr)
c, ok := config.clients.Find(clientAddr)
if !ok || !c.UseOwnSettings {
return
}

View File

@ -108,7 +108,7 @@ func run(args options) {
os.Exit(0)
}()
clientsInit()
config.clients.Init()
if !config.firstRun {
// Do the upgrade if necessary