From 5c385521c42ee597237998ab6273e2057338c5d8 Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Fri, 8 Nov 2019 17:53:12 +0300 Subject: [PATCH 01/31] + install: check static ip * move code to network_utils + add tests 'dhcp' package has its own network_utils for now + doc: add static ip configuration algorithm on Ubuntu --- AGHTechDoc.md | 62 +++++++- dhcpd/dhcp_http.go | 162 +------------------ dhcpd/network_utils.go | 317 +++++++++++++++++++++++++++++++++++++ home/control_install.go | 61 ++++++- home/helpers.go | 138 ---------------- home/network_utils.go | 317 +++++++++++++++++++++++++++++++++++++ home/network_utils_test.go | 61 +++++++ 7 files changed, 806 insertions(+), 312 deletions(-) create mode 100644 dhcpd/network_utils.go create mode 100644 home/network_utils.go create mode 100644 home/network_utils_test.go diff --git a/AGHTechDoc.md b/AGHTechDoc.md index 5d812e6b..c5d09f8a 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -138,10 +138,13 @@ Request: { "web":{"port":80,"ip":"192.168.11.33"}, "dns":{"port":53,"ip":"127.0.0.1","autofix":false}, + "set_static_ip": true | false } Server should check whether a port is available only in case it itself isn't already listening on that port. +If `set_static_ip` is `true`, Server attempts to set a static IP for the network interface chosen by `dns.ip` setting. If the operation is successful, `static_ip.static` setting will be `yes`. If it fails, `static_ip.static` setting will be set to `error` and `static_ip.error` will contain the error message. + Server replies on success: 200 OK @@ -149,7 +152,14 @@ Server replies on success: { "web":{"status":""}, "dns":{"status":""}, + "static_ip": { + "static": "yes|no|error", + "ip": "", // set if static=no + "error": "..." // set if static=error } + } + +If `static_ip.static` is `no`, Server has detected that the system uses a dynamic address and it can automatically set a static address if `set_static_ip` in request is `true`. See section `Static IP check/set` for detailed process. Server replies on error: @@ -172,7 +182,11 @@ Request: POST /control/install/check_config { - "dns":{"port":53,"ip":"127.0.0.1","autofix":false} + "dns":{ + "port":53, + "ip":"127.0.0.1", + "autofix":false + } } Check if DNSStubListener is enabled: @@ -499,13 +513,7 @@ which will print: default via 192.168.0.1 proto dhcp metric 100 -#### Phase 2 - -This method only works on Raspbian. - -On Ubuntu DHCP for a network interface can't be disabled via `dhcpcd.conf`. This must be configured in `/etc/netplan/01-netcfg.yaml`. - -Fedora doesn't use `dhcpcd.conf` configuration at all. +#### Phase 2 (Raspbian) Step 1. @@ -526,6 +534,44 @@ If we would set a different IP address, we'd need to replace the IP address for ip addr replace dev eth0 192.168.0.1/24 +#### Phase 2 (Ubuntu) + +`/etc/netplan/01-netcfg.yaml` or `/etc/netplan/01-network-manager-all.yaml` + +This configuration example has a static IP set for `enp0s3` interface: + + network: + version: 2 + renderer: networkd + ethernets: + enp0s3: + dhcp4: no + addresses: [192.168.0.2/24] + gateway: 192.168.0.1 + nameservers: + addresses: [192.168.0.1,8.8.8.8] + +For dynamic configuration `dhcp4: yes` is set. + +Make a backup copy to `/etc/netplan/01-netcfg.yaml.backup`. + +Apply: + + netplan apply + +Restart network: + + systemctl restart networking + +or: + + systemctl restart network-manager + +or: + + systemctl restart system-networkd + + ### Add a static lease Request: diff --git a/dhcpd/dhcp_http.go b/dhcpd/dhcp_http.go index e1b3d4fb..35f6a7ad 100644 --- a/dhcpd/dhcp_http.go +++ b/dhcpd/dhcp_http.go @@ -2,18 +2,14 @@ package dhcpd import ( "encoding/json" - "errors" "fmt" "io/ioutil" "net" "net/http" "os" - "os/exec" - "runtime" "strings" "time" - "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" ) @@ -115,7 +111,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { } } -type netInterface struct { +type netInterfaceJSON struct { Name string `json:"name"` MTU int `json:"mtu"` HardwareAddr string `json:"hardware_address"` @@ -123,29 +119,6 @@ type netInterface struct { Flags string `json:"flags"` } -// getValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP -// invalid interface is a ppp interface or the one that doesn't allow broadcasts -func getValidNetInterfaces() ([]net.Interface, error) { - ifaces, err := net.Interfaces() - if err != nil { - return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err) - } - - netIfaces := []net.Interface{} - - for i := range ifaces { - if ifaces[i].Flags&net.FlagPointToPoint != 0 { - // this interface is ppp, we're not interested in this one - continue - } - - iface := ifaces[i] - netIfaces = append(netIfaces, iface) - } - - return netIfaces, nil -} - func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { response := map[string]interface{}{} @@ -170,7 +143,7 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { return } - jsonIface := netInterface{ + jsonIface := netInterfaceJSON{ Name: iface.Name, MTU: iface.MTU, HardwareAddr: iface.HardwareAddr.String(), @@ -263,137 +236,6 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque } } -// Check if network interface has a static IP configured -func hasStaticIP(ifaceName string) (bool, error) { - if runtime.GOOS == "windows" { - return false, errors.New("Can't detect static IP: not supported on Windows") - } - - body, err := ioutil.ReadFile("/etc/dhcpcd.conf") - if err != nil { - return false, err - } - lines := strings.Split(string(body), "\n") - nameLine := fmt.Sprintf("interface %s", ifaceName) - withinInterfaceCtx := false - - for _, line := range lines { - line = strings.TrimSpace(line) - - if withinInterfaceCtx && len(line) == 0 { - // an empty line resets our state - withinInterfaceCtx = false - } - - if len(line) == 0 || line[0] == '#' { - continue - } - line = strings.TrimSpace(line) - - if !withinInterfaceCtx { - if line == nameLine { - // we found our interface - withinInterfaceCtx = true - } - - } else { - if strings.HasPrefix(line, "interface ") { - // we found another interface - reset our state - withinInterfaceCtx = false - continue - } - if strings.HasPrefix(line, "static ip_address=") { - return true, nil - } - } - } - - return false, nil -} - -// Get IP address with netmask -func getFullIP(ifaceName string) string { - cmd := exec.Command("ip", "-oneline", "-family", "inet", "address", "show", ifaceName) - log.Tracef("executing %s %v", cmd.Path, cmd.Args) - d, err := cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - return "" - } - - fields := strings.Fields(string(d)) - if len(fields) < 4 { - return "" - } - _, _, err = net.ParseCIDR(fields[3]) - if err != nil { - return "" - } - - return fields[3] -} - -// Get gateway IP address -func getGatewayIP(ifaceName string) string { - cmd := exec.Command("ip", "route", "show", "dev", ifaceName) - log.Tracef("executing %s %v", cmd.Path, cmd.Args) - d, err := cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - return "" - } - - fields := strings.Fields(string(d)) - if len(fields) < 3 || fields[0] != "default" { - return "" - } - - ip := net.ParseIP(fields[2]) - if ip == nil { - return "" - } - - return fields[2] -} - -// Set a static IP for network interface -func setStaticIP(ifaceName string) error { - ip := getFullIP(ifaceName) - if len(ip) == 0 { - return errors.New("Can't get IP address") - } - - body, err := ioutil.ReadFile("/etc/dhcpcd.conf") - if err != nil { - return err - } - - ip4, _, err := net.ParseCIDR(ip) - if err != nil { - return err - } - - add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n", - ifaceName, ip) - body = append(body, []byte(add)...) - - gatewayIP := getGatewayIP(ifaceName) - if len(gatewayIP) != 0 { - add = fmt.Sprintf("static routers=%s\n", - gatewayIP) - body = append(body, []byte(add)...) - } - - add = fmt.Sprintf("static domain_name_servers=%s\n\n", - ip4) - body = append(body, []byte(add)...) - - err = file.SafeWrite("/etc/dhcpcd.conf", body) - if err != nil { - return err - } - - return nil -} - func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) { lj := staticLeaseJSON{} diff --git a/dhcpd/network_utils.go b/dhcpd/network_utils.go new file mode 100644 index 00000000..16a3c7dd --- /dev/null +++ b/dhcpd/network_utils.go @@ -0,0 +1,317 @@ +package dhcpd + +import ( + "errors" + "fmt" + "io/ioutil" + "net" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "syscall" + "time" + + "github.com/AdguardTeam/golibs/file" + "github.com/AdguardTeam/golibs/log" + "github.com/joomcode/errorx" +) + +type netInterface struct { + Name string + MTU int + HardwareAddr string + Addresses []string + Flags string +} + +// getValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP +// invalid interface is a ppp interface or the one that doesn't allow broadcasts +func getValidNetInterfaces() ([]net.Interface, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err) + } + + netIfaces := []net.Interface{} + + for i := range ifaces { + if ifaces[i].Flags&net.FlagPointToPoint != 0 { + // this interface is ppp, we're not interested in this one + continue + } + + iface := ifaces[i] + netIfaces = append(netIfaces, iface) + } + + return netIfaces, nil +} + +// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only +// we do not return link-local addresses here +func getValidNetInterfacesForWeb() ([]netInterface, error) { + ifaces, err := getValidNetInterfaces() + if err != nil { + return nil, errorx.Decorate(err, "Couldn't get interfaces") + } + if len(ifaces) == 0 { + return nil, errors.New("couldn't find any legible interface") + } + + var netInterfaces []netInterface + + for _, iface := range ifaces { + addrs, e := iface.Addrs() + if e != nil { + return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name) + } + + netIface := netInterface{ + Name: iface.Name, + MTU: iface.MTU, + HardwareAddr: iface.HardwareAddr.String(), + } + + if iface.Flags != 0 { + netIface.Flags = iface.Flags.String() + } + + // we don't want link-local addresses in json, so skip them + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok { + // not an IPNet, should not happen + return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) + } + // ignore link-local + if ipnet.IP.IsLinkLocalUnicast() { + continue + } + netIface.Addresses = append(netIface.Addresses, ipnet.IP.String()) + } + if len(netIface.Addresses) != 0 { + netInterfaces = append(netInterfaces, netIface) + } + } + + return netInterfaces, nil +} + +// Check if network interface has a static IP configured +// Supports: Raspbian. +func hasStaticIP(ifaceName string) (bool, error) { + if runtime.GOOS == "windows" { + return false, errors.New("Can't detect static IP: not supported on Windows") + } + + body, err := ioutil.ReadFile("/etc/dhcpcd.conf") + if err != nil { + return false, err + } + + return hasStaticIPDhcpcdConf(string(body), ifaceName), nil +} + +// for dhcpcd.conf +func hasStaticIPDhcpcdConf(data, ifaceName string) bool { + lines := strings.Split(data, "\n") + nameLine := fmt.Sprintf("interface %s", ifaceName) + withinInterfaceCtx := false + + for _, line := range lines { + line = strings.TrimSpace(line) + + if withinInterfaceCtx && len(line) == 0 { + // an empty line resets our state + withinInterfaceCtx = false + } + + if len(line) == 0 || line[0] == '#' { + continue + } + line = strings.TrimSpace(line) + + if !withinInterfaceCtx { + if line == nameLine { + // we found our interface + withinInterfaceCtx = true + } + + } else { + if strings.HasPrefix(line, "interface ") { + // we found another interface - reset our state + withinInterfaceCtx = false + continue + } + if strings.HasPrefix(line, "static ip_address=") { + return true + } + } + } + return false +} + +// Get IP address with netmask +func getFullIP(ifaceName string) string { + cmd := exec.Command("ip", "-oneline", "-family", "inet", "address", "show", ifaceName) + log.Tracef("executing %s %v", cmd.Path, cmd.Args) + d, err := cmd.Output() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return "" + } + + fields := strings.Fields(string(d)) + if len(fields) < 4 { + return "" + } + _, _, err = net.ParseCIDR(fields[3]) + if err != nil { + return "" + } + + return fields[3] +} + +// Get interface name by its IP address. +func getInterfaceByIP(ip string) string { + ifaces, err := getValidNetInterfacesForWeb() + if err != nil { + return "" + } + + for _, iface := range ifaces { + for _, addr := range iface.Addresses { + if ip == addr { + return iface.Name + } + } + } + + return "" +} + +// Get gateway IP address +func getGatewayIP(ifaceName string) string { + cmd := exec.Command("ip", "route", "show", "dev", ifaceName) + log.Tracef("executing %s %v", cmd.Path, cmd.Args) + d, err := cmd.Output() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return "" + } + + fields := strings.Fields(string(d)) + if len(fields) < 3 || fields[0] != "default" { + return "" + } + + ip := net.ParseIP(fields[2]) + if ip == nil { + return "" + } + + return fields[2] +} + +// Set a static IP for network interface +// Supports: Raspbian. +func setStaticIP(ifaceName string) error { + ip := getFullIP(ifaceName) + if len(ip) == 0 { + return errors.New("Can't get IP address") + } + + ip4, _, err := net.ParseCIDR(ip) + if err != nil { + return err + } + gatewayIP := getGatewayIP(ifaceName) + add := setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, ip4.String()) + + body, err := ioutil.ReadFile("/etc/dhcpcd.conf") + if err != nil { + return err + } + + body = append(body, []byte(add)...) + err = file.SafeWrite("/etc/dhcpcd.conf", body) + if err != nil { + return err + } + + return nil +} + +// for dhcpcd.conf +func setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string { + var body []byte + + add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n", + ifaceName, ip) + body = append(body, []byte(add)...) + + if len(gatewayIP) != 0 { + add = fmt.Sprintf("static routers=%s\n", + gatewayIP) + body = append(body, []byte(add)...) + } + + add = fmt.Sprintf("static domain_name_servers=%s\n\n", + dnsIP) + body = append(body, []byte(add)...) + + return string(body) +} + +// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily +func checkPortAvailable(host string, port int) error { + ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + ln.Close() + + // It seems that net.Listener.Close() doesn't close file descriptors right away. + // We wait for some time and hope that this fd will be closed. + time.Sleep(100 * time.Millisecond) + return nil +} + +func checkPacketPortAvailable(host string, port int) error { + ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + ln.Close() + + // It seems that net.Listener.Close() doesn't close file descriptors right away. + // We wait for some time and hope that this fd will be closed. + time.Sleep(100 * time.Millisecond) + return err +} + +// check if error is "address already in use" +func errorIsAddrInUse(err error) bool { + errOpError, ok := err.(*net.OpError) + if !ok { + return false + } + + errSyscallError, ok := errOpError.Err.(*os.SyscallError) + if !ok { + return false + } + + errErrno, ok := errSyscallError.Err.(syscall.Errno) + if !ok { + return false + } + + if runtime.GOOS == "windows" { + const WSAEADDRINUSE = 10048 + return errErrno == WSAEADDRINUSE + } + + return errErrno == syscall.EADDRINUSE +} diff --git a/home/control_install.go b/home/control_install.go index 5311c091..18fe2a8b 100644 --- a/home/control_install.go +++ b/home/control_install.go @@ -22,6 +22,14 @@ type firstRunData struct { Interfaces map[string]interface{} `json:"interfaces"` } +type netInterfaceJSON struct { + Name string `json:"name"` + MTU int `json:"mtu"` + HardwareAddr string `json:"hardware_address"` + Addresses []string `json:"ip_addresses"` + Flags string `json:"flags"` +} + // Get initial installation settings func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { data := firstRunData{} @@ -36,7 +44,14 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { data.Interfaces = make(map[string]interface{}) for _, iface := range ifaces { - data.Interfaces[iface.Name] = iface + ifaceJSON := netInterfaceJSON{ + Name: iface.Name, + MTU: iface.MTU, + HardwareAddr: iface.HardwareAddr, + Addresses: iface.Addresses, + Flags: iface.Flags, + } + data.Interfaces[iface.Name] = ifaceJSON } w.Header().Set("Content-Type", "application/json") @@ -48,9 +63,10 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { } type checkConfigReqEnt struct { - Port int `json:"port"` - IP string `json:"ip"` - Autofix bool `json:"autofix"` + Port int `json:"port"` + IP string `json:"ip"` + Autofix bool `json:"autofix"` + SetStaticIP bool `json:"set_static_ip"` } type checkConfigReq struct { Web checkConfigReqEnt `json:"web"` @@ -61,9 +77,15 @@ type checkConfigRespEnt struct { Status string `json:"status"` CanAutofix bool `json:"can_autofix"` } +type staticIPJSON struct { + Static string `json:"static"` + IP string `json:"ip"` + Error string `json:"error"` +} type checkConfigResp struct { - Web checkConfigRespEnt `json:"web"` - DNS checkConfigRespEnt `json:"dns"` + Web checkConfigRespEnt `json:"web"` + DNS checkConfigRespEnt `json:"dns"` + StaticIP staticIPJSON `json:"static_ip"` } // Check if ports are available, respond with results @@ -108,6 +130,33 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { if err != nil { respData.DNS.Status = fmt.Sprintf("%v", err) + + } else { + + interfaceName := getInterfaceByIP(reqData.DNS.IP) + staticIPStatus := "yes" + + if len(interfaceName) == 0 { + staticIPStatus = "error" + respData.StaticIP.Error = fmt.Sprintf("Couldn't find network interface by IP %s", reqData.DNS.IP) + + } else if reqData.DNS.SetStaticIP { + err = setStaticIP(interfaceName) + staticIPStatus = "error" + respData.StaticIP.Error = err.Error() + + } else { + // check if we have a static IP + isStaticIP, err := hasStaticIP(interfaceName) + if err != nil { + staticIPStatus = "error" + respData.StaticIP.Error = err.Error() + } else if !isStaticIP { + staticIPStatus = "no" + respData.StaticIP.IP = getFullIP(interfaceName) + } + } + respData.StaticIP.Static = staticIPStatus } } diff --git a/home/helpers.go b/home/helpers.go index c00fcbc8..e05c4bb2 100644 --- a/home/helpers.go +++ b/home/helpers.go @@ -2,7 +2,6 @@ package home import ( "context" - "errors" "fmt" "net" "net/http" @@ -13,7 +12,6 @@ import ( "runtime" "strconv" "strings" - "syscall" "time" "github.com/AdguardTeam/golibs/log" @@ -151,117 +149,6 @@ func postInstallHandler(handler http.Handler) http.Handler { return &postInstallHandlerStruct{handler} } -// ------------------ -// network interfaces -// ------------------ -type netInterface struct { - Name string `json:"name"` - MTU int `json:"mtu"` - HardwareAddr string `json:"hardware_address"` - Addresses []string `json:"ip_addresses"` - Flags string `json:"flags"` -} - -// getValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP -// invalid interface is a ppp interface or the one that doesn't allow broadcasts -func getValidNetInterfaces() ([]net.Interface, error) { - ifaces, err := net.Interfaces() - if err != nil { - return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err) - } - - netIfaces := []net.Interface{} - - for i := range ifaces { - if ifaces[i].Flags&net.FlagPointToPoint != 0 { - // this interface is ppp, we're not interested in this one - continue - } - - iface := ifaces[i] - netIfaces = append(netIfaces, iface) - } - - return netIfaces, nil -} - -// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only -// we do not return link-local addresses here -func getValidNetInterfacesForWeb() ([]netInterface, error) { - ifaces, err := getValidNetInterfaces() - if err != nil { - return nil, errorx.Decorate(err, "Couldn't get interfaces") - } - if len(ifaces) == 0 { - return nil, errors.New("couldn't find any legible interface") - } - - var netInterfaces []netInterface - - for _, iface := range ifaces { - addrs, e := iface.Addrs() - if e != nil { - return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name) - } - - netIface := netInterface{ - Name: iface.Name, - MTU: iface.MTU, - HardwareAddr: iface.HardwareAddr.String(), - } - - if iface.Flags != 0 { - netIface.Flags = iface.Flags.String() - } - - // we don't want link-local addresses in json, so skip them - for _, addr := range addrs { - ipnet, ok := addr.(*net.IPNet) - if !ok { - // not an IPNet, should not happen - return nil, fmt.Errorf("SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) - } - // ignore link-local - if ipnet.IP.IsLinkLocalUnicast() { - continue - } - netIface.Addresses = append(netIface.Addresses, ipnet.IP.String()) - } - if len(netIface.Addresses) != 0 { - netInterfaces = append(netInterfaces, netIface) - } - } - - return netInterfaces, nil -} - -// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily -func checkPortAvailable(host string, port int) error { - ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) - if err != nil { - return err - } - ln.Close() - - // It seems that net.Listener.Close() doesn't close file descriptors right away. - // We wait for some time and hope that this fd will be closed. - time.Sleep(100 * time.Millisecond) - return nil -} - -func checkPacketPortAvailable(host string, port int) error { - ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port))) - if err != nil { - return err - } - ln.Close() - - // It seems that net.Listener.Close() doesn't close file descriptors right away. - // We wait for some time and hope that this fd will be closed. - time.Sleep(100 * time.Millisecond) - return err -} - // Connect to a remote server resolving hostname using our own DNS server func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { log.Tracef("network:%v addr:%v", network, addr) @@ -303,31 +190,6 @@ func customDialContext(ctx context.Context, network, addr string) (net.Conn, err return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...) } -// check if error is "address already in use" -func errorIsAddrInUse(err error) bool { - errOpError, ok := err.(*net.OpError) - if !ok { - return false - } - - errSyscallError, ok := errOpError.Err.(*os.SyscallError) - if !ok { - return false - } - - errErrno, ok := errSyscallError.Err.(syscall.Errno) - if !ok { - return false - } - - if runtime.GOOS == "windows" { - const WSAEADDRINUSE = 10048 - return errErrno == WSAEADDRINUSE - } - - return errErrno == syscall.EADDRINUSE -} - // --------------------- // debug logging helpers // --------------------- diff --git a/home/network_utils.go b/home/network_utils.go new file mode 100644 index 00000000..04112001 --- /dev/null +++ b/home/network_utils.go @@ -0,0 +1,317 @@ +package home + +import ( + "errors" + "fmt" + "io/ioutil" + "net" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "syscall" + "time" + + "github.com/AdguardTeam/golibs/file" + "github.com/AdguardTeam/golibs/log" + "github.com/joomcode/errorx" +) + +type netInterface struct { + Name string + MTU int + HardwareAddr string + Addresses []string + Flags string +} + +// getValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP +// invalid interface is a ppp interface or the one that doesn't allow broadcasts +func getValidNetInterfaces() ([]net.Interface, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err) + } + + netIfaces := []net.Interface{} + + for i := range ifaces { + if ifaces[i].Flags&net.FlagPointToPoint != 0 { + // this interface is ppp, we're not interested in this one + continue + } + + iface := ifaces[i] + netIfaces = append(netIfaces, iface) + } + + return netIfaces, nil +} + +// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only +// we do not return link-local addresses here +func getValidNetInterfacesForWeb() ([]netInterface, error) { + ifaces, err := getValidNetInterfaces() + if err != nil { + return nil, errorx.Decorate(err, "Couldn't get interfaces") + } + if len(ifaces) == 0 { + return nil, errors.New("couldn't find any legible interface") + } + + var netInterfaces []netInterface + + for _, iface := range ifaces { + addrs, e := iface.Addrs() + if e != nil { + return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name) + } + + netIface := netInterface{ + Name: iface.Name, + MTU: iface.MTU, + HardwareAddr: iface.HardwareAddr.String(), + } + + if iface.Flags != 0 { + netIface.Flags = iface.Flags.String() + } + + // we don't want link-local addresses in json, so skip them + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok { + // not an IPNet, should not happen + return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) + } + // ignore link-local + if ipnet.IP.IsLinkLocalUnicast() { + continue + } + netIface.Addresses = append(netIface.Addresses, ipnet.IP.String()) + } + if len(netIface.Addresses) != 0 { + netInterfaces = append(netInterfaces, netIface) + } + } + + return netInterfaces, nil +} + +// Check if network interface has a static IP configured +// Supports: Raspbian. +func hasStaticIP(ifaceName string) (bool, error) { + if runtime.GOOS == "windows" { + return false, errors.New("Can't detect static IP: not supported on Windows") + } + + body, err := ioutil.ReadFile("/etc/dhcpcd.conf") + if err != nil { + return false, err + } + + return hasStaticIPDhcpcdConf(string(body), ifaceName), nil +} + +// for dhcpcd.conf +func hasStaticIPDhcpcdConf(data, ifaceName string) bool { + lines := strings.Split(data, "\n") + nameLine := fmt.Sprintf("interface %s", ifaceName) + withinInterfaceCtx := false + + for _, line := range lines { + line = strings.TrimSpace(line) + + if withinInterfaceCtx && len(line) == 0 { + // an empty line resets our state + withinInterfaceCtx = false + } + + if len(line) == 0 || line[0] == '#' { + continue + } + line = strings.TrimSpace(line) + + if !withinInterfaceCtx { + if line == nameLine { + // we found our interface + withinInterfaceCtx = true + } + + } else { + if strings.HasPrefix(line, "interface ") { + // we found another interface - reset our state + withinInterfaceCtx = false + continue + } + if strings.HasPrefix(line, "static ip_address=") { + return true + } + } + } + return false +} + +// Get IP address with netmask +func getFullIP(ifaceName string) string { + cmd := exec.Command("ip", "-oneline", "-family", "inet", "address", "show", ifaceName) + log.Tracef("executing %s %v", cmd.Path, cmd.Args) + d, err := cmd.Output() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return "" + } + + fields := strings.Fields(string(d)) + if len(fields) < 4 { + return "" + } + _, _, err = net.ParseCIDR(fields[3]) + if err != nil { + return "" + } + + return fields[3] +} + +// Get interface name by its IP address. +func getInterfaceByIP(ip string) string { + ifaces, err := getValidNetInterfacesForWeb() + if err != nil { + return "" + } + + for _, iface := range ifaces { + for _, addr := range iface.Addresses { + if ip == addr { + return iface.Name + } + } + } + + return "" +} + +// Get gateway IP address +func getGatewayIP(ifaceName string) string { + cmd := exec.Command("ip", "route", "show", "dev", ifaceName) + log.Tracef("executing %s %v", cmd.Path, cmd.Args) + d, err := cmd.Output() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return "" + } + + fields := strings.Fields(string(d)) + if len(fields) < 3 || fields[0] != "default" { + return "" + } + + ip := net.ParseIP(fields[2]) + if ip == nil { + return "" + } + + return fields[2] +} + +// Set a static IP for network interface +// Supports: Raspbian. +func setStaticIP(ifaceName string) error { + ip := getFullIP(ifaceName) + if len(ip) == 0 { + return errors.New("Can't get IP address") + } + + ip4, _, err := net.ParseCIDR(ip) + if err != nil { + return err + } + gatewayIP := getGatewayIP(ifaceName) + add := setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, ip4.String()) + + body, err := ioutil.ReadFile("/etc/dhcpcd.conf") + if err != nil { + return err + } + + body = append(body, []byte(add)...) + err = file.SafeWrite("/etc/dhcpcd.conf", body) + if err != nil { + return err + } + + return nil +} + +// for dhcpcd.conf +func setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string { + var body []byte + + add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n", + ifaceName, ip) + body = append(body, []byte(add)...) + + if len(gatewayIP) != 0 { + add = fmt.Sprintf("static routers=%s\n", + gatewayIP) + body = append(body, []byte(add)...) + } + + add = fmt.Sprintf("static domain_name_servers=%s\n\n", + dnsIP) + body = append(body, []byte(add)...) + + return string(body) +} + +// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily +func checkPortAvailable(host string, port int) error { + ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + ln.Close() + + // It seems that net.Listener.Close() doesn't close file descriptors right away. + // We wait for some time and hope that this fd will be closed. + time.Sleep(100 * time.Millisecond) + return nil +} + +func checkPacketPortAvailable(host string, port int) error { + ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + ln.Close() + + // It seems that net.Listener.Close() doesn't close file descriptors right away. + // We wait for some time and hope that this fd will be closed. + time.Sleep(100 * time.Millisecond) + return err +} + +// check if error is "address already in use" +func errorIsAddrInUse(err error) bool { + errOpError, ok := err.(*net.OpError) + if !ok { + return false + } + + errSyscallError, ok := errOpError.Err.(*os.SyscallError) + if !ok { + return false + } + + errErrno, ok := errSyscallError.Err.(syscall.Errno) + if !ok { + return false + } + + if runtime.GOOS == "windows" { + const WSAEADDRINUSE = 10048 + return errErrno == WSAEADDRINUSE + } + + return errErrno == syscall.EADDRINUSE +} diff --git a/home/network_utils_test.go b/home/network_utils_test.go new file mode 100644 index 00000000..6312e535 --- /dev/null +++ b/home/network_utils_test.go @@ -0,0 +1,61 @@ +package home + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHasStaticIPDhcpcdConf(t *testing.T) { + dhcpdConf := `#comment +# comment + +interface eth0 +static ip_address=192.168.0.1/24 + +# interface wlan0 +static ip_address=192.168.1.1/24 + +# comment +` + assert.True(t, !hasStaticIPDhcpcdConf(dhcpdConf, "wlan0")) + + dhcpdConf = `#comment +# comment + +interface eth0 +static ip_address=192.168.0.1/24 + +# interface wlan0 +static ip_address=192.168.1.1/24 + +# comment + +interface wlan0 +# comment +static ip_address=192.168.2.1/24 +` + assert.True(t, hasStaticIPDhcpcdConf(dhcpdConf, "wlan0")) +} + +func TestSetStaticIPDhcpcdConf(t *testing.T) { + dhcpcdConf := ` +interface wlan0 +static ip_address=192.168.0.2/24 +static routers=192.168.0.1 +static domain_name_servers=192.168.0.2 + +` + s := setStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2") + assert.Equal(t, dhcpcdConf, s) + + // without gateway + dhcpcdConf = ` +interface wlan0 +static ip_address=192.168.0.2/24 +static domain_name_servers=192.168.0.2 + +` + s = setStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2") + assert.Equal(t, dhcpcdConf, s) +} From cc2d953c9d5d036347b3dbe27a82de30eacaba3b Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Thu, 30 Jan 2020 18:26:54 +0300 Subject: [PATCH 02/31] + client: handle static IP while doing initial setup --- client/package-lock.json | 41 ++++++++++----- client/src/__locales/en.json | 5 +- client/src/helpers/form.js | 7 +++ client/src/install/Setup/Settings.js | 74 +++++++++++++++++++++------- client/src/install/Setup/Setup.css | 5 ++ client/src/install/Setup/index.js | 10 +++- client/src/reducers/install.js | 8 ++- 7 files changed, 119 insertions(+), 31 deletions(-) diff --git a/client/package-lock.json b/client/package-lock.json index 81899b0c..3e000a8d 100644 --- a/client/package-lock.json +++ b/client/package-lock.json @@ -5458,7 +5458,8 @@ "ansi-regex": { "version": "2.1.1", "bundled": true, - "dev": true + "dev": true, + "optional": true }, "aproba": { "version": "1.2.0", @@ -5479,12 +5480,14 @@ "balanced-match": { "version": "1.0.0", "bundled": true, - "dev": true + "dev": true, + "optional": true }, "brace-expansion": { "version": "1.1.11", "bundled": true, "dev": true, + "optional": true, "requires": { "balanced-match": "^1.0.0", "concat-map": "0.0.1" @@ -5499,17 +5502,20 @@ "code-point-at": { "version": "1.1.0", "bundled": true, - "dev": true + "dev": true, + "optional": true }, "concat-map": { "version": "0.0.1", "bundled": true, - "dev": true + "dev": true, + "optional": true }, "console-control-strings": { "version": "1.1.0", "bundled": true, - "dev": true + "dev": true, + "optional": true }, "core-util-is": { "version": "1.0.2", @@ -5626,7 +5632,8 @@ "inherits": { "version": "2.0.3", "bundled": true, - "dev": true + "dev": true, + "optional": true }, "ini": { "version": "1.3.5", @@ -5638,6 +5645,7 @@ "version": "1.0.0", "bundled": true, "dev": true, + "optional": true, "requires": { "number-is-nan": "^1.0.0" } @@ -5652,6 +5660,7 @@ "version": "3.0.4", "bundled": true, "dev": true, + "optional": true, "requires": { "brace-expansion": "^1.1.7" } @@ -5659,12 +5668,14 @@ "minimist": { "version": "0.0.8", "bundled": true, - "dev": true + "dev": true, + "optional": true }, "minipass": { "version": "2.3.5", "bundled": true, "dev": true, + "optional": true, "requires": { "safe-buffer": "^5.1.2", "yallist": "^3.0.0" @@ -5683,6 +5694,7 @@ "version": "0.5.1", "bundled": true, "dev": true, + "optional": true, "requires": { "minimist": "0.0.8" } @@ -5763,7 +5775,8 @@ "number-is-nan": { "version": "1.0.1", "bundled": true, - "dev": true + "dev": true, + "optional": true }, "object-assign": { "version": "4.1.1", @@ -5775,6 +5788,7 @@ "version": "1.4.0", "bundled": true, "dev": true, + "optional": true, "requires": { "wrappy": "1" } @@ -5860,7 +5874,8 @@ "safe-buffer": { "version": "5.1.2", "bundled": true, - "dev": true + "dev": true, + "optional": true }, "safer-buffer": { "version": "2.1.2", @@ -5896,6 +5911,7 @@ "version": "1.0.2", "bundled": true, "dev": true, + "optional": true, "requires": { "code-point-at": "^1.0.0", "is-fullwidth-code-point": "^1.0.0", @@ -5915,6 +5931,7 @@ "version": "3.0.1", "bundled": true, "dev": true, + "optional": true, "requires": { "ansi-regex": "^2.0.0" } @@ -5958,12 +5975,14 @@ "wrappy": { "version": "1.0.2", "bundled": true, - "dev": true + "dev": true, + "optional": true }, "yallist": { "version": "3.0.3", "bundled": true, - "dev": true + "dev": true, + "optional": true } } }, diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 961b47c7..aa63c8a5 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -462,5 +462,8 @@ "client_confirm_block": "Are you sure you want to block the client \"{{ip}}\"?", "client_confirm_unblock": "Are you sure you want to unblock the client \"{{ip}}\"?", "client_blocked": "Client \"{{ip}}\" successfully blocked", - "client_unblocked": "Client \"{{ip}}\" successfully unblocked" + "client_unblocked": "Client \"{{ip}}\" successfully unblocked", + "set_static_ip": "Set static IP address", + "install_static_error": "We failed to determine if this network interface is configured using static IP address. Please set a static IP address manually.", + "install_static_configure": "We have detected that a dynamic IP address is used — {{ip}}. We suggest that you set this IP as static." } diff --git a/client/src/helpers/form.js b/client/src/helpers/form.js index 0c617ff5..7aed918c 100644 --- a/client/src/helpers/form.js +++ b/client/src/helpers/form.js @@ -240,6 +240,13 @@ export const port = (value) => { return undefined; }; +export const validInstallPort = (value) => { + if (value < 1 || value > 65535) { + return form_error_port; + } + return undefined; +}; + export const portTLS = (value) => { if (value === 0) { return undefined; diff --git a/client/src/install/Setup/Settings.js b/client/src/install/Setup/Settings.js index 246206ba..3de151a9 100644 --- a/client/src/install/Setup/Settings.js +++ b/client/src/install/Setup/Settings.js @@ -7,26 +7,18 @@ import flow from 'lodash/flow'; import Controls from './Controls'; import AddressList from './AddressList'; +import Accordion from '../../components/ui/Accordion'; + import { getInterfaceIp } from '../../helpers/helpers'; import { ALL_INTERFACES_IP } from '../../helpers/constants'; -import { renderInputField } from '../../helpers/form'; +import { renderInputField, required, validInstallPort, toNumber } from '../../helpers/form'; -const required = (value) => { - if (value || value === 0) { - return false; - } - return form_error_required; +const STATIC_STATUS = { + ENABLED: 'yes', + DISABLED: 'no', + ERROR: 'error', }; -const port = (value) => { - if (value < 1 || value > 65535) { - return form_error_port; - } - return false; -}; - -const toNumber = value => value && parseInt(value, 10); - const renderInterfaces = (interfaces => ( Object.keys(interfaces).map((item) => { const option = interfaces[item]; @@ -79,11 +71,54 @@ class Settings extends Component { }); } + getStaticIpMessage = (staticIp, handleStaticIp) => { + const { static: status, ip, error } = staticIp; + + if (!status || status === STATIC_STATUS.ENABLED) { + return ''; + } + + return ( +
+ +
+ {status === STATIC_STATUS.DISABLED && ( + + + install_static_configure + + + + )} + {status === STATIC_STATUS.ERROR && ( + + install_static_error +
+ + {error} + +
+
+ )} +
+
+
+
+ ); + }; + render() { const { handleSubmit, handleChange, handleAutofix, + handleStaticIp, webIp, webPort, dnsIp, @@ -100,6 +135,7 @@ class Settings extends Component { status: dnsStatus, can_autofix: isDnsFixAvailable, } = config.dns; + const { staticIp } = config; return (
@@ -137,7 +173,7 @@ class Settings extends Component { type="number" className="form-control" placeholder="80" - validate={[port, required]} + validate={[validInstallPort, required]} normalize={toNumber} onChange={handleChange} /> @@ -205,12 +241,13 @@ class Settings extends Component { type="number" className="form-control" placeholder="80" - validate={[port, required]} + validate={[validInstallPort, required]} normalize={toNumber} onChange={handleChange} /> + {this.getStaticIpMessage(staticIp, handleStaticIp)}
{dnsStatus && @@ -237,6 +274,7 @@ class Settings extends Component { autofix_warning_result

+
} @@ -278,6 +316,8 @@ Settings.propTypes = { interfaces: PropTypes.object.isRequired, invalid: PropTypes.bool.isRequired, initialValues: PropTypes.object, + t: PropTypes.func.isRequired, + handleStaticIp: PropTypes.func.isRequired, }; const selector = formValueSelector('install'); diff --git a/client/src/install/Setup/Setup.css b/client/src/install/Setup/Setup.css index 11ee1430..aac7ea0e 100644 --- a/client/src/install/Setup/Setup.css +++ b/client/src/install/Setup/Setup.css @@ -119,3 +119,8 @@ .setup__error { margin: -5px 0 5px; } + +.divider--small { + margin-top: 1rem; + margin-bottom: 1rem; +} diff --git a/client/src/install/Setup/index.js b/client/src/install/Setup/index.js index ca91cd8f..ad179612 100644 --- a/client/src/install/Setup/index.js +++ b/client/src/install/Setup/index.js @@ -60,6 +60,12 @@ class Setup extends Component { } }; + handleStaticIp = () => { + this.props.checkConfig({ + set_static_ip: true, + }); + }; + openDashboard = (ip, port) => { let address = getWebAddress(ip, port); @@ -96,6 +102,7 @@ class Setup extends Component { onChange={this.handleFormChange} validateForm={this.handleFormChange} handleAutofix={this.handleAutofix} + handleStaticIp={this.handleStaticIp} /> ); case 3: @@ -117,6 +124,7 @@ class Setup extends Component { step, web, dns, + staticIp, interfaces, } = this.props.install; @@ -128,7 +136,7 @@ class Setup extends Component {
logo - {this.renderPage(step, { web, dns }, interfaces)} + {this.renderPage(step, { web, dns, staticIp }, interfaces)}
diff --git a/client/src/reducers/install.js b/client/src/reducers/install.js index 3709b0ec..b3f95dfb 100644 --- a/client/src/reducers/install.js +++ b/client/src/reducers/install.js @@ -32,9 +32,10 @@ const install = handleActions({ [actions.checkConfigSuccess]: (state, { payload }) => { const web = { ...state.web, ...payload.web }; const dns = { ...state.dns, ...payload.dns }; + const staticIp = { ...state.staticIp, ...payload.static_ip }; const newState = { - ...state, web, dns, processingCheck: false, + ...state, web, dns, staticIp, processingCheck: false, }; return newState; }, @@ -55,6 +56,11 @@ const install = handleActions({ status: '', can_autofix: false, }, + staticIp: { + static: '', + ip: '', + error: '', + }, interfaces: {}, }); From eb075198c2760c576d39b64bb8c283c060aa7a49 Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Thu, 30 Jan 2020 18:36:13 +0300 Subject: [PATCH 03/31] - client: send set_static_ip: false as default --- client/src/install/Setup/index.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/client/src/install/Setup/index.js b/client/src/install/Setup/index.js index ad179612..03ae76d5 100644 --- a/client/src/install/Setup/index.js +++ b/client/src/install/Setup/index.js @@ -37,8 +37,9 @@ class Setup extends Component { }; handleFormChange = debounce((values) => { - if (values && values.web.port && values.dns.port) { - this.props.checkConfig(values); + const { web, dns } = values; + if (values && web.port && dns.port) { + this.props.checkConfig({ web, dns, set_static_ip: false }); } }, DEBOUNCE_TIMEOUT); From 56f4daeb68f3699febf2562c96a68a1cdf13fd65 Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Thu, 30 Jan 2020 18:55:22 +0300 Subject: [PATCH 04/31] - client: remove staticIp from apply config --- client/src/install/Setup/index.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/src/install/Setup/index.js b/client/src/install/Setup/index.js index 03ae76d5..943fece9 100644 --- a/client/src/install/Setup/index.js +++ b/client/src/install/Setup/index.js @@ -33,7 +33,8 @@ class Setup extends Component { } handleFormSubmit = (values) => { - this.props.setAllSettings(values); + const { staticIp, ...config } = values; + this.props.setAllSettings(config); }; handleFormChange = debounce((values) => { From 8575042c23142deda71c56eb0f19c2bdfa1151f1 Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Thu, 30 Jan 2020 18:57:55 +0300 Subject: [PATCH 05/31] - client: fix locales --- client/src/__locales/en.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index aa63c8a5..4c837637 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -464,6 +464,6 @@ "client_blocked": "Client \"{{ip}}\" successfully blocked", "client_unblocked": "Client \"{{ip}}\" successfully unblocked", "set_static_ip": "Set static IP address", - "install_static_error": "We failed to determine if this network interface is configured using static IP address. Please set a static IP address manually.", - "install_static_configure": "We have detected that a dynamic IP address is used — {{ip}}. We suggest that you set this IP as static." + "install_static_error": "We failed to determine if this network interface is configured by using static IP address. Please set a static IP address manually.", + "install_static_configure": "We have detected that a dynamic IP address is used — {{ip}}. We suggest you to set this IP address as static." } From f91cbea9d5d9f849faedaf583cf53a2e3523c747 Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Thu, 30 Jan 2020 19:05:44 +0300 Subject: [PATCH 06/31] - client: add divider --- client/src/install/Setup/Settings.js | 1 + 1 file changed, 1 insertion(+) diff --git a/client/src/install/Setup/Settings.js b/client/src/install/Setup/Settings.js index 3de151a9..0b3120b2 100644 --- a/client/src/install/Setup/Settings.js +++ b/client/src/install/Setup/Settings.js @@ -192,6 +192,7 @@ class Settings extends Component { fix } +
} From e528522ad51b4c1357b9d6a4f0c592176745c077 Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Fri, 31 Jan 2020 13:24:18 +0300 Subject: [PATCH 07/31] linter --- dhcpd/network_utils.go | 74 ------------------------------------------ 1 file changed, 74 deletions(-) diff --git a/dhcpd/network_utils.go b/dhcpd/network_utils.go index 16a3c7dd..30180df2 100644 --- a/dhcpd/network_utils.go +++ b/dhcpd/network_utils.go @@ -5,13 +5,9 @@ import ( "fmt" "io/ioutil" "net" - "os" "os/exec" "runtime" - "strconv" "strings" - "syscall" - "time" "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" @@ -174,24 +170,6 @@ func getFullIP(ifaceName string) string { return fields[3] } -// Get interface name by its IP address. -func getInterfaceByIP(ip string) string { - ifaces, err := getValidNetInterfacesForWeb() - if err != nil { - return "" - } - - for _, iface := range ifaces { - for _, addr := range iface.Addresses { - if ip == addr { - return iface.Name - } - } - } - - return "" -} - // Get gateway IP address func getGatewayIP(ifaceName string) string { cmd := exec.Command("ip", "route", "show", "dev", ifaceName) @@ -263,55 +241,3 @@ func setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string { return string(body) } - -// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily -func checkPortAvailable(host string, port int) error { - ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) - if err != nil { - return err - } - ln.Close() - - // It seems that net.Listener.Close() doesn't close file descriptors right away. - // We wait for some time and hope that this fd will be closed. - time.Sleep(100 * time.Millisecond) - return nil -} - -func checkPacketPortAvailable(host string, port int) error { - ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port))) - if err != nil { - return err - } - ln.Close() - - // It seems that net.Listener.Close() doesn't close file descriptors right away. - // We wait for some time and hope that this fd will be closed. - time.Sleep(100 * time.Millisecond) - return err -} - -// check if error is "address already in use" -func errorIsAddrInUse(err error) bool { - errOpError, ok := err.(*net.OpError) - if !ok { - return false - } - - errSyscallError, ok := errOpError.Err.(*os.SyscallError) - if !ok { - return false - } - - errErrno, ok := errSyscallError.Err.(syscall.Errno) - if !ok { - return false - } - - if runtime.GOOS == "windows" { - const WSAEADDRINUSE = 10048 - return errErrno == WSAEADDRINUSE - } - - return errErrno == syscall.EADDRINUSE -} From 143a86a28a3465776f803f6b99b9f3c64b26400e Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Fri, 31 Jan 2020 13:26:47 +0300 Subject: [PATCH 08/31] linter --- dhcpd/network_utils.go | 1 + home/network_utils.go | 1 + 2 files changed, 2 insertions(+) diff --git a/dhcpd/network_utils.go b/dhcpd/network_utils.go index 30180df2..0c933cbc 100644 --- a/dhcpd/network_utils.go +++ b/dhcpd/network_utils.go @@ -47,6 +47,7 @@ func getValidNetInterfaces() ([]net.Interface, error) { // getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only // we do not return link-local addresses here +// nolint:dupl func getValidNetInterfacesForWeb() ([]netInterface, error) { ifaces, err := getValidNetInterfaces() if err != nil { diff --git a/home/network_utils.go b/home/network_utils.go index 04112001..146e4623 100644 --- a/home/network_utils.go +++ b/home/network_utils.go @@ -51,6 +51,7 @@ func getValidNetInterfaces() ([]net.Interface, error) { // getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only // we do not return link-local addresses here +// nolint:dupl func getValidNetInterfacesForWeb() ([]netInterface, error) { ifaces, err := getValidNetInterfaces() if err != nil { From 84b5708e71c88a9643d402ab630270f5e7bf35b8 Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Fri, 31 Jan 2020 13:27:53 +0300 Subject: [PATCH 09/31] linter --- dhcpd/network_utils.go | 52 ------------------------------------------ home/network_utils.go | 1 - 2 files changed, 53 deletions(-) diff --git a/dhcpd/network_utils.go b/dhcpd/network_utils.go index 0c933cbc..362378df 100644 --- a/dhcpd/network_utils.go +++ b/dhcpd/network_utils.go @@ -11,7 +11,6 @@ import ( "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" - "github.com/joomcode/errorx" ) type netInterface struct { @@ -45,57 +44,6 @@ func getValidNetInterfaces() ([]net.Interface, error) { return netIfaces, nil } -// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only -// we do not return link-local addresses here -// nolint:dupl -func getValidNetInterfacesForWeb() ([]netInterface, error) { - ifaces, err := getValidNetInterfaces() - if err != nil { - return nil, errorx.Decorate(err, "Couldn't get interfaces") - } - if len(ifaces) == 0 { - return nil, errors.New("couldn't find any legible interface") - } - - var netInterfaces []netInterface - - for _, iface := range ifaces { - addrs, e := iface.Addrs() - if e != nil { - return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name) - } - - netIface := netInterface{ - Name: iface.Name, - MTU: iface.MTU, - HardwareAddr: iface.HardwareAddr.String(), - } - - if iface.Flags != 0 { - netIface.Flags = iface.Flags.String() - } - - // we don't want link-local addresses in json, so skip them - for _, addr := range addrs { - ipnet, ok := addr.(*net.IPNet) - if !ok { - // not an IPNet, should not happen - return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) - } - // ignore link-local - if ipnet.IP.IsLinkLocalUnicast() { - continue - } - netIface.Addresses = append(netIface.Addresses, ipnet.IP.String()) - } - if len(netIface.Addresses) != 0 { - netInterfaces = append(netInterfaces, netIface) - } - } - - return netInterfaces, nil -} - // Check if network interface has a static IP configured // Supports: Raspbian. func hasStaticIP(ifaceName string) (bool, error) { diff --git a/home/network_utils.go b/home/network_utils.go index 146e4623..04112001 100644 --- a/home/network_utils.go +++ b/home/network_utils.go @@ -51,7 +51,6 @@ func getValidNetInterfaces() ([]net.Interface, error) { // getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only // we do not return link-local addresses here -// nolint:dupl func getValidNetInterfacesForWeb() ([]netInterface, error) { ifaces, err := getValidNetInterfaces() if err != nil { From fb476b011768367be51010c89754dcd23b383f5a Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Fri, 31 Jan 2020 13:29:03 +0300 Subject: [PATCH 10/31] linter --- dhcpd/network_utils.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/dhcpd/network_utils.go b/dhcpd/network_utils.go index 362378df..85fc261d 100644 --- a/dhcpd/network_utils.go +++ b/dhcpd/network_utils.go @@ -13,14 +13,6 @@ import ( "github.com/AdguardTeam/golibs/log" ) -type netInterface struct { - Name string - MTU int - HardwareAddr string - Addresses []string - Flags string -} - // getValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP // invalid interface is a ppp interface or the one that doesn't allow broadcasts func getValidNetInterfaces() ([]net.Interface, error) { From 8406d7d28827ce1ed9d9f6770ce1700681811535 Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Fri, 31 Jan 2020 16:52:22 +0300 Subject: [PATCH 11/31] - client: fix locales --- client/src/__locales/en.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 4c837637..3d93a7de 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -463,7 +463,7 @@ "client_confirm_unblock": "Are you sure you want to unblock the client \"{{ip}}\"?", "client_blocked": "Client \"{{ip}}\" successfully blocked", "client_unblocked": "Client \"{{ip}}\" successfully unblocked", - "set_static_ip": "Set static IP address", - "install_static_error": "We failed to determine if this network interface is configured by using static IP address. Please set a static IP address manually.", + "set_static_ip": "Set a static IP address", + "install_static_error": "We failed to determine if this network interface is configured by using a static IP address. Please set the static IP address manually.", "install_static_configure": "We have detected that a dynamic IP address is used — {{ip}}. We suggest you to set this IP address as static." } From d9001ff84852d708e400d039503141929e06d774 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Fri, 7 Feb 2020 16:17:59 +0300 Subject: [PATCH 12/31] *(documentation): updated openapi --- openapi/openapi.yaml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 9ad2c5d1..5f3539f0 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -1831,6 +1831,9 @@ definitions: $ref: "#/definitions/CheckConfigRequestInfo" web: $ref: "#/definitions/CheckConfigRequestInfo" + set_static_ip: + type: "boolean" + example: false CheckConfigRequestInfo: type: "object" properties: @@ -1851,6 +1854,8 @@ definitions: $ref: "#/definitions/CheckConfigResponseInfo" web: $ref: "#/definitions/CheckConfigResponseInfo" + static_ip: + $ref: "#/definitions/CheckConfigStaticIpInfo" CheckConfigResponseInfo: type: "object" properties: @@ -1860,6 +1865,23 @@ definitions: can_autofix: type: "boolean" example: false + CheckConfigStaticIpInfo: + type: "object" + properties: + static: + type: "string" + example: "no" + description: "Can be: yes, no, error" + ip: + type: "string" + example: "192.168.1.1" + description: "Current dynamic IP address. Set if static=no" + error: + type: "string" + example: "" + description: "Error text. Set if static=error" + + InitialConfiguration: type: "object" description: "AdGuard Home initial configuration (for the first-install wizard)" From cb12babc4698d048478570303af8955a35e8531d Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Fri, 7 Feb 2020 17:08:39 +0300 Subject: [PATCH 13/31] *: lower log level for some commands --- home/control_install.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/home/control_install.go b/home/control_install.go index 18fe2a8b..45cf0c7c 100644 --- a/home/control_install.go +++ b/home/control_install.go @@ -178,7 +178,7 @@ func checkDNSStubListener() bool { log.Tracef("executing %s %v", cmd.Path, cmd.Args) _, err := cmd.Output() if err != nil || cmd.ProcessState.ExitCode() != 0 { - log.Error("command %s has failed: %v code:%d", + log.Info("command %s has failed: %v code:%d", cmd.Path, err, cmd.ProcessState.ExitCode()) return false } @@ -187,7 +187,7 @@ func checkDNSStubListener() bool { log.Tracef("executing %s %v", cmd.Path, cmd.Args) _, err = cmd.Output() if err != nil || cmd.ProcessState.ExitCode() != 0 { - log.Error("command %s has failed: %v code:%d", + log.Info("command %s has failed: %v code:%d", cmd.Path, err, cmd.ProcessState.ExitCode()) return false } From 26c47e59dd63317bdb959cb416e7c1c0bfdf7dc1 Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Fri, 7 Feb 2020 18:08:56 +0300 Subject: [PATCH 14/31] - client: fix static ip description --- client/src/__locales/en.json | 7 ++- client/src/install/Setup/Settings.js | 75 ++++++++++++++++------------ 2 files changed, 47 insertions(+), 35 deletions(-) diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 3d93a7de..88442ac8 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -463,7 +463,10 @@ "client_confirm_unblock": "Are you sure you want to unblock the client \"{{ip}}\"?", "client_blocked": "Client \"{{ip}}\" successfully blocked", "client_unblocked": "Client \"{{ip}}\" successfully unblocked", + "static_ip": "Static IP Address", + "static_ip_desc": "AdGuard Home is a server so it needs a static IP address to function properly. Otherwise, at some point, your router may assign a different IP address to this device.", "set_static_ip": "Set a static IP address", - "install_static_error": "We failed to determine if this network interface is configured by using a static IP address. Please set the static IP address manually.", - "install_static_configure": "We have detected that a dynamic IP address is used — {{ip}}. We suggest you to set this IP address as static." + "install_static_ok": "Good news, Static IP already configured", + "install_static_error": "AdGuard Home cannot configure it automatically for your OS. Please look for an instruction on how to do this manually", + "install_static_configure": "We have detected that a dynamic IP address is used — <0>{{ip}}. Do you want to use your it as a static address?" } diff --git a/client/src/install/Setup/Settings.js b/client/src/install/Setup/Settings.js index 0b3120b2..22685ca5 100644 --- a/client/src/install/Setup/Settings.js +++ b/client/src/install/Setup/Settings.js @@ -7,7 +7,6 @@ import flow from 'lodash/flow'; import Controls from './Controls'; import AddressList from './AddressList'; -import Accordion from '../../components/ui/Accordion'; import { getInterfaceIp } from '../../helpers/helpers'; import { ALL_INTERFACES_IP } from '../../helpers/constants'; @@ -72,44 +71,41 @@ class Settings extends Component { } getStaticIpMessage = (staticIp, handleStaticIp) => { - const { static: status, ip, error } = staticIp; + const { static: status, ip } = staticIp; - if (!status || status === STATIC_STATUS.ENABLED) { + if (!status) { return ''; } return ( -
- + + {status === STATIC_STATUS.DISABLED && ( + + text]}> + install_static_configure + + + + )} + {status === STATIC_STATUS.ERROR && (
- {status === STATIC_STATUS.DISABLED && ( - - - install_static_configure - - - - )} - {status === STATIC_STATUS.ERROR && ( - - install_static_error -
- - {error} - -
-
- )} -
+ install_static_error
-
-
+ )} + {status === STATIC_STATUS.ENABLED && ( +
+ + install_static_ok + +
+ )} + ); }; @@ -208,6 +204,7 @@ class Settings extends Component { +
install_settings_dns @@ -248,7 +245,6 @@ class Settings extends Component { />
- {this.getStaticIpMessage(staticIp, handleStaticIp)}
{dnsStatus && @@ -292,6 +288,19 @@ class Settings extends Component {
+ +
+
+ static_ip +
+ +
+ static_ip_desc +
+ + {this.getStaticIpMessage(staticIp, handleStaticIp)} +
+ ); From 8d039c572f0e5f5245bd155a4e4d35400e6962c6 Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Fri, 7 Feb 2020 18:37:39 +0300 Subject: [PATCH 15/31] - client: fix button position --- client/src/install/Setup/Settings.js | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/client/src/install/Setup/Settings.js b/client/src/install/Setup/Settings.js index 22685ca5..66f1691e 100644 --- a/client/src/install/Setup/Settings.js +++ b/client/src/install/Setup/Settings.js @@ -81,12 +81,14 @@ class Settings extends Component { {status === STATIC_STATUS.DISABLED && ( - text]}> - install_static_configure - +
+ text]}> + install_static_configure + +
@@ -111,12 +111,48 @@ class Settings extends Component { ); }; + handleAutofix = (type) => { + const { + webIp, + webPort, + dnsIp, + dnsPort, + handleFix, + } = this.props; + + const web = { ip: webIp, port: webPort, autofix: false }; + const dns = { ip: dnsIp, port: dnsPort, autofix: false }; + const set_static_ip = false; + + if (type === 'web') { + web.autofix = true; + } else { + dns.autofix = true; + } + + handleFix(web, dns, set_static_ip); + }; + + handleStaticIp = () => { + const { + webIp, + webPort, + dnsIp, + dnsPort, + handleFix, + } = this.props; + + const web = { ip: webIp, port: webPort, autofix: false }; + const dns = { ip: dnsIp, port: dnsPort, autofix: false }; + const set_static_ip = true; + + handleFix(web, dns, set_static_ip); + }; + render() { const { handleSubmit, handleChange, - handleAutofix, - handleStaticIp, webIp, webPort, dnsIp, @@ -185,7 +221,7 @@ class Settings extends Component { @@ -256,7 +292,7 @@ class Settings extends Component { @@ -300,7 +336,7 @@ class Settings extends Component { static_ip_desc - {this.getStaticIpMessage(staticIp, handleStaticIp)} + {this.getStaticIpMessage(staticIp)} @@ -312,7 +348,7 @@ class Settings extends Component { Settings.propTypes = { handleSubmit: PropTypes.func.isRequired, handleChange: PropTypes.func, - handleAutofix: PropTypes.func, + handleFix: PropTypes.func.isRequired, validateForm: PropTypes.func, webIp: PropTypes.string.isRequired, dnsIp: PropTypes.string.isRequired, @@ -329,7 +365,6 @@ Settings.propTypes = { invalid: PropTypes.bool.isRequired, initialValues: PropTypes.object, t: PropTypes.func.isRequired, - handleStaticIp: PropTypes.func.isRequired, }; const selector = formValueSelector('install'); diff --git a/client/src/install/Setup/index.js b/client/src/install/Setup/index.js index 943fece9..82d8f84b 100644 --- a/client/src/install/Setup/index.js +++ b/client/src/install/Setup/index.js @@ -44,28 +44,8 @@ class Setup extends Component { } }, DEBOUNCE_TIMEOUT); - handleAutofix = (type, ip, port) => { - const data = { - ip, - port, - autofix: true, - }; - - if (type === 'web') { - this.props.checkConfig({ - web: { ...data }, - }); - } else { - this.props.checkConfig({ - dns: { ...data }, - }); - } - }; - - handleStaticIp = () => { - this.props.checkConfig({ - set_static_ip: true, - }); + handleFix = (web, dns, set_static_ip) => { + this.props.checkConfig({ web, dns, set_static_ip }); }; openDashboard = (ip, port) => { @@ -103,8 +83,7 @@ class Setup extends Component { onSubmit={this.nextStep} onChange={this.handleFormChange} validateForm={this.handleFormChange} - handleAutofix={this.handleAutofix} - handleStaticIp={this.handleStaticIp} + handleFix={this.handleFix} /> ); case 3: From 7afa16fbe76dff4485d166f6164bae171e0110c9 Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Thu, 13 Feb 2020 13:51:52 +0300 Subject: [PATCH 20/31] - client: fix text --- client/src/__locales/en.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 88442ac8..138082c7 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -466,7 +466,7 @@ "static_ip": "Static IP Address", "static_ip_desc": "AdGuard Home is a server so it needs a static IP address to function properly. Otherwise, at some point, your router may assign a different IP address to this device.", "set_static_ip": "Set a static IP address", - "install_static_ok": "Good news, Static IP already configured", + "install_static_ok": "Good news! The static IP address is already configured", "install_static_error": "AdGuard Home cannot configure it automatically for your OS. Please look for an instruction on how to do this manually", - "install_static_configure": "We have detected that a dynamic IP address is used — <0>{{ip}}. Do you want to use your it as a static address?" + "install_static_configure": "We have detected that a dynamic IP address is used — <0>{{ip}}. Do you want to use it as your static address?" } From f61d5f0f85a954120b2676a5153f10a05662cf42 Mon Sep 17 00:00:00 2001 From: Ildar Kamalov Date: Thu, 13 Feb 2020 14:13:35 +0300 Subject: [PATCH 21/31] + client: show confirm before setting static IP --- client/src/__locales/en.json | 3 ++- client/src/install/Setup/Settings.js | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 138082c7..e54c4a34 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -468,5 +468,6 @@ "set_static_ip": "Set a static IP address", "install_static_ok": "Good news! The static IP address is already configured", "install_static_error": "AdGuard Home cannot configure it automatically for your OS. Please look for an instruction on how to do this manually", - "install_static_configure": "We have detected that a dynamic IP address is used — <0>{{ip}}. Do you want to use it as your static address?" + "install_static_configure": "We have detected that a dynamic IP address is used — <0>{{ip}}. Do you want to use it as your static address?", + "confirm_static_ip": "AdGuard Home will configure {{ip}} to be your static IP address. Do you want to proceed?" } diff --git a/client/src/install/Setup/Settings.js b/client/src/install/Setup/Settings.js index 1c4da3fa..876aa05b 100644 --- a/client/src/install/Setup/Settings.js +++ b/client/src/install/Setup/Settings.js @@ -89,7 +89,7 @@ class Settings extends Component { @@ -133,7 +133,7 @@ class Settings extends Component { handleFix(web, dns, set_static_ip); }; - handleStaticIp = () => { + handleStaticIp = (ip) => { const { webIp, webPort, @@ -146,7 +146,9 @@ class Settings extends Component { const dns = { ip: dnsIp, port: dnsPort, autofix: false }; const set_static_ip = true; - handleFix(web, dns, set_static_ip); + if (window.confirm(this.props.t('confirm_static_ip', { ip }))) { + handleFix(web, dns, set_static_ip); + } }; render() { From c27852537d2f5ce62b16c43f4241a15d0fb8c9fd Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Thu, 13 Feb 2020 14:14:30 +0300 Subject: [PATCH 22/31] +(dhcpd): added static IP for MacOS --- client/src/__locales/en.json | 2 +- dhcpd/dhcp_http.go | 6 +- dhcpd/network_utils.go | 258 +++++++++++++----- dhcpd/network_utils_darwin.go | 8 - dhcpd/network_utils_test.go | 4 +- home/control.go | 125 ++++++++- home/control_install.go | 28 +- home/control_tls.go | 6 +- home/control_update.go | 4 +- home/control_update_test.go | 1 - home/filter.go | 3 +- home/filter_test.go | 7 +- home/helpers.go | 241 ---------------- home/home.go | 73 ++++- home/service.go | 3 +- home/upgrade.go | 14 +- home/whois.go | 4 +- util/helpers.go | 59 ++++ util/helpers_test.go | 14 + {home => util}/network_utils.go | 92 +++++-- .../network_utils_test.go | 15 +- {home => util}/os_freebsd.go | 6 +- {home => util}/os_unix.go | 6 +- {home => util}/os_windows.go | 6 +- {home => util}/syslog_others.go | 6 +- {home => util}/syslog_windows.go | 4 +- 26 files changed, 589 insertions(+), 406 deletions(-) delete mode 100644 dhcpd/network_utils_darwin.go delete mode 100644 home/helpers.go create mode 100644 util/helpers.go create mode 100644 util/helpers_test.go rename {home => util}/network_utils.go (52%) rename home/helpers_test.go => util/network_utils_test.go (52%) rename {home => util}/os_freebsd.go (86%) rename {home => util}/os_unix.go (87%) rename {home => util}/os_windows.go (87%) rename {home => util}/syslog_others.go (62%) rename {home => util}/syslog_windows.go (94%) diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 138082c7..a4f8279c 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -467,6 +467,6 @@ "static_ip_desc": "AdGuard Home is a server so it needs a static IP address to function properly. Otherwise, at some point, your router may assign a different IP address to this device.", "set_static_ip": "Set a static IP address", "install_static_ok": "Good news! The static IP address is already configured", - "install_static_error": "AdGuard Home cannot configure it automatically for your OS. Please look for an instruction on how to do this manually", + "install_static_error": "AdGuard Home cannot configure it automatically for this network interface. Please look for an instruction on how to do this manually.", "install_static_configure": "We have detected that a dynamic IP address is used — <0>{{ip}}. Do you want to use it as your static address?" } diff --git a/dhcpd/dhcp_http.go b/dhcpd/dhcp_http.go index 6031b9a8..9105d76b 100644 --- a/dhcpd/dhcp_http.go +++ b/dhcpd/dhcp_http.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" ) @@ -121,7 +123,7 @@ type netInterfaceJSON struct { func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { response := map[string]interface{}{} - ifaces, err := GetValidNetInterfaces() + ifaces, err := util.GetValidNetInterfaces() if err != nil { httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) return @@ -219,7 +221,7 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque staticIP["error"] = err.Error() } else if !isStaticIP { staticIPStatus = "no" - staticIP["ip"] = GetFullIP(interfaceName) + staticIP["ip"] = util.GetSubnet(interfaceName) } staticIP["static"] = staticIPStatus diff --git a/dhcpd/network_utils.go b/dhcpd/network_utils.go index 204105ae..8407e031 100644 --- a/dhcpd/network_utils.go +++ b/dhcpd/network_utils.go @@ -6,37 +6,17 @@ import ( "io/ioutil" "net" "os/exec" + "regexp" "runtime" "strings" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" ) -// GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP -// invalid interface is a ppp interface or the one that doesn't allow broadcasts -func GetValidNetInterfaces() ([]net.Interface, error) { - ifaces, err := net.Interfaces() - if err != nil { - return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err) - } - - netIfaces := []net.Interface{} - - for i := range ifaces { - if ifaces[i].Flags&net.FlagPointToPoint != 0 { - // this interface is ppp, we're not interested in this one - continue - } - - iface := ifaces[i] - netIfaces = append(netIfaces, iface) - } - - return netIfaces, nil -} - // Check if network interface has a static IP configured // Supports: Raspbian. func HasStaticIP(ifaceName string) (bool, error) { @@ -56,54 +36,18 @@ func HasStaticIP(ifaceName string) (bool, error) { return false, fmt.Errorf("Cannot check if IP is static: not supported on %s", runtime.GOOS) } -// Get IP address with netmask -func GetFullIP(ifaceName string) string { - cmd := exec.Command("ip", "-oneline", "-family", "inet", "address", "show", ifaceName) - log.Tracef("executing %s %v", cmd.Path, cmd.Args) - d, err := cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - return "" - } - - fields := strings.Fields(string(d)) - if len(fields) < 4 { - return "" - } - _, _, err = net.ParseCIDR(fields[3]) - if err != nil { - return "" - } - - return fields[3] -} - -// Set a static IP for network interface -// Supports: Raspbian. +// Set a static IP for the specified network interface func SetStaticIP(ifaceName string) error { - ip := GetFullIP(ifaceName) - if len(ip) == 0 { - return errors.New("Can't get IP address") + if runtime.GOOS == "linux" { + return setStaticIPDhcpdConf(ifaceName) } - ip4, _, err := net.ParseCIDR(ip) - if err != nil { - return err - } - gatewayIP := getGatewayIP(ifaceName) - add := setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, ip4.String()) - - body, err := ioutil.ReadFile("/etc/dhcpcd.conf") - if err != nil { - return err + if runtime.GOOS == "darwin" { + return fmt.Errorf("cannot do that") + // return setStaticIPDarwin(ifaceName) } - body = append(body, []byte(add)...) - err = file.SafeWrite("/etc/dhcpcd.conf", body) - if err != nil { - return err - } - - return nil + return fmt.Errorf("Cannot set static IP on %s", runtime.GOOS) } // for dhcpcd.conf @@ -167,8 +111,37 @@ func getGatewayIP(ifaceName string) string { return fields[2] } +// setStaticIPDhcpdConf - updates /etc/dhcpd.conf and sets the current IP address to be static +func setStaticIPDhcpdConf(ifaceName string) error { + ip := util.GetSubnet(ifaceName) + if len(ip) == 0 { + return errors.New("Can't get IP address") + } + + ip4, _, err := net.ParseCIDR(ip) + if err != nil { + return err + } + gatewayIP := getGatewayIP(ifaceName) + add := updateStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, ip4.String()) + + body, err := ioutil.ReadFile("/etc/dhcpcd.conf") + if err != nil { + return err + } + + body = append(body, []byte(add)...) + err = file.SafeWrite("/etc/dhcpcd.conf", body) + if err != nil { + return err + } + + return nil +} + +// updates dhcpd.conf content -- sets static IP address there // for dhcpcd.conf -func setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string { +func updateStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string { var body []byte add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n", @@ -187,3 +160,154 @@ func setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string { return string(body) } + +// Check if network interface has a static IP configured +// Supports: MacOS. +func hasStaticIPDarwin(ifaceName string) (bool, error) { + portInfo, err := getCurrentHardwarePortInfo(ifaceName) + if err != nil { + return false, err + } + + return portInfo.static, nil +} + +// setStaticIPDarwin - uses networksetup util to set the current IP address to be static +// Additionally it configures the current DNS servers as well +func setStaticIPDarwin(ifaceName string) error { + portInfo, err := getCurrentHardwarePortInfo(ifaceName) + if err != nil { + return err + } + + if portInfo.static { + return errors.New("IP address is already static") + } + + dnsAddrs, err := getEtcResolvConfServers() + if err != nil { + return err + } + + args := make([]string, 0) + args = append(args, "-setdnsservers") + args = append(args, dnsAddrs...) + + // Setting DNS servers is necessary when configuring a static IP + code, _, err := util.RunCommand("networksetup", args...) + if err != nil { + return err + } + if code != 0 { + return fmt.Errorf("Failed to set DNS servers, code=%d", code) + } + + // Actually configures hardware port to have static IP + code, _, err = util.RunCommand("networksetup", "-setmanual", + portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP) + if err != nil { + return err + } + if code != 0 { + return fmt.Errorf("Failed to set DNS servers, code=%d", code) + } + + return nil +} + +// getCurrentHardwarePortInfo gets information the specified network interface +func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { + // First of all we should find hardware port name + m := getNetworkSetupHardwareReports() + hardwarePort, ok := m[ifaceName] + if !ok { + return hardwarePortInfo{}, fmt.Errorf("Could not find hardware port for %s", ifaceName) + } + + return getHardwarePortInfo(hardwarePort) +} + +// getNetworkSetupHardwareReports parses the output of the `networksetup -listallhardwareports` command +// it returns a map where the key is the interface name, and the value is the "hardware port" +// returns nil if it fails to parse the output +func getNetworkSetupHardwareReports() map[string]string { + _, out, err := util.RunCommand("networksetup", "-listallhardwareports") + if err != nil { + return nil + } + + re, err := regexp.Compile("Hardware Port: (.*?)\nDevice: (.*?)\n") + if err != nil { + return nil + } + + m := make(map[string]string, 0) + + matches := re.FindAllStringSubmatch(out, -1) + for i := range matches { + port := matches[i][1] + device := matches[i][2] + m[device] = port + } + + return m +} + +// hardwarePortInfo - information obtained using MacOS networksetup +// about the current state of the internet connection +type hardwarePortInfo struct { + name string + ip string + subnet string + gatewayIP string + static bool +} + +func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) { + h := hardwarePortInfo{} + + _, out, err := util.RunCommand("networksetup", "-getinfo", hardwarePort) + if err != nil { + return h, err + } + + re := regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n") + + match := re.FindStringSubmatch(out) + if len(match) == 0 { + return h, errors.New("Could not find hardware port info") + } + + h.name = hardwarePort + h.ip = match[1] + h.subnet = match[2] + h.gatewayIP = match[3] + + if strings.Index(out, "Manual Configuration") == 0 { + h.static = true + } + + return h, nil +} + +// Gets a list of nameservers currently configured in the /etc/resolv.conf +func getEtcResolvConfServers() ([]string, error) { + body, err := ioutil.ReadFile("/etc/resolv.conf") + if err != nil { + return nil, err + } + + re := regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)") + + matches := re.FindAllStringSubmatch(string(body), -1) + if len(matches) == 0 { + return nil, errors.New("Found no DNS servers in /etc/resolv.conf") + } + + addrs := make([]string, 0) + for i := range matches { + addrs = append(addrs, matches[i][1]) + } + + return addrs, nil +} diff --git a/dhcpd/network_utils_darwin.go b/dhcpd/network_utils_darwin.go deleted file mode 100644 index 9110b0b0..00000000 --- a/dhcpd/network_utils_darwin.go +++ /dev/null @@ -1,8 +0,0 @@ -package dhcpd - -// Check if network interface has a static IP configured -// Supports: Raspbian. -func hasStaticIPDarwin(ifaceName string) (bool, error) { - - return false, nil -} diff --git a/dhcpd/network_utils_test.go b/dhcpd/network_utils_test.go index 3948a224..2957a411 100644 --- a/dhcpd/network_utils_test.go +++ b/dhcpd/network_utils_test.go @@ -46,7 +46,7 @@ static routers=192.168.0.1 static domain_name_servers=192.168.0.2 ` - s := setStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2") + s := updateStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2") assert.Equal(t, dhcpcdConf, s) // without gateway @@ -56,6 +56,6 @@ static ip_address=192.168.0.2/24 static domain_name_servers=192.168.0.2 ` - s = setStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2") + s = updateStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2") assert.Equal(t, dhcpcdConf, s) } diff --git a/home/control.go b/home/control.go index d852ef66..031f7a36 100644 --- a/home/control.go +++ b/home/control.go @@ -3,7 +3,13 @@ package home import ( "encoding/json" "fmt" + "net" "net/http" + "net/url" + "strconv" + "strings" + + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/golibs/log" @@ -54,8 +60,7 @@ func getDNSAddresses() []string { dnsAddresses := []string{} if config.DNS.BindHost == "0.0.0.0" { - - ifaces, e := getValidNetInterfacesForWeb() + ifaces, e := util.GetValidNetInterfacesForWeb() if e != nil { log.Error("Couldn't get network interfaces: %v", e) return []string{} @@ -66,7 +71,6 @@ func getDNSAddresses() []string { addDNSAddress(&dnsAddresses, addr) } } - } else { addDNSAddress(&dnsAddresses, config.DNS.BindHost) } @@ -180,3 +184,118 @@ func registerControlHandlers() { func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) { http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler))))) } + +// ---------------------------------- +// helper functions for HTTP handlers +// ---------------------------------- +func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + log.Debug("%s %v", r.Method, r.URL) + + if r.Method != method { + http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed) + return + } + + if method == "POST" || method == "PUT" || method == "DELETE" { + Context.controlLock.Lock() + defer Context.controlLock.Unlock() + } + + handler(w, r) + } +} + +func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return ensure("POST", handler) +} + +func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return ensure("GET", handler) +} + +// Bridge between http.Handler object and Go function +type httpHandler struct { + handler func(http.ResponseWriter, *http.Request) +} + +func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.handler(w, r) +} + +func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler { + h := httpHandler{} + h.handler = ensure(method, handler) + return &h +} + +// preInstall lets the handler run only if firstRun is true, no redirects +func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if !Context.firstRun { + // if it's not first run, don't let users access it (for example /install.html when configuration is done) + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return + } + handler(w, r) + } +} + +// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary +type preInstallHandlerStruct struct { + handler http.Handler +} + +func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { + preInstall(p.handler.ServeHTTP)(w, r) +} + +// preInstallHandler returns http.Handler interface for preInstall wrapper +func preInstallHandler(handler http.Handler) http.Handler { + return &preInstallHandlerStruct{handler} +} + +// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise +// it also enforces HTTPS if it is enabled and configured +func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if Context.firstRun && + !strings.HasPrefix(r.URL.Path, "/install.") && + r.URL.Path != "/favicon.png" { + http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable + return + } + // enforce https? + if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil { + // yes, and we want host from host:port + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + // no port in host + host = r.Host + } + // construct new URL to redirect to + newURL := url.URL{ + Scheme: "https", + Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)), + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + } + http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect) + return + } + w.Header().Set("Access-Control-Allow-Origin", "*") + handler(w, r) + } +} + +type postInstallHandlerStruct struct { + handler http.Handler +} + +func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { + postInstall(p.handler.ServeHTTP)(w, r) +} + +func postInstallHandler(handler http.Handler) http.Handler { + return &postInstallHandlerStruct{handler} +} diff --git a/home/control_install.go b/home/control_install.go index 8a595bb8..50d6e7c1 100644 --- a/home/control_install.go +++ b/home/control_install.go @@ -13,6 +13,8 @@ import ( "runtime" "strconv" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/golibs/log" @@ -38,7 +40,7 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { data.WebPort = 80 data.DNSPort = 53 - ifaces, err := getValidNetInterfacesForWeb() + ifaces, err := util.GetValidNetInterfacesForWeb() if err != nil { httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) return @@ -101,16 +103,16 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { } if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort { - err = checkPortAvailable(reqData.Web.IP, reqData.Web.Port) + err = util.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port) if err != nil { respData.Web.Status = fmt.Sprintf("%v", err) } } if reqData.DNS.Port != 0 { - err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) - if errorIsAddrInUse(err) { + if util.ErrorIsAddrInUse(err) { canAutofix := checkDNSStubListener() if canAutofix && reqData.DNS.Autofix { @@ -119,7 +121,7 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { log.Error("Couldn't disable DNSStubListener: %s", err) } - err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) canAutofix = false } @@ -127,26 +129,22 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { } if err == nil { - err = checkPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = util.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port) } if err != nil { respData.DNS.Status = fmt.Sprintf("%v", err) - } else { - - interfaceName := getInterfaceByIP(reqData.DNS.IP) + interfaceName := util.GetInterfaceByIP(reqData.DNS.IP) staticIPStatus := "yes" if len(interfaceName) == 0 { staticIPStatus = "error" respData.StaticIP.Error = fmt.Sprintf("Couldn't find network interface by IP %s", reqData.DNS.IP) - } else if reqData.DNS.SetStaticIP { err = dhcpd.SetStaticIP(interfaceName) staticIPStatus = "error" respData.StaticIP.Error = err.Error() - } else { // check if we have a static IP isStaticIP, err := dhcpd.HasStaticIP(interfaceName) @@ -155,7 +153,7 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { respData.StaticIP.Error = err.Error() } else if !isStaticIP { staticIPStatus = "no" - respData.StaticIP.IP = dhcpd.GetFullIP(interfaceName) + respData.StaticIP.IP = util.GetSubnet(interfaceName) } } respData.StaticIP.Static = staticIPStatus @@ -279,7 +277,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { // validate that hosts and ports are bindable if restartHTTP { - err = checkPortAvailable(newSettings.Web.IP, newSettings.Web.Port) + err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port) if err != nil { httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s", net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err) @@ -287,13 +285,13 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { } } - err = checkPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) + err = util.CheckPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return } - err = checkPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) + err = util.CheckPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return diff --git a/home/control_tls.go b/home/control_tls.go index f0f4c655..0df8b729 100644 --- a/home/control_tls.go +++ b/home/control_tls.go @@ -20,6 +20,8 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" ) @@ -84,7 +86,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) { alreadyRunning = true } if !alreadyRunning { - err = checkPortAvailable(config.BindHost, data.PortHTTPS) + err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS) if err != nil { httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) return @@ -114,7 +116,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { alreadyRunning = true } if !alreadyRunning { - err = checkPortAvailable(config.BindHost, data.PortHTTPS) + err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS) if err != nil { httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) return diff --git a/home/control_update.go b/home/control_update.go index 10e94e50..87fe4034 100644 --- a/home/control_update.go +++ b/home/control_update.go @@ -17,6 +17,8 @@ import ( "syscall" "time" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" ) @@ -196,7 +198,7 @@ func getUpdateInfo(jsonData []byte) (*updateInfo, error) { binName = "AdGuardHome.exe" } u.curBinName = filepath.Join(workDir, binName) - if !fileExists(u.curBinName) { + if !util.FileExists(u.curBinName) { return nil, fmt.Errorf("Executable file %s doesn't exist", u.curBinName) } u.bkpBinName = filepath.Join(u.backupDir, binName) diff --git a/home/control_update_test.go b/home/control_update_test.go index cfc13823..6ec4a186 100644 --- a/home/control_update_test.go +++ b/home/control_update_test.go @@ -8,7 +8,6 @@ import ( ) func TestDoUpdate(t *testing.T) { - config.DNS.Port = 0 Context.workDir = "..." // set absolute path newver := "v0.96" diff --git a/home/filter.go b/home/filter.go index 9e0a5abf..befdd873 100644 --- a/home/filter.go +++ b/home/filter.go @@ -13,6 +13,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" ) @@ -401,7 +402,7 @@ func parseFilterContents(contents []byte) (int, string) { // Count lines in the filter for len(data) != 0 { - line := SplitNext(&data, '\n') + line := util.SplitNext(&data, '\n') if len(line) == 0 { continue } diff --git a/home/filter_test.go b/home/filter_test.go index fe1826d2..edda556a 100644 --- a/home/filter_test.go +++ b/home/filter_test.go @@ -10,6 +10,11 @@ import ( ) func TestFilters(t *testing.T) { + dir := prepareTestDir() + defer func() { _ = os.RemoveAll(dir) }() + + Context = homeContext{} + Context.workDir = dir Context.client = &http.Client{ Timeout: time.Minute * 5, } @@ -33,5 +38,5 @@ func TestFilters(t *testing.T) { assert.True(t, err == nil) f.unload() - os.Remove(f.Path()) + _ = os.Remove(f.Path()) } diff --git a/home/helpers.go b/home/helpers.go deleted file mode 100644 index 5a87aec7..00000000 --- a/home/helpers.go +++ /dev/null @@ -1,241 +0,0 @@ -package home - -import ( - "context" - "fmt" - "net" - "net/http" - "net/url" - "os" - "os/exec" - "path" - "path/filepath" - "runtime" - "strconv" - "strings" - "time" - - "github.com/AdguardTeam/golibs/log" - "github.com/joomcode/errorx" -) - -// ---------------------------------- -// helper functions for HTTP handlers -// ---------------------------------- -func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - log.Debug("%s %v", r.Method, r.URL) - - if r.Method != method { - http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed) - return - } - - if method == "POST" || method == "PUT" || method == "DELETE" { - Context.controlLock.Lock() - defer Context.controlLock.Unlock() - } - - handler(w, r) - } -} - -func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - return ensure("POST", handler) -} - -func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - return ensure("GET", handler) -} - -// Bridge between http.Handler object and Go function -type httpHandler struct { - handler func(http.ResponseWriter, *http.Request) -} - -func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - h.handler(w, r) -} - -func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler { - h := httpHandler{} - h.handler = ensure(method, handler) - return &h -} - -// ------------------- -// first run / install -// ------------------- -func detectFirstRun() bool { - configfile := Context.configFilename - if !filepath.IsAbs(configfile) { - configfile = filepath.Join(Context.workDir, Context.configFilename) - } - _, err := os.Stat(configfile) - if !os.IsNotExist(err) { - // do nothing, file exists - return false - } - return true -} - -// preInstall lets the handler run only if firstRun is true, no redirects -func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if !Context.firstRun { - // if it's not first run, don't let users access it (for example /install.html when configuration is done) - http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) - return - } - handler(w, r) - } -} - -// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary -type preInstallHandlerStruct struct { - handler http.Handler -} - -func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { - preInstall(p.handler.ServeHTTP)(w, r) -} - -// preInstallHandler returns http.Handler interface for preInstall wrapper -func preInstallHandler(handler http.Handler) http.Handler { - return &preInstallHandlerStruct{handler} -} - -// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise -// it also enforces HTTPS if it is enabled and configured -func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if Context.firstRun && - !strings.HasPrefix(r.URL.Path, "/install.") && - r.URL.Path != "/favicon.png" { - http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable - return - } - // enforce https? - if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil { - // yes, and we want host from host:port - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - // no port in host - host = r.Host - } - // construct new URL to redirect to - newURL := url.URL{ - Scheme: "https", - Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)), - Path: r.URL.Path, - RawQuery: r.URL.RawQuery, - } - http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect) - return - } - w.Header().Set("Access-Control-Allow-Origin", "*") - handler(w, r) - } -} - -type postInstallHandlerStruct struct { - handler http.Handler -} - -func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) { - postInstall(p.handler.ServeHTTP)(w, r) -} - -func postInstallHandler(handler http.Handler) http.Handler { - return &postInstallHandlerStruct{handler} -} - -// Connect to a remote server resolving hostname using our own DNS server -func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { - log.Tracef("network:%v addr:%v", network, addr) - - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - - dialer := &net.Dialer{ - Timeout: time.Minute * 5, - } - - if net.ParseIP(host) != nil || config.DNS.Port == 0 { - con, err := dialer.DialContext(ctx, network, addr) - return con, err - } - - addrs, e := Context.dnsServer.Resolve(host) - log.Debug("dnsServer.Resolve: %s: %v", host, addrs) - if e != nil { - return nil, e - } - - if len(addrs) == 0 { - return nil, fmt.Errorf("couldn't lookup host: %s", host) - } - - var dialErrs []error - for _, a := range addrs { - addr = net.JoinHostPort(a.String(), port) - con, err := dialer.DialContext(ctx, network, addr) - if err != nil { - dialErrs = append(dialErrs, err) - continue - } - return con, err - } - return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...) -} - -// --------------------- -// general helpers -// --------------------- - -// fileExists returns TRUE if file exists -func fileExists(fn string) bool { - _, err := os.Stat(fn) - if err != nil { - return false - } - return true -} - -// runCommand runs shell command -func runCommand(command string, arguments ...string) (int, string, error) { - cmd := exec.Command(command, arguments...) - out, err := cmd.Output() - if err != nil { - return 1, "", fmt.Errorf("exec.Command(%s) failed: %s", command, err) - } - - return cmd.ProcessState.ExitCode(), string(out), nil -} - -// --------------------- -// debug logging helpers -// --------------------- -func _Func() string { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - return path.Base(f.Name()) -} - -// SplitNext - split string by a byte and return the first chunk -// Whitespace is trimmed -func SplitNext(str *string, splitBy byte) string { - i := strings.IndexByte(*str, splitBy) - s := "" - if i != -1 { - s = (*str)[0:i] - *str = (*str)[i+1:] - } else { - s = *str - *str = "" - } - return strings.TrimSpace(s) -} diff --git a/home/home.go b/home/home.go index b31f1044..f7c697f6 100644 --- a/home/home.go +++ b/home/home.go @@ -20,6 +20,10 @@ import ( "syscall" "time" + "github.com/AdguardTeam/AdGuardHome/util" + + "github.com/joomcode/errorx" + "github.com/AdguardTeam/AdGuardHome/isdelve" "github.com/AdguardTeam/AdGuardHome/dhcpd" @@ -193,7 +197,7 @@ func run(args options) { if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && config.RlimitNoFile != 0 { - setRlimit(config.RlimitNoFile) + util.SetRlimit(config.RlimitNoFile) } // override bind host/port from the console @@ -327,7 +331,7 @@ func httpServerLoop() { // Check if the current user has root (administrator) rights // and if not, ask and try to run as root func requireAdminRights() { - admin, _ := haveAdminRights() + admin, _ := util.HaveAdminRights() if //noinspection ALL admin || isdelve.Enabled { return @@ -412,7 +416,7 @@ func configureLogger(args options) { if ls.LogFile == configSyslog { // Use syslog where it is possible and eventlog on Windows - err := configureSyslog() + err := util.ConfigureSyslog(serviceName) if err != nil { log.Fatalf("cannot initialize syslog: %s", err) } @@ -448,9 +452,9 @@ func stopHTTPServer() { log.Info("Stopping HTTP server...") Context.httpsServer.shutdown = true if Context.httpsServer.server != nil { - Context.httpsServer.server.Shutdown(context.TODO()) + _ = Context.httpsServer.server.Shutdown(context.TODO()) } - Context.httpServer.Shutdown(context.TODO()) + _ = Context.httpServer.Shutdown(context.TODO()) log.Info("Stopped HTTP server") } @@ -580,7 +584,7 @@ func printHTTPAddresses(proto string) { } } else if config.BindHost == "0.0.0.0" { log.Println("AdGuard Home is available on the following addresses:") - ifaces, err := getValidNetInterfacesForWeb() + ifaces, err := util.GetValidNetInterfacesForWeb() if err != nil { // That's weird, but we'll ignore it address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) @@ -597,3 +601,60 @@ func printHTTPAddresses(proto string) { log.Printf("Go to %s://%s", proto, address) } } + +// ------------------- +// first run / install +// ------------------- +func detectFirstRun() bool { + configfile := Context.configFilename + if !filepath.IsAbs(configfile) { + configfile = filepath.Join(Context.workDir, Context.configFilename) + } + _, err := os.Stat(configfile) + if !os.IsNotExist(err) { + // do nothing, file exists + return false + } + return true +} + +// Connect to a remote server resolving hostname using our own DNS server +func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { + log.Tracef("network:%v addr:%v", network, addr) + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + dialer := &net.Dialer{ + Timeout: time.Minute * 5, + } + + if net.ParseIP(host) != nil || config.DNS.Port == 0 { + con, err := dialer.DialContext(ctx, network, addr) + return con, err + } + + addrs, e := Context.dnsServer.Resolve(host) + log.Debug("dnsServer.Resolve: %s: %v", host, addrs) + if e != nil { + return nil, e + } + + if len(addrs) == 0 { + return nil, fmt.Errorf("couldn't lookup host: %s", host) + } + + var dialErrs []error + for _, a := range addrs { + addr = net.JoinHostPort(a.String(), port) + con, err := dialer.DialContext(ctx, network, addr) + if err != nil { + dialErrs = append(dialErrs, err) + continue + } + return con, err + } + return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...) +} diff --git a/home/service.go b/home/service.go index 1951ed28..d066e118 100644 --- a/home/service.go +++ b/home/service.go @@ -7,6 +7,7 @@ import ( "strings" "syscall" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/log" "github.com/kardianos/service" ) @@ -229,7 +230,7 @@ func configureService(c *service.Config) { // returns command code or error if any func runInitdCommand(action string) (int, error) { confPath := "/etc/init.d/" + serviceName - code, _, err := runCommand("sh", "-c", confPath+" "+action) + code, _, err := util.RunCommand("sh", "-c", confPath+" "+action) return code, err } diff --git a/home/upgrade.go b/home/upgrade.go index 7336f846..b2936a3e 100644 --- a/home/upgrade.go +++ b/home/upgrade.go @@ -5,6 +5,8 @@ import ( "os" "path/filepath" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" "golang.org/x/crypto/bcrypt" @@ -114,7 +116,7 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err // The first schema upgrade: // No more "dnsfilter.txt", filters are now kept in data/filters/ func upgradeSchema0to1(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt") if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) { @@ -135,7 +137,7 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error { // coredns is now dns in config // delete 'Corefile', since we don't use that anymore func upgradeSchema1to2(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) coreFilePath := filepath.Join(Context.workDir, "Corefile") if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) { @@ -159,7 +161,7 @@ func upgradeSchema1to2(diskConfig *map[string]interface{}) error { // Third schema upgrade: // Bootstrap DNS becomes an array func upgradeSchema2to3(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) // Let's read dns configuration from diskConfig dnsConfig, ok := (*diskConfig)["dns"] @@ -196,7 +198,7 @@ func upgradeSchema2to3(diskConfig *map[string]interface{}) error { // Add use_global_blocked_services=true setting for existing "clients" array func upgradeSchema3to4(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) (*diskConfig)["schema_version"] = 4 @@ -233,7 +235,7 @@ func upgradeSchema3to4(diskConfig *map[string]interface{}) error { // password: "..." // ... func upgradeSchema4to5(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) (*diskConfig)["schema_version"] = 5 @@ -288,7 +290,7 @@ func upgradeSchema4to5(diskConfig *map[string]interface{}) error { // - 127.0.0.1 // - ... func upgradeSchema5to6(diskConfig *map[string]interface{}) error { - log.Printf("%s(): called", _Func()) + log.Printf("%s(): called", util.FuncName()) (*diskConfig)["schema_version"] = 6 diff --git a/home/whois.go b/home/whois.go index 25fe211a..321b4ef2 100644 --- a/home/whois.go +++ b/home/whois.go @@ -8,6 +8,8 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" ) @@ -61,7 +63,7 @@ func whoisParse(data string) map[string]string { descr := "" netname := "" for len(data) != 0 { - ln := SplitNext(&data, '\n') + ln := util.SplitNext(&data, '\n') if len(ln) == 0 || ln[0] == '#' || ln[0] == '%' { continue } diff --git a/util/helpers.go b/util/helpers.go new file mode 100644 index 00000000..730e0a2c --- /dev/null +++ b/util/helpers.go @@ -0,0 +1,59 @@ +package util + +import ( + "fmt" + "os" + "os/exec" + "path" + "runtime" + "strings" +) + +// --------------------- +// general helpers +// --------------------- + +// fileExists returns TRUE if file exists +func FileExists(fn string) bool { + _, err := os.Stat(fn) + if err != nil { + return false + } + return true +} + +// runCommand runs shell command +func RunCommand(command string, arguments ...string) (int, string, error) { + cmd := exec.Command(command, arguments...) + out, err := cmd.Output() + if err != nil { + return 1, "", fmt.Errorf("exec.Command(%s) failed: %s", command, err) + } + + return cmd.ProcessState.ExitCode(), string(out), nil +} + +// --------------------- +// debug logging helpers +// --------------------- +func FuncName() string { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + return path.Base(f.Name()) +} + +// SplitNext - split string by a byte and return the first chunk +// Whitespace is trimmed +func SplitNext(str *string, splitBy byte) string { + i := strings.IndexByte(*str, splitBy) + s := "" + if i != -1 { + s = (*str)[0:i] + *str = (*str)[i+1:] + } else { + s = *str + *str = "" + } + return strings.TrimSpace(s) +} diff --git a/util/helpers_test.go b/util/helpers_test.go new file mode 100644 index 00000000..d5e90637 --- /dev/null +++ b/util/helpers_test.go @@ -0,0 +1,14 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSplitNext(t *testing.T) { + s := " a,b , c " + assert.True(t, SplitNext(&s, ',') == "a") + assert.True(t, SplitNext(&s, ',') == "b") + assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0) +} diff --git a/home/network_utils.go b/util/network_utils.go similarity index 52% rename from home/network_utils.go rename to util/network_utils.go index 4277c1c2..af410201 100644 --- a/home/network_utils.go +++ b/util/network_utils.go @@ -1,4 +1,4 @@ -package home +package util import ( "errors" @@ -10,23 +10,48 @@ import ( "syscall" "time" - "github.com/AdguardTeam/AdGuardHome/dhcpd" + "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" ) -type netInterface struct { - Name string - MTU int - HardwareAddr string - Addresses []string - Flags string +// NetInterface represents a list of network interfaces +type NetInterface struct { + Name string // Network interface name + MTU int // MTU + HardwareAddr string // Hardware address + Addresses []string // Array with the network interface addresses + Subnets []string // Array with CIDR addresses of this network interface + Flags string // Network interface flags (up, broadcast, etc) +} + +// GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP +// invalid interface is a ppp interface or the one that doesn't allow broadcasts +func GetValidNetInterfaces() ([]net.Interface, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err) + } + + netIfaces := []net.Interface{} + + for i := range ifaces { + if ifaces[i].Flags&net.FlagPointToPoint != 0 { + // this interface is ppp, we're not interested in this one + continue + } + + iface := ifaces[i] + netIfaces = append(netIfaces, iface) + } + + return netIfaces, nil } // getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only // we do not return link-local addresses here -func getValidNetInterfacesForWeb() ([]netInterface, error) { - ifaces, err := dhcpd.GetValidNetInterfaces() +func GetValidNetInterfacesForWeb() ([]NetInterface, error) { + ifaces, err := GetValidNetInterfaces() if err != nil { return nil, errorx.Decorate(err, "Couldn't get interfaces") } @@ -34,7 +59,7 @@ func getValidNetInterfacesForWeb() ([]netInterface, error) { return nil, errors.New("couldn't find any legible interface") } - var netInterfaces []netInterface + var netInterfaces []NetInterface for _, iface := range ifaces { addrs, e := iface.Addrs() @@ -42,7 +67,7 @@ func getValidNetInterfacesForWeb() ([]netInterface, error) { return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name) } - netIface := netInterface{ + netIface := NetInterface{ Name: iface.Name, MTU: iface.MTU, HardwareAddr: iface.HardwareAddr.String(), @@ -52,19 +77,26 @@ func getValidNetInterfacesForWeb() ([]netInterface, error) { netIface.Flags = iface.Flags.String() } - // we don't want link-local addresses in json, so skip them + // Collect network interface addresses for _, addr := range addrs { - ipnet, ok := addr.(*net.IPNet) + ipNet, ok := addr.(*net.IPNet) if !ok { // not an IPNet, should not happen return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) } // ignore link-local - if ipnet.IP.IsLinkLocalUnicast() { + if ipNet.IP.IsLinkLocalUnicast() { continue } - netIface.Addresses = append(netIface.Addresses, ipnet.IP.String()) + // ignore IPv6 + if ipNet.IP.To4() == nil { + continue + } + netIface.Addresses = append(netIface.Addresses, ipNet.IP.String()) + netIface.Subnets = append(netIface.Subnets, ipNet.String()) } + + // Discard interfaces with no addresses if len(netIface.Addresses) != 0 { netInterfaces = append(netInterfaces, netIface) } @@ -74,8 +106,8 @@ func getValidNetInterfacesForWeb() ([]netInterface, error) { } // Get interface name by its IP address. -func getInterfaceByIP(ip string) string { - ifaces, err := getValidNetInterfacesForWeb() +func GetInterfaceByIP(ip string) string { + ifaces, err := GetValidNetInterfacesForWeb() if err != nil { return "" } @@ -91,8 +123,26 @@ func getInterfaceByIP(ip string) string { return "" } +// Get IP address with netmask for the specified interface +// Returns an empty string if it fails to find it +func GetSubnet(ifaceName string) string { + netIfaces, err := GetValidNetInterfacesForWeb() + if err != nil { + log.Error("Could not get network interfaces info: %v", err) + return "" + } + + for _, netIface := range netIfaces { + if netIface.Name == ifaceName && len(netIface.Subnets) > 0 { + return netIface.Subnets[0] + } + } + + return "" +} + // checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily -func checkPortAvailable(host string, port int) error { +func CheckPortAvailable(host string, port int) error { ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) if err != nil { return err @@ -105,7 +155,7 @@ func checkPortAvailable(host string, port int) error { return nil } -func checkPacketPortAvailable(host string, port int) error { +func CheckPacketPortAvailable(host string, port int) error { ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port))) if err != nil { return err @@ -119,7 +169,7 @@ func checkPacketPortAvailable(host string, port int) error { } // check if error is "address already in use" -func errorIsAddrInUse(err error) bool { +func ErrorIsAddrInUse(err error) bool { errOpError, ok := err.(*net.OpError) if !ok { return false diff --git a/home/helpers_test.go b/util/network_utils_test.go similarity index 52% rename from home/helpers_test.go rename to util/network_utils_test.go index c2095966..7feac0f2 100644 --- a/home/helpers_test.go +++ b/util/network_utils_test.go @@ -1,14 +1,12 @@ -package home +package util import ( + "log" "testing" - - "github.com/AdguardTeam/golibs/log" - "github.com/stretchr/testify/assert" ) func TestGetValidNetInterfacesForWeb(t *testing.T) { - ifaces, err := getValidNetInterfacesForWeb() + ifaces, err := GetValidNetInterfacesForWeb() if err != nil { t.Fatalf("Cannot get net interfaces: %s", err) } @@ -24,10 +22,3 @@ func TestGetValidNetInterfacesForWeb(t *testing.T) { log.Printf("%v", iface) } } - -func TestSplitNext(t *testing.T) { - s := " a,b , c " - assert.True(t, SplitNext(&s, ',') == "a") - assert.True(t, SplitNext(&s, ',') == "b") - assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0) -} diff --git a/home/os_freebsd.go b/util/os_freebsd.go similarity index 86% rename from home/os_freebsd.go rename to util/os_freebsd.go index 43ee223e..33311e16 100644 --- a/home/os_freebsd.go +++ b/util/os_freebsd.go @@ -1,6 +1,6 @@ // +build freebsd -package home +package util import ( "os" @@ -11,7 +11,7 @@ import ( // Set user-specified limit of how many fd's we can use // https://github.com/AdguardTeam/AdGuardHome/issues/659 -func setRlimit(val uint) { +func SetRlimit(val uint) { var rlim syscall.Rlimit rlim.Max = int64(val) rlim.Cur = int64(val) @@ -22,6 +22,6 @@ func setRlimit(val uint) { } // Check if the current user has root (administrator) rights -func haveAdminRights() (bool, error) { +func HaveAdminRights() (bool, error) { return os.Getuid() == 0, nil } diff --git a/home/os_unix.go b/util/os_unix.go similarity index 87% rename from home/os_unix.go rename to util/os_unix.go index 2623376e..338edfa8 100644 --- a/home/os_unix.go +++ b/util/os_unix.go @@ -1,6 +1,6 @@ // +build aix darwin dragonfly linux netbsd openbsd solaris -package home +package util import ( "os" @@ -11,7 +11,7 @@ import ( // Set user-specified limit of how many fd's we can use // https://github.com/AdguardTeam/AdGuardHome/issues/659 -func setRlimit(val uint) { +func SetRlimit(val uint) { var rlim syscall.Rlimit rlim.Max = uint64(val) rlim.Cur = uint64(val) @@ -22,6 +22,6 @@ func setRlimit(val uint) { } // Check if the current user has root (administrator) rights -func haveAdminRights() (bool, error) { +func HaveAdminRights() (bool, error) { return os.Getuid() == 0, nil } diff --git a/home/os_windows.go b/util/os_windows.go similarity index 87% rename from home/os_windows.go rename to util/os_windows.go index f6949d93..e081f758 100644 --- a/home/os_windows.go +++ b/util/os_windows.go @@ -1,12 +1,12 @@ -package home +package util import "golang.org/x/sys/windows" // Set user-specified limit of how many fd's we can use -func setRlimit(val uint) { +func SetRlimit(val uint) { } -func haveAdminRights() (bool, error) { +func HaveAdminRights() (bool, error) { var token windows.Token h, _ := windows.GetCurrentProcess() err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token) diff --git a/home/syslog_others.go b/util/syslog_others.go similarity index 62% rename from home/syslog_others.go rename to util/syslog_others.go index 8aa0f8b0..f4ad9119 100644 --- a/home/syslog_others.go +++ b/util/syslog_others.go @@ -1,14 +1,14 @@ // +build !windows,!nacl,!plan9 -package home +package util import ( "log" "log/syslog" ) -// configureSyslog reroutes standard logger output to syslog -func configureSyslog() error { +// ConfigureSyslog reroutes standard logger output to syslog +func ConfigureSyslog(serviceName string) error { w, err := syslog.New(syslog.LOG_NOTICE|syslog.LOG_USER, serviceName) if err != nil { return err diff --git a/home/syslog_windows.go b/util/syslog_windows.go similarity index 94% rename from home/syslog_windows.go rename to util/syslog_windows.go index a80933bb..30ee7815 100644 --- a/home/syslog_windows.go +++ b/util/syslog_windows.go @@ -1,4 +1,4 @@ -package home +package util import ( "log" @@ -17,7 +17,7 @@ func (w *eventLogWriter) Write(b []byte) (int, error) { return len(b), w.el.Info(1, string(b)) } -func configureSyslog() error { +func ConfigureSyslog(serviceName string) error { // Note that the eventlog src is the same as the service name // Otherwise, we will get "the description for event id cannot be found" warning in every log record From 2df5f281c4f5949b92edd4747ece60ff73799e54 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Thu, 13 Feb 2020 18:35:54 +0300 Subject: [PATCH 23/31] *: lang fix --- client/src/__locales/en.json | 2 +- home/auth.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 74ac7021..ee17dbc1 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -458,7 +458,7 @@ "check_reason": "Reason: {{reason}}", "check_rule": "Rule: {{rule}}", "check_service": "Service name: {{service}}", - "check_not_found": "Doesn't exist in any filter", + "check_not_found": "Not found in your filter lists", "client_confirm_block": "Are you sure you want to block the client \"{{ip}}\"?", "client_confirm_unblock": "Are you sure you want to unblock the client \"{{ip}}\"?", "client_blocked": "Client \"{{ip}}\" successfully blocked", diff --git a/home/auth.go b/home/auth.go index 3e0d4300..e6e4642a 100644 --- a/home/auth.go +++ b/home/auth.go @@ -152,7 +152,7 @@ func (a *Auth) addSession(data []byte, s *session) { a.sessions[name] = s a.lock.Unlock() if a.storeSession(data, s) { - log.Info("Auth: created session %s: expire=%d", name, s.expire) + log.Debug("Auth: created session %s: expire=%d", name, s.expire) } } From 9d5f201ddf20985599aed9bc486a5f01f3226cfc Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Tue, 18 Feb 2020 13:02:59 +0300 Subject: [PATCH 24/31] *(dnsforward): added QLogFile struct This struct is supposed to use an easier algorithm for scanning query log files. The end goal is allow us scanning multiple query log files. --- go.mod | 1 + querylog/qlog_file.go | 279 +++++++++++++++++++++++++++++++++++++ querylog/qlog_file_test.go | 59 ++++++++ querylog/querylog_file.go | 2 + 4 files changed, 341 insertions(+) create mode 100644 querylog/qlog_file.go create mode 100644 querylog/qlog_file_test.go diff --git a/go.mod b/go.mod index 28b2d1a4..18e8de6e 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/kardianos/service v0.0.0-20181115005516-4c239ee84e7b github.com/krolaw/dhcp4 v0.0.0-20180925202202-7cead472c414 github.com/miekg/dns v1.1.26 + github.com/pkg/errors v0.8.1 github.com/sparrc/go-ping v0.0.0-20181106165434-ef3ab45e41b0 github.com/stretchr/testify v1.4.0 go.etcd.io/bbolt v1.3.3 // indirect diff --git a/querylog/qlog_file.go b/querylog/qlog_file.go new file mode 100644 index 00000000..83f0d30d --- /dev/null +++ b/querylog/qlog_file.go @@ -0,0 +1,279 @@ +package querylog + +import ( + "io" + "os" + "sync" + "time" + + "github.com/pkg/errors" +) + +var ErrSeekNotFound = errors.New("Seek not found the record") + +const bufferSize = 64 * 1024 // 64 KB is the buffer size + +type QLogFile struct { + file *os.File // the query log file + position int64 // current position in the file + + buffer []byte // buffer that we've read from the file + bufferStart int64 // start of the buffer (in the file) + bufferLen int // buffer len + + lock sync.Mutex // We use mutex to make it thread-safe +} + +// NewQLogFile initializes a new instance of the QLogFile +func NewQLogFile(path string) (*QLogFile, error) { + f, err := os.OpenFile(path, os.O_RDONLY, 0644) + + if err != nil { + return nil, err + } + + return &QLogFile{ + file: f, + }, nil +} + +// Seek performs binary search in the query log file looking for a record +// with the specified timestamp. +// +// The algorithm is rather simple: +// 1. It starts with the position in the middle of a file +// 2. Shifts back to the beginning of the line +// 3. Checks the log record timestamp +// 4. If it is lower than the timestamp we are looking for, +// it shifts seek position to 3/4 of the file. Otherwise, to 1/4 of the file. +// 5. It performs the search again, every time the search scope is narrowed twice. +// +// It returns the position of the line with the timestamp we were looking for. +// If we could not find it, it returns 0 and ErrSeekNotFound +func (q *QLogFile) Seek(timestamp uint64) (int64, error) { + q.lock.Lock() + defer q.lock.Unlock() + + // First of all, check the file size + fileInfo, err := q.file.Stat() + if err != nil { + return 0, err + } + + // Define the search scope + start := int64(0) + end := fileInfo.Size() + probe := (end - start) / 2 + + // Get the line + line, _, err := q.readProbeLine(probe) + if err != nil { + return 0, err + } + + // Get the timestamp from the query log record + ts := q.readTimestamp(line) + + if ts == timestamp { + // Hurray, returning the result + return probe, nil + } + + // Narrow the scope and repeat the search + if ts > timestamp { + end := probe + probe = (end - start) / 2 + } else { + start := probe + probe = (end - start) / 2 + } + + // TODO: temp + q.position = probe + + // TODO: Check start/stop/probe values and loop this + return 0, ErrSeekNotFound +} + +// SeekStart changes the current position to the end of the file +// Please note that we're reading query log in the reverse order +// and that's why log start is actually the end of file +func (q *QLogFile) SeekStart() (int64, error) { + q.lock.Lock() + defer q.lock.Unlock() + + // First of all, check the file size + fileInfo, err := q.file.Stat() + if err != nil { + return 0, err + } + + // Place the position to the very end of file + q.position = fileInfo.Size() - 1 + if q.position < 0 { + // TODO: test empty file + q.position = 0 + } + return q.position, nil +} + +// ReadNext reads the next line (in the reverse order) from the file +// and shifts the current position left to the next (actually prev) line. +// returns io.EOF if there's nothing to read more +func (q *QLogFile) ReadNext() (string, error) { + q.lock.Lock() + defer q.lock.Unlock() + + if q.position == 0 { + return "", io.EOF + } + + line, lineIdx, err := q.readNextLine(q.position) + if err != nil { + return "", err + } + + // Shift position + if lineIdx == 0 { + q.position = 0 + } else { + // there's usually a line break before the line + // so we should shift one more char left from the line + // line\nline + q.position = lineIdx - 1 + } + return line, err +} + +// Close frees the underlying resources +func (q *QLogFile) Close() error { + return q.file.Close() +} + +// readNextLine reads the next line from the specified position +// this line actually have to END on that position. +// +// the algorithm is: +// 1. check if we have the buffer initialized +// 2. if it is, scan it and look for the line there +// 3. if we cannot find the line there, read the prev chunk into the buffer +// 4. read the line from the buffer +func (q *QLogFile) readNextLine(position int64) (string, int64, error) { + relativePos := position - q.bufferStart + if q.buffer == nil || relativePos < maxEntrySize { + // Time to re-init the buffer + err := q.initBuffer(position) + if err != nil { + return "", 0, err + } + } + + // Look for the end of the prev line + // This is where we'll read from + var startLine = int64(0) + for i := relativePos - 1; i >= 0; i-- { + if q.buffer[i] == '\n' { + startLine = i + 1 + break + } + } + + line := string(q.buffer[startLine:relativePos]) + lineIdx := q.bufferStart + startLine + return line, lineIdx, nil +} + +// initBuffer initializes the QLogFile buffer. +// the goal is to read a chunk of file that includes the line with the specified position. +func (q *QLogFile) initBuffer(position int64) error { + q.bufferStart = int64(0) + if (position - bufferSize) > 0 { + q.bufferStart = position - bufferSize + } + + // Seek to this position + _, err := q.file.Seek(q.bufferStart, io.SeekStart) + if err != nil { + return err + } + + if q.buffer == nil { + q.buffer = make([]byte, bufferSize) + } + q.bufferLen, err = q.file.Read(q.buffer) + // TODO: validate bufferLen + if err != nil { + return err + } + + return nil +} + +// readProbeLine reads a line that includes the specified position +// this method is supposed to be used when we use binary search in the Seek method +// in the case of consecutive reads, use readNext (it uses a better buffer) +func (q *QLogFile) readProbeLine(position int64) (string, int64, error) { + // First of all, we should read a buffer that will include the query log line + // In order to do this, we'll define the boundaries + seekPosition := int64(0) + relativePos := position // position relative to the buffer we're going to read + if (position - maxEntrySize) > 0 { + // TODO: cover this case in tests + seekPosition = position - maxEntrySize + relativePos = maxEntrySize + } + + // Seek to this position + _, err := q.file.Seek(seekPosition, io.SeekStart) + if err != nil { + return "", 0, err + } + + // The buffer size is 2*maxEntrySize + buffer := make([]byte, maxEntrySize*2) + bufferLen, err := q.file.Read(buffer) + if err != nil { + return "", 0, err + } + + // Now start looking for the new line character starting + // from the relativePos and going left + var startLine = int64(0) + for i := relativePos - 1; i >= 0; i-- { + if buffer[i] == '\n' { + startLine = i + 1 + break + } + } + // Looking for the end of line now + var endLine = int64(bufferLen) + for i := relativePos; i < int64(bufferLen); i++ { + if buffer[i] == '\n' { + endLine = i + break + } + } + + // Finally we can return the string we were looking for + lineIdx := startLine + seekPosition + return string(buffer[startLine:endLine]), lineIdx, nil +} + +// readTimestamp reads the timestamp field from the query log line +func (q *QLogFile) readTimestamp(str string) uint64 { + val := readJSONValue(str, "T") + if len(val) == 0 { + val = readJSONValue(str, "Time") + } + + if len(val) == 0 { + // TODO: log + return 0 + } + tm, err := time.Parse(time.RFC3339, val) + if err != nil { + // TODO: log + return 0 + } + return uint64(tm.UnixNano()) +} diff --git a/querylog/qlog_file_test.go b/querylog/qlog_file_test.go new file mode 100644 index 00000000..3d6d2600 --- /dev/null +++ b/querylog/qlog_file_test.go @@ -0,0 +1,59 @@ +package querylog + +import ( + "io" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Check adding and loading (with filtering) entries from disk and memory +func TestQLogFile(t *testing.T) { + conf := Config{ + Enabled: true, + Interval: 1, + MemSize: 100, + } + conf.BaseDir = prepareTestDir() + defer func() { _ = os.RemoveAll(conf.BaseDir) }() + l := newQueryLog(conf) + + // add disk entries + addEntry(l, "example.org", "1.2.3.4", "0.1.2.4") + addEntry(l, "example.org", "1.2.3.4", "0.1.2.5") + + // write to disk + _ = l.flushLogBuffer(true) + + // create the new QLogFile instance + q, err := NewQLogFile(l.logFile) + assert.Nil(t, err) + assert.NotNil(t, q) + + // seek to the start + pos, err := q.SeekStart() + assert.Nil(t, err) + assert.True(t, pos > 0) + + // read first line + line, err := q.ReadNext() + assert.Nil(t, err) + assert.True(t, strings.Contains(line, "0.1.2.5"), line) + assert.True(t, strings.HasPrefix(line, "{"), line) + assert.True(t, strings.HasSuffix(line, "}"), line) + + // read second line + line, err = q.ReadNext() + assert.Nil(t, err) + assert.Equal(t, int64(0), q.position) + assert.True(t, strings.Contains(line, "0.1.2.4"), line) + assert.True(t, strings.HasPrefix(line, "{"), line) + assert.True(t, strings.HasSuffix(line, "}"), line) + + // try reading again (there's nothing to read anymore) + line, err = q.ReadNext() + assert.Equal(t, io.EOF, err) + assert.Equal(t, "", line) +} diff --git a/querylog/querylog_file.go b/querylog/querylog_file.go index 02296a98..3cf2ae9e 100644 --- a/querylog/querylog_file.go +++ b/querylog/querylog_file.go @@ -18,6 +18,8 @@ import ( ) const enableGzip = false + +// TODO: Check this when we append a new line -- we don't want to have a line longer than this const maxEntrySize = 1000 // flushLogBuffer flushes the current buffer to file and resets the current buffer From 67dc7d7d888c7fe1576be842344649986411d8d3 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Thu, 20 Feb 2020 14:11:33 +0300 Subject: [PATCH 25/31] *: better tests for qlog --- querylog/qlog_file_test.go | 67 ++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 18 deletions(-) diff --git a/querylog/qlog_file_test.go b/querylog/qlog_file_test.go index 3d6d2600..6ba5780f 100644 --- a/querylog/qlog_file_test.go +++ b/querylog/qlog_file_test.go @@ -1,34 +1,38 @@ package querylog import ( + "encoding/binary" "io" + "io/ioutil" + "net" "os" "strings" "testing" + "time" "github.com/stretchr/testify/assert" ) +func TestQLogFileEmpty(t *testing.T) { + // TODO: test empty file +} + +func TestQLogFileLarge(t *testing.T) { + // TODO: test reading large file +} + +func TestQLogFileSeek(t *testing.T) { + // TODO: test seek method on a small file +} + // Check adding and loading (with filtering) entries from disk and memory func TestQLogFile(t *testing.T) { - conf := Config{ - Enabled: true, - Interval: 1, - MemSize: 100, - } - conf.BaseDir = prepareTestDir() - defer func() { _ = os.RemoveAll(conf.BaseDir) }() - l := newQueryLog(conf) - - // add disk entries - addEntry(l, "example.org", "1.2.3.4", "0.1.2.4") - addEntry(l, "example.org", "1.2.3.4", "0.1.2.5") - - // write to disk - _ = l.flushLogBuffer(true) + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFile := prepareTestFile(testDir, 2) // create the new QLogFile instance - q, err := NewQLogFile(l.logFile) + q, err := NewQLogFile(testFile) assert.Nil(t, err) assert.NotNil(t, q) @@ -40,7 +44,7 @@ func TestQLogFile(t *testing.T) { // read first line line, err := q.ReadNext() assert.Nil(t, err) - assert.True(t, strings.Contains(line, "0.1.2.5"), line) + assert.True(t, strings.Contains(line, "0.0.0.2"), line) assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasSuffix(line, "}"), line) @@ -48,7 +52,7 @@ func TestQLogFile(t *testing.T) { line, err = q.ReadNext() assert.Nil(t, err) assert.Equal(t, int64(0), q.position) - assert.True(t, strings.Contains(line, "0.1.2.4"), line) + assert.True(t, strings.Contains(line, "0.0.0.1"), line) assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasSuffix(line, "}"), line) @@ -57,3 +61,30 @@ func TestQLogFile(t *testing.T) { assert.Equal(t, io.EOF, err) assert.Equal(t, "", line) } + +// prepareTestFile - prepares a test query log file with the specified number of lines +func prepareTestFile(dir string, linesCount int) string { + format := `{"IP":"${IP}","T":"${TIMESTAMP}","QH":"example.org","QT":"A","QC":"IN","Answer":"AAAAAAABAAEAAAAAB2V4YW1wbGUDb3JnAAABAAEHZXhhbXBsZQNvcmcAAAEAAQAAAAAABAECAwQ=","Result":{},"Elapsed":0,"Upstream":"upstream"}` + + lineTime, _ := time.Parse(time.RFC3339Nano, "2020-02-18T22:36:35.920973+03:00") + lineIP := uint32(0) + + f, _ := ioutil.TempFile(dir, "*.txt") + + for i := 0; i < linesCount; i++ { + lineIP += 1 + lineTime = lineTime.Add(time.Second) + + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, lineIP) + + line := format + line = strings.ReplaceAll(line, "${IP}", ip.String()) + line = strings.ReplaceAll(line, "${TIMESTAMP}", lineTime.Format(time.RFC3339Nano)) + + _, _ = f.WriteString(line) + _, _ = f.WriteString("\n") + } + + return f.Name() +} From 712023112d96a4c0a951442bfb1dcbbcd7179786 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Thu, 20 Feb 2020 19:38:11 +0300 Subject: [PATCH 26/31] *(dnsforward): finished new qlog_file implementation --- querylog/qlog_file.go | 107 +++++++++++++++++--------- querylog/qlog_file_test.go | 150 ++++++++++++++++++++++++++++++++++++- 2 files changed, 216 insertions(+), 41 deletions(-) diff --git a/querylog/qlog_file.go b/querylog/qlog_file.go index 83f0d30d..63557098 100644 --- a/querylog/qlog_file.go +++ b/querylog/qlog_file.go @@ -6,12 +6,14 @@ import ( "sync" "time" + "github.com/AdguardTeam/golibs/log" + "github.com/pkg/errors" ) var ErrSeekNotFound = errors.New("Seek not found the record") -const bufferSize = 64 * 1024 // 64 KB is the buffer size +const bufferSize = 256 * 1024 // 256 KB is the buffer size type QLogFile struct { file *os.File // the query log file @@ -48,12 +50,16 @@ func NewQLogFile(path string) (*QLogFile, error) { // it shifts seek position to 3/4 of the file. Otherwise, to 1/4 of the file. // 5. It performs the search again, every time the search scope is narrowed twice. // -// It returns the position of the line with the timestamp we were looking for. +// It returns the position of the the line with the timestamp we were looking for +// so that when we call "ReadNext" this line was returned. // If we could not find it, it returns 0 and ErrSeekNotFound func (q *QLogFile) Seek(timestamp uint64) (int64, error) { q.lock.Lock() defer q.lock.Unlock() + // Empty the buffer + q.buffer = nil + // First of all, check the file size fileInfo, err := q.file.Stat() if err != nil { @@ -61,38 +67,64 @@ func (q *QLogFile) Seek(timestamp uint64) (int64, error) { } // Define the search scope - start := int64(0) - end := fileInfo.Size() - probe := (end - start) / 2 + start := int64(0) // start of the search interval (position in the file) + end := fileInfo.Size() // end of the search interval (position in the file) + probe := (end - start) / 2 // probe -- approximate index of the line we'll try to check + var line string + var lineIdx int64 // index of the probe line in the file + var lastProbeLineIdx int64 // index of the last probe line - // Get the line - line, _, err := q.readProbeLine(probe) - if err != nil { - return 0, err + // Count seek depth in order to detect mistakes + // If depth is too large, we should stop the search + depth := 0 + + for { + // Get the line at the specified position + line, lineIdx, err = q.readProbeLine(probe) + if err != nil { + return 0, err + } + + // Get the timestamp from the query log record + ts := q.readTimestamp(line) + + if ts == 0 { + return 0, ErrSeekNotFound + } + + if ts == timestamp { + // Hurray, returning the result + break + } + + if lastProbeLineIdx == lineIdx { + // If we're testing the same line twice then most likely + // the scope is too narrow and we won't find anything anymore + return 0, ErrSeekNotFound + } + + // Narrow the scope and repeat the search + if ts > timestamp { + // If the timestamp we're looking for is OLDER than what we found + // Then the line is somewhere on the LEFT side from the current probe position + end = probe + probe = start + (end-start)/2 + } else { + // If the timestamp we're looking for is NEWER than what we found + // Then the line is somewhere on the RIGHT side from the current probe position + start = probe + probe = start + (end-start)/2 + } + + depth++ + if depth >= 100 { + log.Error("Seek depth is too high, aborting. File %s, ts %v", q.file.Name(), timestamp) + return 0, ErrSeekNotFound + } } - // Get the timestamp from the query log record - ts := q.readTimestamp(line) - - if ts == timestamp { - // Hurray, returning the result - return probe, nil - } - - // Narrow the scope and repeat the search - if ts > timestamp { - end := probe - probe = (end - start) / 2 - } else { - start := probe - probe = (end - start) / 2 - } - - // TODO: temp - q.position = probe - - // TODO: Check start/stop/probe values and loop this - return 0, ErrSeekNotFound + q.position = lineIdx + int64(len(line)) + return q.position, nil } // SeekStart changes the current position to the end of the file @@ -102,6 +134,9 @@ func (q *QLogFile) SeekStart() (int64, error) { q.lock.Lock() defer q.lock.Unlock() + // Empty the buffer + q.buffer = nil + // First of all, check the file size fileInfo, err := q.file.Stat() if err != nil { @@ -111,7 +146,6 @@ func (q *QLogFile) SeekStart() (int64, error) { // Place the position to the very end of file q.position = fileInfo.Size() - 1 if q.position < 0 { - // TODO: test empty file q.position = 0 } return q.position, nil @@ -160,12 +194,13 @@ func (q *QLogFile) Close() error { // 4. read the line from the buffer func (q *QLogFile) readNextLine(position int64) (string, int64, error) { relativePos := position - q.bufferStart - if q.buffer == nil || relativePos < maxEntrySize { + if q.buffer == nil || (relativePos < maxEntrySize && q.bufferStart != 0) { // Time to re-init the buffer err := q.initBuffer(position) if err != nil { return "", 0, err } + relativePos = position - q.bufferStart } // Look for the end of the prev line @@ -201,7 +236,6 @@ func (q *QLogFile) initBuffer(position int64) error { q.buffer = make([]byte, bufferSize) } q.bufferLen, err = q.file.Read(q.buffer) - // TODO: validate bufferLen if err != nil { return err } @@ -218,7 +252,6 @@ func (q *QLogFile) readProbeLine(position int64) (string, int64, error) { seekPosition := int64(0) relativePos := position // position relative to the buffer we're going to read if (position - maxEntrySize) > 0 { - // TODO: cover this case in tests seekPosition = position - maxEntrySize relativePos = maxEntrySize } @@ -267,12 +300,12 @@ func (q *QLogFile) readTimestamp(str string) uint64 { } if len(val) == 0 { - // TODO: log + log.Error("Couldn't find timestamp in %s: %s", q.file.Name(), str) return 0 } tm, err := time.Parse(time.RFC3339, val) if err != nil { - // TODO: log + log.Error("Couldn't parse timestamp in %s: %s", q.file.Name(), val) return 0 } return uint64(tm.UnixNano()) diff --git a/querylog/qlog_file_test.go b/querylog/qlog_file_test.go index 6ba5780f..e9bc70e8 100644 --- a/querylog/qlog_file_test.go +++ b/querylog/qlog_file_test.go @@ -14,15 +14,157 @@ import ( ) func TestQLogFileEmpty(t *testing.T) { - // TODO: test empty file + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFile := prepareTestFile(testDir, 0) + + // create the new QLogFile instance + q, err := NewQLogFile(testFile) + assert.Nil(t, err) + assert.NotNil(t, q) + + // seek to the start + pos, err := q.SeekStart() + assert.Nil(t, err) + assert.Equal(t, int64(0), pos) + + // try reading anyway + line, err := q.ReadNext() + assert.Equal(t, io.EOF, err) + assert.Equal(t, "", line) } func TestQLogFileLarge(t *testing.T) { - // TODO: test reading large file + // should be large enough + count := 50000 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFile := prepareTestFile(testDir, count) + + // create the new QLogFile instance + q, err := NewQLogFile(testFile) + assert.Nil(t, err) + assert.NotNil(t, q) + + // seek to the start + pos, err := q.SeekStart() + assert.Nil(t, err) + assert.NotEqual(t, int64(0), pos) + + read := 0 + var line string + for err == nil { + line, err = q.ReadNext() + if err == nil { + assert.True(t, len(line) > 0) + read += 1 + } + } + + assert.Equal(t, count, read) + assert.Equal(t, io.EOF, err) } -func TestQLogFileSeek(t *testing.T) { - // TODO: test seek method on a small file +func TestQLogFileSeekLargeFile(t *testing.T) { + // more or less big file + count := 10000 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFile := prepareTestFile(testDir, count) + + // create the new QLogFile instance + q, err := NewQLogFile(testFile) + assert.Nil(t, err) + assert.NotNil(t, q) + + // CASE 1: NOT TOO OLD LINE + testSeekLine(t, q, 300) + + // CASE 2: OLD LINE + testSeekLine(t, q, count-300) + + // CASE 3: FIRST LINE + testSeekLine(t, q, 0) + + // CASE 4: LAST LINE + testSeekLine(t, q, count) + + // CASE 5: Seek non-existent (too low) + _, err = q.Seek(123) + assert.NotNil(t, err) + + // CASE 6: Seek non-existent (too high) + ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") + _, err = q.Seek(uint64(ts.UnixNano())) + assert.NotNil(t, err) +} + +func TestQLogFileSeekSmallFile(t *testing.T) { + // more or less big file + count := 10 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFile := prepareTestFile(testDir, count) + + // create the new QLogFile instance + q, err := NewQLogFile(testFile) + assert.Nil(t, err) + assert.NotNil(t, q) + + // CASE 1: NOT TOO OLD LINE + testSeekLine(t, q, 2) + + // CASE 2: OLD LINE + testSeekLine(t, q, count-2) + + // CASE 3: FIRST LINE + testSeekLine(t, q, 0) + + // CASE 4: LAST LINE + testSeekLine(t, q, count) + + // CASE 5: Seek non-existent (too low) + _, err = q.Seek(123) + assert.NotNil(t, err) + + // CASE 6: Seek non-existent (too high) + ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") + _, err = q.Seek(uint64(ts.UnixNano())) + assert.NotNil(t, err) +} + +func testSeekLine(t *testing.T, q *QLogFile, lineNumber int) { + line, err := getQLogLine(q, lineNumber) + assert.Nil(t, err) + ts := q.readTimestamp(line) + assert.NotEqual(t, uint64(0), ts) + + // try seeking to that line now + pos, err := q.Seek(ts) + assert.Nil(t, err) + assert.NotEqual(t, int64(0), pos) + + testLine, err := q.ReadNext() + assert.Nil(t, err) + assert.Equal(t, line, testLine) +} + +func getQLogLine(q *QLogFile, lineNumber int) (string, error) { + _, err := q.SeekStart() + if err != nil { + return "", err + } + + for i := 1; i < lineNumber; i++ { + _, err := q.ReadNext() + if err != nil { + return "", err + } + } + return q.ReadNext() } // Check adding and loading (with filtering) entries from disk and memory From 90f2c1835348c333fb4933990a43586e4349ebf3 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Thu, 20 Feb 2020 21:12:51 +0300 Subject: [PATCH 27/31] *(dnsforward): qlogreader implementation --- querylog/qlog_file.go | 26 ++++-- querylog/qlog_file_test.go | 63 ++++++++------ querylog/qlog_reader.go | 139 +++++++++++++++++++++++++++++++ querylog/qlog_reader_test.go | 157 +++++++++++++++++++++++++++++++++++ querylog/querylog_file.go | 33 +------- 5 files changed, 355 insertions(+), 63 deletions(-) create mode 100644 querylog/qlog_reader.go create mode 100644 querylog/qlog_reader_test.go diff --git a/querylog/qlog_file.go b/querylog/qlog_file.go index 63557098..14ae2b32 100644 --- a/querylog/qlog_file.go +++ b/querylog/qlog_file.go @@ -11,10 +11,18 @@ import ( "github.com/pkg/errors" ) +// ErrSeekNotFound is returned from the Seek method +// if we failed to find the desired record var ErrSeekNotFound = errors.New("Seek not found the record") const bufferSize = 256 * 1024 // 256 KB is the buffer size +// QLogFile represents a single query log file +// It allows reading from the file in the reverse order +// +// Please note that this is a stateful object. +// Internally, it contains a pointer to a specific position in the file, +// and it reads lines in reverse order starting from that position. type QLogFile struct { file *os.File // the query log file position int64 // current position in the file @@ -40,7 +48,8 @@ func NewQLogFile(path string) (*QLogFile, error) { } // Seek performs binary search in the query log file looking for a record -// with the specified timestamp. +// with the specified timestamp. Once the record is found, it sets +// "position" so that the next ReadNext call returned that record. // // The algorithm is rather simple: // 1. It starts with the position in the middle of a file @@ -86,7 +95,7 @@ func (q *QLogFile) Seek(timestamp uint64) (int64, error) { } // Get the timestamp from the query log record - ts := q.readTimestamp(line) + ts := readQLogTimestamp(line) if ts == 0 { return 0, ErrSeekNotFound @@ -130,6 +139,9 @@ func (q *QLogFile) Seek(timestamp uint64) (int64, error) { // SeekStart changes the current position to the end of the file // Please note that we're reading query log in the reverse order // and that's why log start is actually the end of file +// +// Returns nil if we were able to change the current position. +// Returns error in any other case. func (q *QLogFile) SeekStart() (int64, error) { q.lock.Lock() defer q.lock.Unlock() @@ -292,20 +304,20 @@ func (q *QLogFile) readProbeLine(position int64) (string, int64, error) { return string(buffer[startLine:endLine]), lineIdx, nil } -// readTimestamp reads the timestamp field from the query log line -func (q *QLogFile) readTimestamp(str string) uint64 { +// readQLogTimestamp reads the timestamp field from the query log line +func readQLogTimestamp(str string) uint64 { val := readJSONValue(str, "T") if len(val) == 0 { val = readJSONValue(str, "Time") } if len(val) == 0 { - log.Error("Couldn't find timestamp in %s: %s", q.file.Name(), str) + log.Error("Couldn't find timestamp: %s", str) return 0 } - tm, err := time.Parse(time.RFC3339, val) + tm, err := time.Parse(time.RFC3339Nano, val) if err != nil { - log.Error("Couldn't parse timestamp in %s: %s", q.file.Name(), val) + log.Error("Couldn't parse timestamp: %s", val) return 0 } return uint64(tm.UnixNano()) diff --git a/querylog/qlog_file_test.go b/querylog/qlog_file_test.go index e9bc70e8..2ba6f38d 100644 --- a/querylog/qlog_file_test.go +++ b/querylog/qlog_file_test.go @@ -22,6 +22,7 @@ func TestQLogFileEmpty(t *testing.T) { q, err := NewQLogFile(testFile) assert.Nil(t, err) assert.NotNil(t, q) + defer q.Close() // seek to the start pos, err := q.SeekStart() @@ -46,6 +47,7 @@ func TestQLogFileLarge(t *testing.T) { q, err := NewQLogFile(testFile) assert.Nil(t, err) assert.NotNil(t, q) + defer q.Close() // seek to the start pos, err := q.SeekStart() @@ -78,18 +80,19 @@ func TestQLogFileSeekLargeFile(t *testing.T) { q, err := NewQLogFile(testFile) assert.Nil(t, err) assert.NotNil(t, q) + defer q.Close() // CASE 1: NOT TOO OLD LINE - testSeekLine(t, q, 300) + testSeekLineQLogFile(t, q, 300) // CASE 2: OLD LINE - testSeekLine(t, q, count-300) + testSeekLineQLogFile(t, q, count-300) // CASE 3: FIRST LINE - testSeekLine(t, q, 0) + testSeekLineQLogFile(t, q, 0) // CASE 4: LAST LINE - testSeekLine(t, q, count) + testSeekLineQLogFile(t, q, count) // CASE 5: Seek non-existent (too low) _, err = q.Seek(123) @@ -113,18 +116,19 @@ func TestQLogFileSeekSmallFile(t *testing.T) { q, err := NewQLogFile(testFile) assert.Nil(t, err) assert.NotNil(t, q) + defer q.Close() // CASE 1: NOT TOO OLD LINE - testSeekLine(t, q, 2) + testSeekLineQLogFile(t, q, 2) // CASE 2: OLD LINE - testSeekLine(t, q, count-2) + testSeekLineQLogFile(t, q, count-2) // CASE 3: FIRST LINE - testSeekLine(t, q, 0) + testSeekLineQLogFile(t, q, 0) // CASE 4: LAST LINE - testSeekLine(t, q, count) + testSeekLineQLogFile(t, q, count) // CASE 5: Seek non-existent (too low) _, err = q.Seek(123) @@ -136,10 +140,10 @@ func TestQLogFileSeekSmallFile(t *testing.T) { assert.NotNil(t, err) } -func testSeekLine(t *testing.T, q *QLogFile, lineNumber int) { - line, err := getQLogLine(q, lineNumber) +func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) { + line, err := getQLogFileLine(q, lineNumber) assert.Nil(t, err) - ts := q.readTimestamp(line) + ts := readQLogTimestamp(line) assert.NotEqual(t, uint64(0), ts) // try seeking to that line now @@ -152,7 +156,7 @@ func testSeekLine(t *testing.T, q *QLogFile, lineNumber int) { assert.Equal(t, line, testLine) } -func getQLogLine(q *QLogFile, lineNumber int) (string, error) { +func getQLogFileLine(q *QLogFile, lineNumber int) (string, error) { _, err := q.SeekStart() if err != nil { return "", err @@ -177,6 +181,7 @@ func TestQLogFile(t *testing.T) { q, err := NewQLogFile(testFile) assert.Nil(t, err) assert.NotNil(t, q) + defer q.Close() // seek to the start pos, err := q.SeekStart() @@ -206,27 +211,37 @@ func TestQLogFile(t *testing.T) { // prepareTestFile - prepares a test query log file with the specified number of lines func prepareTestFile(dir string, linesCount int) string { + return prepareTestFiles(dir, 1, linesCount)[0] +} + +// prepareTestFiles - prepares several test query log files +// each of them -- with the specified linesCount +func prepareTestFiles(dir string, filesCount, linesCount int) []string { format := `{"IP":"${IP}","T":"${TIMESTAMP}","QH":"example.org","QT":"A","QC":"IN","Answer":"AAAAAAABAAEAAAAAB2V4YW1wbGUDb3JnAAABAAEHZXhhbXBsZQNvcmcAAAEAAQAAAAAABAECAwQ=","Result":{},"Elapsed":0,"Upstream":"upstream"}` lineTime, _ := time.Parse(time.RFC3339Nano, "2020-02-18T22:36:35.920973+03:00") lineIP := uint32(0) - f, _ := ioutil.TempFile(dir, "*.txt") + files := make([]string, 0) + for j := 0; j < filesCount; j++ { + f, _ := ioutil.TempFile(dir, "*.txt") + files = append(files, f.Name()) - for i := 0; i < linesCount; i++ { - lineIP += 1 - lineTime = lineTime.Add(time.Second) + for i := 0; i < linesCount; i++ { + lineIP += 1 + lineTime = lineTime.Add(time.Second) - ip := make(net.IP, 4) - binary.BigEndian.PutUint32(ip, lineIP) + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, lineIP) - line := format - line = strings.ReplaceAll(line, "${IP}", ip.String()) - line = strings.ReplaceAll(line, "${TIMESTAMP}", lineTime.Format(time.RFC3339Nano)) + line := format + line = strings.ReplaceAll(line, "${IP}", ip.String()) + line = strings.ReplaceAll(line, "${TIMESTAMP}", lineTime.Format(time.RFC3339Nano)) - _, _ = f.WriteString(line) - _, _ = f.WriteString("\n") + _, _ = f.WriteString(line) + _, _ = f.WriteString("\n") + } } - return f.Name() + return files } diff --git a/querylog/qlog_reader.go b/querylog/qlog_reader.go new file mode 100644 index 00000000..6b081cc2 --- /dev/null +++ b/querylog/qlog_reader.go @@ -0,0 +1,139 @@ +package querylog + +import ( + "io" + + "github.com/joomcode/errorx" +) + +// QLogReader allows reading from multiple query log files in the reverse order. +// +// Please note that this is a stateful object. +// Internally, it contains a pointer to a particular query log file, and +// to a specific position in this file, and it reads lines in reverse order +// starting from that position. +type QLogReader struct { + // qFiles - array with the query log files + // The order is - from oldest to newest + qFiles []*QLogFile + + currentFile int // Index of the current file +} + +// NewQLogReader initializes a QLogReader instance +// with the specified files +func NewQLogReader(files []string) (*QLogReader, error) { + qFiles := make([]*QLogFile, 0) + + for _, f := range files { + q, err := NewQLogFile(f) + if err != nil { + // Close what we've already opened + _ = closeQFiles(qFiles) + return nil, err + } + + qFiles = append(qFiles, q) + } + + return &QLogReader{ + qFiles: qFiles, + currentFile: (len(qFiles) - 1), + }, nil +} + +// Seek performs binary search of a query log record with the specified timestamp. +// If the record is found, it sets QLogReader's position to point to that line, +// so that the next ReadNext call returned this line. +// +// Returns nil if the record is successfully found. +// Returns an error if for some reason we could not find a record with the specified timestamp. +func (r *QLogReader) Seek(timestamp uint64) error { + for i := len(r.qFiles) - 1; i >= 0; i-- { + q := r.qFiles[i] + _, err := q.Seek(timestamp) + if err == nil { + // Our search is finished, we found the element we were looking for + // Update currentFile only, position is already set properly in the QLogFile + r.currentFile = i + return nil + } + } + + return ErrSeekNotFound +} + +// SeekStart changes the current position to the end of the newest file +// Please note that we're reading query log in the reverse order +// and that's why log start is actually the end of file +// +// Returns nil if we were able to change the current position. +// Returns error in any other case. +func (r *QLogReader) SeekStart() error { + if len(r.qFiles) == 0 { + return nil + } + + r.currentFile = len(r.qFiles) - 1 + _, err := r.qFiles[r.currentFile].SeekStart() + return err +} + +// ReadNext reads the next line (in the reverse order) from the query log files. +// and shifts the current position left to the next (actually prev) line (or the next file). +// returns io.EOF if there's nothing to read more. +func (r *QLogReader) ReadNext() (string, error) { + if len(r.qFiles) == 0 { + return "", io.EOF + } + + for r.currentFile >= 0 { + q := r.qFiles[r.currentFile] + line, err := q.ReadNext() + if err != nil { + // Shift to the older file + r.currentFile-- + if r.currentFile < 0 { + break + } + + q = r.qFiles[r.currentFile] + + // Set it's position to the start right away + _, err = q.SeekStart() + + // This is unexpected, return an error right away + if err != nil { + return "", err + } + } else { + return line, nil + } + } + + // Nothing to read anymore + return "", io.EOF +} + +// Close closes the QLogReader +func (r *QLogReader) Close() error { + return closeQFiles(r.qFiles) +} + +// closeQFiles - helper method to close multiple QLogFile instances +func closeQFiles(qFiles []*QLogFile) error { + var errs []error + + for _, q := range qFiles { + err := q.Close() + if err != nil { + errs = append(errs, err) + } + } + + if len(errs) > 0 { + return errorx.DecorateMany("Error while closing QLogReader", errs...) + } + + return nil +} diff --git a/querylog/qlog_reader_test.go b/querylog/qlog_reader_test.go new file mode 100644 index 00000000..592ab4b9 --- /dev/null +++ b/querylog/qlog_reader_test.go @@ -0,0 +1,157 @@ +package querylog + +import ( + "io" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestQLogReaderEmpty(t *testing.T) { + r, err := NewQLogReader([]string{}) + assert.Nil(t, err) + assert.NotNil(t, r) + defer r.Close() + + // seek to the start + err = r.SeekStart() + assert.Nil(t, err) + + line, err := r.ReadNext() + assert.Equal(t, "", line) + assert.Equal(t, io.EOF, err) +} + +func TestQLogReaderOneFile(t *testing.T) { + // let's do one small file + count := 10 + filesCount := 1 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFiles := prepareTestFiles(testDir, filesCount, count) + + r, err := NewQLogReader(testFiles) + assert.Nil(t, err) + assert.NotNil(t, r) + defer r.Close() + + // seek to the start + err = r.SeekStart() + assert.Nil(t, err) + + // read everything + read := 0 + var line string + for err == nil { + line, err = r.ReadNext() + if err == nil { + assert.True(t, len(line) > 0) + read += 1 + } + } + + assert.Equal(t, count*filesCount, read) + assert.Equal(t, io.EOF, err) +} + +func TestQLogReaderMultipleFiles(t *testing.T) { + // should be large enough + count := 10000 + filesCount := 5 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFiles := prepareTestFiles(testDir, filesCount, count) + + r, err := NewQLogReader(testFiles) + assert.Nil(t, err) + assert.NotNil(t, r) + defer r.Close() + + // seek to the start + err = r.SeekStart() + assert.Nil(t, err) + + // read everything + read := 0 + var line string + for err == nil { + line, err = r.ReadNext() + if err == nil { + assert.True(t, len(line) > 0) + read += 1 + } + } + + assert.Equal(t, count*filesCount, read) + assert.Equal(t, io.EOF, err) +} + +func TestQLogReaderSeek(t *testing.T) { + // more or less big file + count := 10000 + filesCount := 2 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFiles := prepareTestFiles(testDir, filesCount, count) + + r, err := NewQLogReader(testFiles) + assert.Nil(t, err) + assert.NotNil(t, r) + defer r.Close() + + // CASE 1: NOT TOO OLD LINE + testSeekLineQLogReader(t, r, 300) + + // CASE 2: OLD LINE + testSeekLineQLogReader(t, r, count-300) + + // CASE 3: FIRST LINE + testSeekLineQLogReader(t, r, 0) + + // CASE 4: LAST LINE + testSeekLineQLogReader(t, r, count) + + // CASE 5: Seek non-existent (too low) + err = r.Seek(123) + assert.NotNil(t, err) + + // CASE 6: Seek non-existent (too high) + ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") + err = r.Seek(uint64(ts.UnixNano())) + assert.NotNil(t, err) +} + +func testSeekLineQLogReader(t *testing.T, r *QLogReader, lineNumber int) { + line, err := getQLogReaderLine(r, lineNumber) + assert.Nil(t, err) + ts := readQLogTimestamp(line) + assert.NotEqual(t, uint64(0), ts) + + // try seeking to that line now + err = r.Seek(ts) + assert.Nil(t, err) + + testLine, err := r.ReadNext() + assert.Nil(t, err) + assert.Equal(t, line, testLine) +} + +func getQLogReaderLine(r *QLogReader, lineNumber int) (string, error) { + err := r.SeekStart() + if err != nil { + return "", err + } + + for i := 1; i < lineNumber; i++ { + _, err := r.ReadNext() + if err != nil { + return "", err + } + } + return r.ReadNext() +} diff --git a/querylog/querylog_file.go b/querylog/querylog_file.go index 3cf2ae9e..e5544fc4 100644 --- a/querylog/querylog_file.go +++ b/querylog/querylog_file.go @@ -3,7 +3,6 @@ package querylog import ( "bufio" "bytes" - "compress/gzip" "encoding/base64" "encoding/json" "io" @@ -17,8 +16,6 @@ import ( "github.com/miekg/dns" ) -const enableGzip = false - // TODO: Check this when we append a new line -- we don't want to have a line longer than this const maxEntrySize = 1000 @@ -70,29 +67,7 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error { var err error var zb bytes.Buffer filename := l.logFile - - // gzip enabled? - if enableGzip { - filename += ".gz" - - zw := gzip.NewWriter(&zb) - zw.Name = l.logFile - zw.ModTime = time.Now() - - _, err = zw.Write(b.Bytes()) - if err != nil { - log.Error("Couldn't compress to gzip: %s", err) - zw.Close() - return err - } - - if err = zw.Close(); err != nil { - log.Error("Couldn't close gzip writer: %s", err) - return err - } - } else { - zb = b - } + zb = b l.fileWriteLock.Lock() defer l.fileWriteLock.Unlock() @@ -118,11 +93,6 @@ func (l *queryLog) rotate() error { from := l.logFile to := l.logFile + ".1" - if enableGzip { - from = l.logFile + ".gz" - to = l.logFile + ".gz.1" - } - if _, err := os.Stat(from); os.IsNotExist(err) { // do nothing, file doesn't exist return nil @@ -135,7 +105,6 @@ func (l *queryLog) rotate() error { } log.Debug("Rotated from %s to %s successfully", from, to) - return nil } From 696a6df6ba9992bdc0693bcbd38baceeaca5634d Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Fri, 21 Feb 2020 01:07:30 +0300 Subject: [PATCH 28/31] *(dnsforward): replaced old search with the new one --- querylog/qlog.go | 217 +++++-------- querylog/qlog_file.go | 6 +- querylog/qlog_file_test.go | 4 +- querylog/qlog_reader.go | 2 +- querylog/qlog_reader_test.go | 2 +- querylog/querylog_file.go | 597 +---------------------------------- querylog/querylog_search.go | 366 +++++++++++++++++++++ querylog/querylog_test.go | 60 ++-- 8 files changed, 481 insertions(+), 773 deletions(-) create mode 100644 querylog/querylog_search.go diff --git a/querylog/qlog.go b/querylog/qlog.go index 8a34130e..cafcdb5c 100644 --- a/querylog/qlog.go +++ b/querylog/qlog.go @@ -171,82 +171,6 @@ func (l *queryLog) Add(params AddParams) { } } -// Return TRUE if this entry is needed -func isNeeded(entry *logEntry, params getDataParams) bool { - if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered { - return false - } - - if len(params.QuestionType) != 0 { - if entry.QType != params.QuestionType { - return false - } - } - - if len(params.Domain) != 0 { - if (params.StrictMatchDomain && entry.QHost != params.Domain) || - (!params.StrictMatchDomain && strings.Index(entry.QHost, params.Domain) == -1) { - return false - } - } - - if len(params.Client) != 0 { - if (params.StrictMatchClient && entry.IP != params.Client) || - (!params.StrictMatchClient && strings.Index(entry.IP, params.Client) == -1) { - return false - } - } - - return true -} - -func (l *queryLog) readFromFile(params getDataParams) ([]*logEntry, time.Time, int) { - entries := []*logEntry{} - oldest := time.Time{} - - r := l.OpenReader() - if r == nil { - return entries, time.Time{}, 0 - } - r.BeginRead(params.OlderThan, getDataLimit, ¶ms) - total := uint64(0) - for total <= maxSearchEntries { - newEntries := []*logEntry{} - for { - entry := r.Next() - if entry == nil { - break - } - - if !isNeeded(entry, params) { - continue - } - if len(newEntries) == getDataLimit { - newEntries = newEntries[1:] - } - newEntries = append(newEntries, entry) - } - - log.Debug("entries: +%d (%d) [%d]", len(newEntries), len(entries), r.Total()) - - entries = append(newEntries, entries...) - if len(entries) > getDataLimit { - toremove := len(entries) - getDataLimit - entries = entries[toremove:] - break - } - if r.Total() == 0 || len(entries) == getDataLimit { - break - } - total += r.Total() - oldest = r.Oldest() - r.BeginReadPrev(getDataLimit) - } - - r.Close() - return entries, oldest, int(total) -} - // Parameters for getData() type getDataParams struct { OlderThan time.Time // return entries that are older than this value @@ -267,17 +191,12 @@ const ( responseStatusFiltered ) -// Get log entries +// Gets log entries func (l *queryLog) getData(params getDataParams) map[string]interface{} { - var data = []map[string]interface{}{} - - var oldest time.Time now := time.Now() - entries := []*logEntry{} - total := 0 // add from file - entries, oldest, total = l.readFromFile(params) + fileEntries, oldest, total := l.searchFiles(params) if params.OlderThan.IsZero() { params.OlderThan = now @@ -286,9 +205,9 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { // add from memory buffer l.bufferLock.Lock() total += len(l.buffer) + memoryEntries := make([]*logEntry, 0) for _, entry := range l.buffer { - - if !isNeeded(entry, params) { + if !matchesGetDataParams(entry, params) { continue } @@ -296,68 +215,24 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { break } - if len(entries) == getDataLimit { - entries = entries[1:] - } - entries = append(entries, entry) + memoryEntries = append(memoryEntries, entry) } l.bufferLock.Unlock() - // process the elements from latest to oldest - for i := len(entries) - 1; i >= 0; i-- { + // now let's get a unified collection + entries := append(memoryEntries, fileEntries...) + if len(entries) > getDataLimit { + // remove extra records + entries = entries[(len(entries) - getDataLimit):] + } + + // init the response object + var data = []map[string]interface{}{} + + // the elements order is already reversed (from newer to older) + for i := 0; i < len(entries); i++ { entry := entries[i] - var a *dns.Msg - - if len(entry.Answer) > 0 { - a = new(dns.Msg) - if err := a.Unpack(entry.Answer); err != nil { - log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer)) - a = nil - } - } - - jsonEntry := map[string]interface{}{ - "reason": entry.Result.Reason.String(), - "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), - "time": entry.Time.Format(time.RFC3339Nano), - "client": entry.IP, - } - jsonEntry["question"] = map[string]interface{}{ - "host": entry.QHost, - "type": entry.QType, - "class": entry.QClass, - } - - if a != nil { - jsonEntry["status"] = dns.RcodeToString[a.Rcode] - } - if len(entry.Result.Rule) > 0 { - jsonEntry["rule"] = entry.Result.Rule - jsonEntry["filterId"] = entry.Result.FilterID - } - - if len(entry.Result.ServiceName) != 0 { - jsonEntry["service_name"] = entry.Result.ServiceName - } - - answers := answerToMap(a) - if answers != nil { - jsonEntry["answer"] = answers - } - - if len(entry.OrigAnswer) != 0 { - a := new(dns.Msg) - err := a.Unpack(entry.OrigAnswer) - if err == nil { - answers = answerToMap(a) - if answers != nil { - jsonEntry["original_answer"] = answers - } - } else { - log.Debug("Querylog: a.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) - } - } - + jsonEntry := logEntryToJSONEntry(entry) data = append(data, jsonEntry) } @@ -376,6 +251,62 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { return result } +func logEntryToJSONEntry(entry *logEntry) map[string]interface{} { + var msg *dns.Msg + + if len(entry.Answer) > 0 { + msg = new(dns.Msg) + if err := msg.Unpack(entry.Answer); err != nil { + log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer)) + msg = nil + } + } + + jsonEntry := map[string]interface{}{ + "reason": entry.Result.Reason.String(), + "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), + "time": entry.Time.Format(time.RFC3339Nano), + "client": entry.IP, + } + jsonEntry["question"] = map[string]interface{}{ + "host": entry.QHost, + "type": entry.QType, + "class": entry.QClass, + } + + if msg != nil { + jsonEntry["status"] = dns.RcodeToString[msg.Rcode] + } + if len(entry.Result.Rule) > 0 { + jsonEntry["rule"] = entry.Result.Rule + jsonEntry["filterId"] = entry.Result.FilterID + } + + if len(entry.Result.ServiceName) != 0 { + jsonEntry["service_name"] = entry.Result.ServiceName + } + + answers := answerToMap(msg) + if answers != nil { + jsonEntry["answer"] = answers + } + + if len(entry.OrigAnswer) != 0 { + a := new(dns.Msg) + err := a.Unpack(entry.OrigAnswer) + if err == nil { + answers = answerToMap(a) + if answers != nil { + jsonEntry["original_answer"] = answers + } + } else { + log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) + } + } + + return jsonEntry +} + func answerToMap(a *dns.Msg) []map[string]interface{} { if a == nil || len(a.Answer) == 0 { return nil diff --git a/querylog/qlog_file.go b/querylog/qlog_file.go index 14ae2b32..24c5a4e2 100644 --- a/querylog/qlog_file.go +++ b/querylog/qlog_file.go @@ -62,7 +62,7 @@ func NewQLogFile(path string) (*QLogFile, error) { // It returns the position of the the line with the timestamp we were looking for // so that when we call "ReadNext" this line was returned. // If we could not find it, it returns 0 and ErrSeekNotFound -func (q *QLogFile) Seek(timestamp uint64) (int64, error) { +func (q *QLogFile) Seek(timestamp int64) (int64, error) { q.lock.Lock() defer q.lock.Unlock() @@ -305,7 +305,7 @@ func (q *QLogFile) readProbeLine(position int64) (string, int64, error) { } // readQLogTimestamp reads the timestamp field from the query log line -func readQLogTimestamp(str string) uint64 { +func readQLogTimestamp(str string) int64 { val := readJSONValue(str, "T") if len(val) == 0 { val = readJSONValue(str, "Time") @@ -320,5 +320,5 @@ func readQLogTimestamp(str string) uint64 { log.Error("Couldn't parse timestamp: %s", val) return 0 } - return uint64(tm.UnixNano()) + return tm.UnixNano() } diff --git a/querylog/qlog_file_test.go b/querylog/qlog_file_test.go index 2ba6f38d..c5871475 100644 --- a/querylog/qlog_file_test.go +++ b/querylog/qlog_file_test.go @@ -100,7 +100,7 @@ func TestQLogFileSeekLargeFile(t *testing.T) { // CASE 6: Seek non-existent (too high) ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") - _, err = q.Seek(uint64(ts.UnixNano())) + _, err = q.Seek(ts.UnixNano()) assert.NotNil(t, err) } @@ -136,7 +136,7 @@ func TestQLogFileSeekSmallFile(t *testing.T) { // CASE 6: Seek non-existent (too high) ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") - _, err = q.Seek(uint64(ts.UnixNano())) + _, err = q.Seek(ts.UnixNano()) assert.NotNil(t, err) } diff --git a/querylog/qlog_reader.go b/querylog/qlog_reader.go index 6b081cc2..5f4f4dfa 100644 --- a/querylog/qlog_reader.go +++ b/querylog/qlog_reader.go @@ -48,7 +48,7 @@ func NewQLogReader(files []string) (*QLogReader, error) { // // Returns nil if the record is successfully found. // Returns an error if for some reason we could not find a record with the specified timestamp. -func (r *QLogReader) Seek(timestamp uint64) error { +func (r *QLogReader) Seek(timestamp int64) error { for i := len(r.qFiles) - 1; i >= 0; i-- { q := r.qFiles[i] _, err := q.Seek(timestamp) diff --git a/querylog/qlog_reader_test.go b/querylog/qlog_reader_test.go index 592ab4b9..357b4f9d 100644 --- a/querylog/qlog_reader_test.go +++ b/querylog/qlog_reader_test.go @@ -122,7 +122,7 @@ func TestQLogReaderSeek(t *testing.T) { // CASE 6: Seek non-existent (too high) ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") - err = r.Seek(uint64(ts.UnixNano())) + err = r.Seek(ts.UnixNano()) assert.NotNil(t, err) } diff --git a/querylog/querylog_file.go b/querylog/querylog_file.go index e5544fc4..3b32b7a2 100644 --- a/querylog/querylog_file.go +++ b/querylog/querylog_file.go @@ -1,23 +1,16 @@ package querylog import ( - "bufio" "bytes" - "encoding/base64" "encoding/json" - "io" "os" - "strconv" - "strings" "time" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/golibs/log" - "github.com/miekg/dns" ) // TODO: Check this when we append a new line -- we don't want to have a line longer than this -const maxEntrySize = 1000 +const maxEntrySize = 1024 // flushLogBuffer flushes the current buffer to file and resets the current buffer func (l *queryLog) flushLogBuffer(fullFlush bool) error { @@ -117,591 +110,3 @@ func (l *queryLog) periodicRotate() { } } } - -// Reader is the DB reader context -type Reader struct { - ql *queryLog - search *getDataParams - - f *os.File - reader *bufio.Reader // reads file line by line - now time.Time - validFrom int64 // UNIX time (ns) - olderThan int64 // UNIX time (ns) - oldest time.Time - - files []string - ifile int - - limit uint64 - count uint64 // counter for returned elements - latest bool // return the latest entries - filePrepared bool - - seeking bool // we're seaching for an entry with exact time stamp - fseeker fileSeeker // file seeker object - fpos uint64 // current file offset - nSeekRequests uint32 // number of Seek() requests made (finding a new line doesn't count) -} - -type fileSeeker struct { - target uint64 // target value - - pos uint64 // current offset, may be adjusted by user for increased accuracy - lastpos uint64 // the last offset returned - lo uint64 // low boundary offset - hi uint64 // high boundary offset -} - -// OpenReader - return reader object -func (l *queryLog) OpenReader() *Reader { - r := Reader{} - r.ql = l - r.now = time.Now() - r.validFrom = r.now.Unix() - int64(l.conf.Interval*24*60*60) - r.validFrom *= 1000000000 - r.files = []string{ - r.ql.logFile, - r.ql.logFile + ".1", - } - return &r -} - -// Close - close the reader -func (r *Reader) Close() { - elapsed := time.Since(r.now) - var perunit time.Duration - if r.count > 0 { - perunit = elapsed / time.Duration(r.count) - } - log.Debug("querylog: read %d entries in %v, %v/entry, seek-reqs:%d", - r.count, elapsed, perunit, r.nSeekRequests) - - if r.f != nil { - r.f.Close() - } -} - -// BeginRead - start reading -// olderThan: stop returning entries when an entry with this time is reached -// count: minimum number of entries to return -func (r *Reader) BeginRead(olderThan time.Time, count uint64, search *getDataParams) { - r.olderThan = olderThan.UnixNano() - r.latest = olderThan.IsZero() - r.oldest = time.Time{} - r.search = search - r.limit = count - if r.latest { - r.olderThan = r.now.UnixNano() - } - r.filePrepared = false - r.seeking = false -} - -// BeginReadPrev - start reading the previous data chunk -func (r *Reader) BeginReadPrev(count uint64) { - r.olderThan = r.oldest.UnixNano() - r.oldest = time.Time{} - r.latest = false - r.limit = count - r.count = 0 - - off := r.fpos - maxEntrySize*(r.limit+1) - if int64(off) < maxEntrySize { - off = 0 - } - r.fpos = off - log.Debug("QueryLog: seek: %x", off) - _, err := r.f.Seek(int64(off), io.SeekStart) - if err != nil { - log.Error("file.Seek: %s: %s", r.files[r.ifile], err) - return - } - r.nSeekRequests++ - - r.seekToNewLine() - r.fseeker.pos = r.fpos - - r.filePrepared = true - r.seeking = false -} - -// Perform binary seek -// Return 0: success; 1: seek reqiured; -1: error -func (fs *fileSeeker) seekBinary(cur uint64) int32 { - log.Debug("QueryLog: seek: tgt=%x cur=%x, %x: [%x..%x]", fs.target, cur, fs.pos, fs.lo, fs.hi) - - off := uint64(0) - if fs.pos >= fs.lo && fs.pos < fs.hi { - if cur == fs.target { - return 0 - } else if cur < fs.target { - fs.lo = fs.pos + 1 - } else { - fs.hi = fs.pos - } - off = fs.lo + (fs.hi-fs.lo)/2 - } else { - // we didn't find another entry from the last file offset: now return the boundary beginning - off = fs.lo - } - - if off == fs.lastpos { - return -1 - } - - fs.lastpos = off - fs.pos = off - return 1 -} - -// Seek to a new line -func (r *Reader) seekToNewLine() bool { - r.reader = bufio.NewReader(r.f) - b, err := r.reader.ReadBytes('\n') - if err != nil { - r.reader = nil - log.Error("QueryLog: file.Read: %s: %s", r.files[r.ifile], err) - return false - } - - off := len(b) - r.fpos += uint64(off) - log.Debug("QueryLog: seek: %x (+%d)", r.fpos, off) - return true -} - -// Open a file -func (r *Reader) openFile() bool { - var err error - fn := r.files[r.ifile] - - r.f, err = os.Open(fn) - if err != nil { - if !os.IsNotExist(err) { - log.Error("QueryLog: Failed to open file \"%s\": %s", fn, err) - } - return false - } - return true -} - -// Seek to the needed position -func (r *Reader) prepareRead() bool { - fn := r.files[r.ifile] - - fi, err := r.f.Stat() - if err != nil { - log.Error("QueryLog: file.Stat: %s: %s", fn, err) - return false - } - fsize := uint64(fi.Size()) - - off := uint64(0) - if r.latest { - // read data from the end of file - off = fsize - maxEntrySize*(r.limit+1) - if int64(off) < maxEntrySize { - off = 0 - } - r.fpos = off - log.Debug("QueryLog: seek: %x", off) - _, err = r.f.Seek(int64(off), io.SeekStart) - if err != nil { - log.Error("QueryLog: file.Seek: %s: %s", fn, err) - return false - } - } else { - // start searching in file: we'll read the first chunk of data from the middle of file - r.seeking = true - r.fseeker = fileSeeker{} - r.fseeker.target = uint64(r.olderThan) - r.fseeker.hi = fsize - rc := r.fseeker.seekBinary(0) - r.fpos = r.fseeker.pos - if rc == 1 { - _, err = r.f.Seek(int64(r.fpos), io.SeekStart) - if err != nil { - log.Error("QueryLog: file.Seek: %s: %s", fn, err) - return false - } - } - } - r.nSeekRequests++ - - if !r.seekToNewLine() { - return false - } - r.fseeker.pos = r.fpos - return true -} - -// Get bool value from "key":bool -func readJSONBool(s, name string) (bool, bool) { - i := strings.Index(s, "\""+name+"\":") - if i == -1 { - return false, false - } - start := i + 1 + len(name) + 2 - b := false - if strings.HasPrefix(s[start:], "true") { - b = true - } else if !strings.HasPrefix(s[start:], "false") { - return false, false - } - return b, true -} - -// Get value from "key":"value" -func readJSONValue(s, name string) string { - i := strings.Index(s, "\""+name+"\":\"") - if i == -1 { - return "" - } - start := i + 1 + len(name) + 3 - i = strings.IndexByte(s[start:], '"') - if i == -1 { - return "" - } - end := start + i - return s[start:end] -} - -// nolint (gocyclo) -func (r *Reader) applySearch(str string) bool { - if r.search.ResponseStatus == responseStatusFiltered { - boolVal, ok := readJSONBool(str, "IsFiltered") - if !ok || !boolVal { - return false - } - } - - mq := dns.Msg{} - - if len(r.search.Domain) != 0 { - val := readJSONValue(str, "QH") - if len(val) == 0 { - // pre-v0.99.3 compatibility - val = readJSONValue(str, "Question") - if len(val) == 0 { - return false - } - bval, err := base64.StdEncoding.DecodeString(val) - if err != nil { - return false - } - err = mq.Unpack(bval) - if err != nil { - return false - } - val = strings.TrimSuffix(mq.Question[0].Name, ".") - } - if len(val) == 0 { - return false - } - - if (r.search.StrictMatchDomain && val != r.search.Domain) || - (!r.search.StrictMatchDomain && strings.Index(val, r.search.Domain) == -1) { - return false - } - } - - if len(r.search.QuestionType) != 0 { - val := readJSONValue(str, "QT") - if len(val) == 0 { - // pre-v0.99.3 compatibility - if len(mq.Question) == 0 { - val = readJSONValue(str, "Question") - if len(val) == 0 { - return false - } - bval, err := base64.StdEncoding.DecodeString(val) - if err != nil { - return false - } - err = mq.Unpack(bval) - if err != nil { - return false - } - } - ok := false - val, ok = dns.TypeToString[mq.Question[0].Qtype] - if !ok { - return false - } - } - if val != r.search.QuestionType { - return false - } - } - - if len(r.search.Client) != 0 { - val := readJSONValue(str, "IP") - if len(val) == 0 { - log.Debug("QueryLog: failed to decode") - return false - } - - if (r.search.StrictMatchClient && val != r.search.Client) || - (!r.search.StrictMatchClient && strings.Index(val, r.search.Client) == -1) { - return false - } - } - - return true -} - -const ( - jsonTErr = iota - jsonTObj - jsonTStr - jsonTNum - jsonTBool -) - -// Parse JSON key-value pair -// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number) -// Note the limitations: -// . doesn't support whitespace -// . doesn't support "null" -// . doesn't validate boolean or number -// . no proper handling of {} braces -// . no handling of [] brackets -// Return (key, value, type) -func readJSON(ps *string) (string, string, int32) { - s := *ps - k := "" - v := "" - t := int32(jsonTErr) - - q1 := strings.IndexByte(s, '"') - if q1 == -1 { - return k, v, t - } - q2 := strings.IndexByte(s[q1+1:], '"') - if q2 == -1 { - return k, v, t - } - k = s[q1+1 : q1+1+q2] - s = s[q1+1+q2+1:] - - if len(s) < 2 || s[0] != ':' { - return k, v, t - } - - if s[1] == '"' { - q2 = strings.IndexByte(s[2:], '"') - if q2 == -1 { - return k, v, t - } - v = s[2 : 2+q2] - t = jsonTStr - s = s[2+q2+1:] - - } else if s[1] == '{' { - t = jsonTObj - s = s[1+1:] - - } else { - sep := strings.IndexAny(s[1:], ",}") - if sep == -1 { - return k, v, t - } - v = s[1 : 1+sep] - if s[1] == 't' || s[1] == 'f' { - t = jsonTBool - } else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') { - t = jsonTNum - } - s = s[1+sep+1:] - } - - *ps = s - return k, v, t -} - -// nolint (gocyclo) -func decode(ent *logEntry, str string) { - var b bool - var i int - var err error - for { - k, v, t := readJSON(&str) - if t == jsonTErr { - break - } - switch k { - case "IP": - if len(ent.IP) == 0 { - ent.IP = v - } - case "T": - ent.Time, err = time.Parse(time.RFC3339, v) - - case "QH": - ent.QHost = v - case "QT": - ent.QType = v - case "QC": - ent.QClass = v - - case "Answer": - ent.Answer, err = base64.StdEncoding.DecodeString(v) - case "OrigAnswer": - ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) - - case "IsFiltered": - b, err = strconv.ParseBool(v) - ent.Result.IsFiltered = b - case "Rule": - ent.Result.Rule = v - case "FilterID": - i, err = strconv.Atoi(v) - ent.Result.FilterID = int64(i) - case "Reason": - i, err = strconv.Atoi(v) - ent.Result.Reason = dnsfilter.Reason(i) - - case "Upstream": - ent.Upstream = v - case "Elapsed": - i, err = strconv.Atoi(v) - ent.Elapsed = time.Duration(i) - - // pre-v0.99.3 compatibility: - case "Question": - var qstr []byte - qstr, err = base64.StdEncoding.DecodeString(v) - if err != nil { - break - } - q := new(dns.Msg) - err = q.Unpack(qstr) - if err != nil { - break - } - ent.QHost = q.Question[0].Name - if len(ent.QHost) == 0 { - break - } - ent.QHost = ent.QHost[:len(ent.QHost)-1] - ent.QType = dns.TypeToString[q.Question[0].Qtype] - ent.QClass = dns.ClassToString[q.Question[0].Qclass] - case "Time": - ent.Time, err = time.Parse(time.RFC3339, v) - } - - if err != nil { - log.Debug("decode err: %s", err) - break - } - } -} - -// Next - return the next entry or nil if reading is finished -func (r *Reader) Next() *logEntry { // nolint - for { - // open file if needed - if r.f == nil { - if r.ifile == len(r.files) { - return nil - } - if !r.openFile() { - r.ifile++ - continue - } - } - - if !r.filePrepared { - if !r.prepareRead() { - return nil - } - r.filePrepared = true - } - - b, err := r.reader.ReadBytes('\n') - if err != nil { - return nil - } - str := string(b) - - val := readJSONValue(str, "T") - if len(val) == 0 { - val = readJSONValue(str, "Time") - } - if len(val) == 0 { - log.Debug("QueryLog: failed to decode") - continue - } - tm, err := time.Parse(time.RFC3339, val) - if err != nil { - log.Debug("QueryLog: failed to decode") - continue - } - t := tm.UnixNano() - - if r.seeking { - - r.reader = nil - rr := r.fseeker.seekBinary(uint64(t)) - r.fpos = r.fseeker.pos - if rr < 0 { - log.Error("QueryLog: File seek error: can't find the target entry: %s", r.files[r.ifile]) - return nil - } else if rr == 0 { - // We found the target entry. - // We'll start reading the previous chunk of data. - r.seeking = false - - off := r.fpos - (maxEntrySize * (r.limit + 1)) - if int64(off) < maxEntrySize { - off = 0 - } - r.fpos = off - } - - _, err = r.f.Seek(int64(r.fpos), io.SeekStart) - if err != nil { - log.Error("QueryLog: file.Seek: %s: %s", r.files[r.ifile], err) - return nil - } - r.nSeekRequests++ - - if !r.seekToNewLine() { - return nil - } - r.fseeker.pos = r.fpos - continue - } - - if r.oldest.IsZero() { - r.oldest = tm - } - - if t < r.validFrom { - continue - } - if t >= r.olderThan { - return nil - } - r.count++ - - if !r.applySearch(str) { - continue - } - - var ent logEntry - decode(&ent, str) - return &ent - } -} - -// Total returns the total number of processed items -func (r *Reader) Total() uint64 { - return r.count -} - -// Oldest returns the time of the oldest processed entry -func (r *Reader) Oldest() time.Time { - return r.oldest -} diff --git a/querylog/querylog_search.go b/querylog/querylog_search.go new file mode 100644 index 00000000..236c1940 --- /dev/null +++ b/querylog/querylog_search.go @@ -0,0 +1,366 @@ +package querylog + +import ( + "encoding/base64" + "io" + "strconv" + "strings" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +// searchFiles reads log entries from all log files and applies the specified search criteria. +// IMPORTANT: this method does not scan more than "maxSearchEntries" so you +// may need to call it many times. +// +// it returns: +// * an array of log entries that we have read +// * time of the oldest processed entry (even if it was discarded) +// * total number of processed entries (including discarded). +func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, int) { + entries := make([]*logEntry, 0) + oldest := time.Time{} + + r, err := l.openReader() + if err != nil { + log.Error("Failed to open qlog reader: %v", err) + return entries, oldest, 0 + } + defer r.Close() + + if params.OlderThan.IsZero() { + err = r.SeekStart() + } else { + err = r.Seek(params.OlderThan.UnixNano()) + } + + if err != nil { + log.Error("Failed to Seek(): %v", err) + return entries, oldest, 0 + } + + total := 0 + oldestNano := int64(0) + // Do not scan more than 50k at once + for total <= maxSearchEntries { + entry, ts, err := l.readNextEntry(r, params) + + if err == io.EOF { + // there's nothing to read anymore + break + } + + if entry != nil { + entries = append(entries, entry) + } + + oldestNano = ts + total++ + } + + oldest = time.Unix(0, oldestNano) + return entries, oldest, total +} + +// readNextEntry - reads the next log entry and checks if it matches the search criteria (getDataParams) +// +// returns: +// * log entry that matches search criteria or null if it was discarded (or if there's nothing to read) +// * timestamp of the processed log entry +// * error if we can't read anymore +func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry, int64, error) { + line, err := r.ReadNext() + if err != nil { + return nil, 0, err + } + + // Read the log record timestamp right away + timestamp := readQLogTimestamp(line) + + // Quick check without deserializing log entry + if !quickMatchesGetDataParams(line, params) { + return nil, timestamp, nil + } + + entry := logEntry{} + decodeLogEntry(&entry, line) + + // Full check of the deserialized log entry + if !matchesGetDataParams(&entry, params) { + return nil, timestamp, nil + } + + return &entry, timestamp, nil +} + +// openReader - opens QLogReader instance +func (l *queryLog) openReader() (*QLogReader, error) { + files := make([]string, 0) + + if util.FileExists(l.logFile + ".1") { + files = append(files, l.logFile+".1") + } + if util.FileExists(l.logFile) { + files = append(files, l.logFile) + } + + return NewQLogReader(files) +} + +// quickMatchesGetDataParams - quickly checks if the line matches getDataParams +// this method does not guarantee anything and the reason is to do a quick check +// without deserializing anything +func quickMatchesGetDataParams(line string, params getDataParams) bool { + if params.ResponseStatus == responseStatusFiltered { + boolVal, ok := readJSONBool(line, "IsFiltered") + if !ok || !boolVal { + return false + } + } + + if len(params.Domain) != 0 { + val := readJSONValue(line, "QH") + if len(val) == 0 { + return false + } + + if (params.StrictMatchDomain && val != params.Domain) || + (!params.StrictMatchDomain && strings.Index(val, params.Domain) == -1) { + return false + } + } + + if len(params.QuestionType) != 0 { + val := readJSONValue(line, "QT") + if val != params.QuestionType { + return false + } + } + + if len(params.Client) != 0 { + val := readJSONValue(line, "IP") + if len(val) == 0 { + log.Debug("QueryLog: failed to decodeLogEntry") + return false + } + + if (params.StrictMatchClient && val != params.Client) || + (!params.StrictMatchClient && strings.Index(val, params.Client) == -1) { + return false + } + } + + return true +} + +// matchesGetDataParams - returns true if the entry matches the search parameters +func matchesGetDataParams(entry *logEntry, params getDataParams) bool { + if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered { + return false + } + + if len(params.QuestionType) != 0 { + if entry.QType != params.QuestionType { + return false + } + } + + if len(params.Domain) != 0 { + if (params.StrictMatchDomain && entry.QHost != params.Domain) || + (!params.StrictMatchDomain && strings.Index(entry.QHost, params.Domain) == -1) { + return false + } + } + + if len(params.Client) != 0 { + if (params.StrictMatchClient && entry.IP != params.Client) || + (!params.StrictMatchClient && strings.Index(entry.IP, params.Client) == -1) { + return false + } + } + + return true +} + +// decodeLogEntry - decodes query log entry from a line +// nolint (gocyclo) +func decodeLogEntry(ent *logEntry, str string) { + var b bool + var i int + var err error + for { + k, v, t := readJSON(&str) + if t == jsonTErr { + break + } + switch k { + case "IP": + if len(ent.IP) == 0 { + ent.IP = v + } + case "T": + ent.Time, err = time.Parse(time.RFC3339, v) + + case "QH": + ent.QHost = v + case "QT": + ent.QType = v + case "QC": + ent.QClass = v + + case "Answer": + ent.Answer, err = base64.StdEncoding.DecodeString(v) + case "OrigAnswer": + ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) + + case "IsFiltered": + b, err = strconv.ParseBool(v) + ent.Result.IsFiltered = b + case "Rule": + ent.Result.Rule = v + case "FilterID": + i, err = strconv.Atoi(v) + ent.Result.FilterID = int64(i) + case "Reason": + i, err = strconv.Atoi(v) + ent.Result.Reason = dnsfilter.Reason(i) + + case "Upstream": + ent.Upstream = v + case "Elapsed": + i, err = strconv.Atoi(v) + ent.Elapsed = time.Duration(i) + + // pre-v0.99.3 compatibility: + case "Question": + var qstr []byte + qstr, err = base64.StdEncoding.DecodeString(v) + if err != nil { + break + } + q := new(dns.Msg) + err = q.Unpack(qstr) + if err != nil { + break + } + ent.QHost = q.Question[0].Name + if len(ent.QHost) == 0 { + break + } + ent.QHost = ent.QHost[:len(ent.QHost)-1] + ent.QType = dns.TypeToString[q.Question[0].Qtype] + ent.QClass = dns.ClassToString[q.Question[0].Qclass] + case "Time": + ent.Time, err = time.Parse(time.RFC3339, v) + } + + if err != nil { + log.Debug("decodeLogEntry err: %s", err) + break + } + } +} + +// Get bool value from "key":bool +func readJSONBool(s, name string) (bool, bool) { + i := strings.Index(s, "\""+name+"\":") + if i == -1 { + return false, false + } + start := i + 1 + len(name) + 2 + b := false + if strings.HasPrefix(s[start:], "true") { + b = true + } else if !strings.HasPrefix(s[start:], "false") { + return false, false + } + return b, true +} + +// Get value from "key":"value" +func readJSONValue(s, name string) string { + i := strings.Index(s, "\""+name+"\":\"") + if i == -1 { + return "" + } + start := i + 1 + len(name) + 3 + i = strings.IndexByte(s[start:], '"') + if i == -1 { + return "" + } + end := start + i + return s[start:end] +} + +const ( + jsonTErr = iota + jsonTObj + jsonTStr + jsonTNum + jsonTBool +) + +// Parse JSON key-value pair +// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number) +// Note the limitations: +// . doesn't support whitespace +// . doesn't support "null" +// . doesn't validate boolean or number +// . no proper handling of {} braces +// . no handling of [] brackets +// Return (key, value, type) +func readJSON(ps *string) (string, string, int32) { + s := *ps + k := "" + v := "" + t := int32(jsonTErr) + + q1 := strings.IndexByte(s, '"') + if q1 == -1 { + return k, v, t + } + q2 := strings.IndexByte(s[q1+1:], '"') + if q2 == -1 { + return k, v, t + } + k = s[q1+1 : q1+1+q2] + s = s[q1+1+q2+1:] + + if len(s) < 2 || s[0] != ':' { + return k, v, t + } + + if s[1] == '"' { + q2 = strings.IndexByte(s[2:], '"') + if q2 == -1 { + return k, v, t + } + v = s[2 : 2+q2] + t = jsonTStr + s = s[2+q2+1:] + + } else if s[1] == '{' { + t = jsonTObj + s = s[1+1:] + + } else { + sep := strings.IndexAny(s[1:], ",}") + if sep == -1 { + return k, v, t + } + v = s[1 : 1+sep] + if s[1] == 't' || s[1] == 'f' { + t = jsonTBool + } else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') { + t = jsonTNum + } + s = s[1+sep+1:] + } + + *ps = s + return k, v, t +} diff --git a/querylog/querylog_test.go b/querylog/querylog_test.go index 8c8b9bb4..8d587b7c 100644 --- a/querylog/querylog_test.go +++ b/querylog/querylog_test.go @@ -30,14 +30,17 @@ func TestQueryLog(t *testing.T) { l := newQueryLog(conf) // add disk entries - addEntry(l, "example.org", "1.2.3.4", "0.1.2.3") - addEntry(l, "example.org", "1.2.3.4", "0.1.2.3") - + addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") + // write to disk (first file) + _ = l.flushLogBuffer(true) + // start writing to the second file + _ = l.rotate() + // add disk entries + addEntry(l, "example.org", "1.1.1.2", "2.2.2.2") // write to disk - l.flushLogBuffer(true) - + _ = l.flushLogBuffer(true) // add memory entries - addEntry(l, "test.example.org", "2.2.3.4", "0.1.2.4") + addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3") // get all entries params := getDataParams{ @@ -45,9 +48,10 @@ func TestQueryLog(t *testing.T) { } d := l.getData(params) mdata := d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 2) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 3, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) // search by domain (strict) params = getDataParams{ @@ -58,9 +62,9 @@ func TestQueryLog(t *testing.T) { d = l.getData(params) mdata = d["data"].([]map[string]interface{}) assert.True(t, len(mdata) == 1) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) + assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) - // search by domain + // search by domain (not strict) params = getDataParams{ OlderThan: time.Time{}, Domain: "example.org", @@ -68,32 +72,34 @@ func TestQueryLog(t *testing.T) { } d = l.getData(params) mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 2) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 3, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) // search by client IP (strict) params = getDataParams{ OlderThan: time.Time{}, - Client: "0.1.2.3", + Client: "2.2.2.2", StrictMatchClient: true, } d = l.getData(params) mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 1) - assert.True(t, checkEntry(t, mdata[0], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 1, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "example.org", "1.1.1.2", "2.2.2.2")) - // search by client IP + // search by client IP (part of) params = getDataParams{ OlderThan: time.Time{}, - Client: "0.1.2", + Client: "2.2.2", StrictMatchClient: false, } d = l.getData(params) mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 2) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 3, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) } func addEntry(l *queryLog, host, answerStr, client string) { @@ -129,11 +135,11 @@ func checkEntry(t *testing.T, m map[string]interface{}, host, answer, client str mq := m["question"].(map[string]interface{}) ma := m["answer"].([]map[string]interface{}) ma0 := ma[0] - if !assert.True(t, mq["host"].(string) == host) || - !assert.True(t, mq["class"].(string) == "IN") || - !assert.True(t, mq["type"].(string) == "A") || - !assert.True(t, ma0["value"].(string) == answer) || - !assert.True(t, m["client"].(string) == client) { + if !assert.Equal(t, host, mq["host"].(string)) || + !assert.Equal(t, "IN", mq["class"].(string)) || + !assert.Equal(t, "A", mq["type"].(string)) || + !assert.Equal(t, answer, ma0["value"].(string)) || + !assert.Equal(t, client, m["client"].(string)) { return false } return true From 87c04109146a0956261f7caa4af28167c406cc6d Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Fri, 21 Feb 2020 12:57:12 +0300 Subject: [PATCH 29/31] -(dnsforward): start using lastProbeLineIdx in the Seek --- querylog/qlog_file.go | 23 ++++++++++++++--------- querylog/qlog_file_test.go | 31 ++++++++++++++++++++++++++----- querylog/qlog_reader.go | 2 +- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/querylog/qlog_file.go b/querylog/qlog_file.go index 24c5a4e2..54d1896c 100644 --- a/querylog/qlog_file.go +++ b/querylog/qlog_file.go @@ -59,10 +59,12 @@ func NewQLogFile(path string) (*QLogFile, error) { // it shifts seek position to 3/4 of the file. Otherwise, to 1/4 of the file. // 5. It performs the search again, every time the search scope is narrowed twice. // -// It returns the position of the the line with the timestamp we were looking for +// Returns: +// * It returns the position of the the line with the timestamp we were looking for // so that when we call "ReadNext" this line was returned. -// If we could not find it, it returns 0 and ErrSeekNotFound -func (q *QLogFile) Seek(timestamp int64) (int64, error) { +// * Depth of the search (how many times we compared timestamps). +// * If we could not find it, it returns ErrSeekNotFound +func (q *QLogFile) Seek(timestamp int64) (int64, int, error) { q.lock.Lock() defer q.lock.Unlock() @@ -72,7 +74,7 @@ func (q *QLogFile) Seek(timestamp int64) (int64, error) { // First of all, check the file size fileInfo, err := q.file.Stat() if err != nil { - return 0, err + return 0, 0, err } // Define the search scope @@ -91,14 +93,14 @@ func (q *QLogFile) Seek(timestamp int64) (int64, error) { // Get the line at the specified position line, lineIdx, err = q.readProbeLine(probe) if err != nil { - return 0, err + return 0, depth, err } // Get the timestamp from the query log record ts := readQLogTimestamp(line) if ts == 0 { - return 0, ErrSeekNotFound + return 0, depth, ErrSeekNotFound } if ts == timestamp { @@ -109,9 +111,12 @@ func (q *QLogFile) Seek(timestamp int64) (int64, error) { if lastProbeLineIdx == lineIdx { // If we're testing the same line twice then most likely // the scope is too narrow and we won't find anything anymore - return 0, ErrSeekNotFound + return 0, depth, ErrSeekNotFound } + // Save the last found idx + lastProbeLineIdx = lineIdx + // Narrow the scope and repeat the search if ts > timestamp { // If the timestamp we're looking for is OLDER than what we found @@ -128,12 +133,12 @@ func (q *QLogFile) Seek(timestamp int64) (int64, error) { depth++ if depth >= 100 { log.Error("Seek depth is too high, aborting. File %s, ts %v", q.file.Name(), timestamp) - return 0, ErrSeekNotFound + return 0, depth, ErrSeekNotFound } } q.position = lineIdx + int64(len(line)) - return q.position, nil + return q.position, depth, nil } // SeekStart changes the current position to the end of the file diff --git a/querylog/qlog_file_test.go b/querylog/qlog_file_test.go index c5871475..a0fa07f1 100644 --- a/querylog/qlog_file_test.go +++ b/querylog/qlog_file_test.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" "io/ioutil" + "math" "net" "os" "strings" @@ -95,13 +96,23 @@ func TestQLogFileSeekLargeFile(t *testing.T) { testSeekLineQLogFile(t, q, count) // CASE 5: Seek non-existent (too low) - _, err = q.Seek(123) + _, _, err = q.Seek(123) assert.NotNil(t, err) // CASE 6: Seek non-existent (too high) ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") - _, err = q.Seek(ts.UnixNano()) + _, _, err = q.Seek(ts.UnixNano()) assert.NotNil(t, err) + + // CASE 7: "Almost" found + line, err := getQLogFileLine(q, count/2) + assert.Nil(t, err) + // ALMOST the record we need + timestamp := readQLogTimestamp(line) - 1 + assert.NotEqual(t, uint64(0), timestamp) + _, depth, err := q.Seek(timestamp) + assert.NotNil(t, err) + assert.True(t, depth <= int(math.Log2(float64(count))+3)) } func TestQLogFileSeekSmallFile(t *testing.T) { @@ -131,13 +142,23 @@ func TestQLogFileSeekSmallFile(t *testing.T) { testSeekLineQLogFile(t, q, count) // CASE 5: Seek non-existent (too low) - _, err = q.Seek(123) + _, _, err = q.Seek(123) assert.NotNil(t, err) // CASE 6: Seek non-existent (too high) ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") - _, err = q.Seek(ts.UnixNano()) + _, _, err = q.Seek(ts.UnixNano()) assert.NotNil(t, err) + + // CASE 7: "Almost" found + line, err := getQLogFileLine(q, count/2) + assert.Nil(t, err) + // ALMOST the record we need + timestamp := readQLogTimestamp(line) - 1 + assert.NotEqual(t, uint64(0), timestamp) + _, depth, err := q.Seek(timestamp) + assert.NotNil(t, err) + assert.True(t, depth <= int(math.Log2(float64(count))+3)) } func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) { @@ -147,7 +168,7 @@ func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) { assert.NotEqual(t, uint64(0), ts) // try seeking to that line now - pos, err := q.Seek(ts) + pos, _, err := q.Seek(ts) assert.Nil(t, err) assert.NotEqual(t, int64(0), pos) diff --git a/querylog/qlog_reader.go b/querylog/qlog_reader.go index 5f4f4dfa..ee2f617d 100644 --- a/querylog/qlog_reader.go +++ b/querylog/qlog_reader.go @@ -51,7 +51,7 @@ func NewQLogReader(files []string) (*QLogReader, error) { func (r *QLogReader) Seek(timestamp int64) error { for i := len(r.qFiles) - 1; i >= 0; i-- { q := r.qFiles[i] - _, err := q.Seek(timestamp) + _, _, err := q.Seek(timestamp) if err == nil { // Our search is finished, we found the element we were looking for // Update currentFile only, position is already set properly in the QLogFile From 9d8a95f83602fae47d5efd6d1df4e4806ed12580 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Fri, 21 Feb 2020 13:01:41 +0300 Subject: [PATCH 30/31] *(dnsforward): fix in-memory log entries sort orderr --- querylog/qlog.go | 5 ++++- querylog/querylog_test.go | 19 +++++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/querylog/qlog.go b/querylog/qlog.go index cafcdb5c..fb49fa35 100644 --- a/querylog/qlog.go +++ b/querylog/qlog.go @@ -206,7 +206,10 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { l.bufferLock.Lock() total += len(l.buffer) memoryEntries := make([]*logEntry, 0) - for _, entry := range l.buffer { + + // go through the buffer in the reverse order + for i := len(l.buffer) - 1; i >= 0; i-- { + entry := l.buffer[i] if !matchesGetDataParams(entry, params) { continue } diff --git a/querylog/querylog_test.go b/querylog/querylog_test.go index 8d587b7c..06de4101 100644 --- a/querylog/querylog_test.go +++ b/querylog/querylog_test.go @@ -41,6 +41,7 @@ func TestQueryLog(t *testing.T) { _ = l.flushLogBuffer(true) // add memory entries addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3") + addEntry(l, "example.com", "1.1.1.4", "2.2.2.4") // get all entries params := getDataParams{ @@ -48,10 +49,11 @@ func TestQueryLog(t *testing.T) { } d := l.getData(params) mdata := d["data"].([]map[string]interface{}) - assert.Equal(t, 3, len(mdata)) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) - assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) + assert.Equal(t, 4, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4")) + assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1")) // search by domain (strict) params = getDataParams{ @@ -96,10 +98,11 @@ func TestQueryLog(t *testing.T) { } d = l.getData(params) mdata = d["data"].([]map[string]interface{}) - assert.Equal(t, 3, len(mdata)) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) - assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) + assert.Equal(t, 4, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4")) + assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1")) } func addEntry(l *queryLog, host, answerStr, client string) { From df427b6822c843a233ec0cad48624d2b1de87843 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Fri, 21 Feb 2020 16:50:20 +0300 Subject: [PATCH 31/31] *(dnsforward): fix reading in-memory entries --- querylog/qlog.go | 9 ++++++--- querylog/qlog_file.go | 6 +++++- querylog/querylog_file.go | 3 --- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/querylog/qlog.go b/querylog/qlog.go index fb49fa35..a0570011 100644 --- a/querylog/qlog.go +++ b/querylog/qlog.go @@ -208,14 +208,17 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { memoryEntries := make([]*logEntry, 0) // go through the buffer in the reverse order + // from NEWER to OLDER for i := len(l.buffer) - 1; i >= 0; i-- { entry := l.buffer[i] - if !matchesGetDataParams(entry, params) { + + if entry.Time.UnixNano() >= params.OlderThan.UnixNano() { + // Ignore entries newer than what was requested continue } - if entry.Time.UnixNano() >= params.OlderThan.UnixNano() { - break + if !matchesGetDataParams(entry, params) { + continue } memoryEntries = append(memoryEntries, entry) diff --git a/querylog/qlog_file.go b/querylog/qlog_file.go index 54d1896c..c1eeefa2 100644 --- a/querylog/qlog_file.go +++ b/querylog/qlog_file.go @@ -15,7 +15,11 @@ import ( // if we failed to find the desired record var ErrSeekNotFound = errors.New("Seek not found the record") -const bufferSize = 256 * 1024 // 256 KB is the buffer size +// TODO: Find a way to grow buffer instead of relying on this value when reading strings +const maxEntrySize = 16 * 1024 + +// buffer should be enough for at least this number of entries +const bufferSize = 100 * maxEntrySize // QLogFile represents a single query log file // It allows reading from the file in the reverse order diff --git a/querylog/querylog_file.go b/querylog/querylog_file.go index 3b32b7a2..1eeeea7c 100644 --- a/querylog/querylog_file.go +++ b/querylog/querylog_file.go @@ -9,9 +9,6 @@ import ( "github.com/AdguardTeam/golibs/log" ) -// TODO: Check this when we append a new line -- we don't want to have a line longer than this -const maxEntrySize = 1024 - // flushLogBuffer flushes the current buffer to file and resets the current buffer func (l *queryLog) flushLogBuffer(fullFlush bool) error { l.fileFlushLock.Lock()