diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ba5072c..3b1ad28e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to ### Added +- The ability to completely purge DHCP leases ([#1691]). - The ability to set the timeout for querying the upstream servers ([#2280]). - The ability to change group and user ID on startup on Unix ([#2763]). - Experimental OpenBSD support for AMD64 and 64-bit ARM CPUs ([#2439]). @@ -62,6 +63,7 @@ released by then. - Go 1.15 support. +[#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691 [#2280]: https://github.com/AdguardTeam/AdGuardHome/issues/2280 [#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439 [#2441]: https://github.com/AdguardTeam/AdGuardHome/issues/2441 diff --git a/internal/aghnet/addr.go b/internal/aghnet/addr.go index c76ef093..4eb957c5 100644 --- a/internal/aghnet/addr.go +++ b/internal/aghnet/addr.go @@ -10,6 +10,24 @@ import ( "golang.org/x/net/idna" ) +// CloneIP returns a clone of an IP address. +func CloneIP(ip net.IP) (clone net.IP) { + if ip != nil && len(ip) == 0 { + return net.IP{} + } + + return append(clone, ip...) +} + +// CloneMAC returns a clone of a MAC address. +func CloneMAC(mac net.HardwareAddr) (clone net.HardwareAddr) { + if mac != nil && len(mac) == 0 { + return net.HardwareAddr{} + } + + return append(clone, mac...) +} + // IPFromAddr returns an IP address from addr. If addr is neither // a *net.TCPAddr nor a *net.UDPAddr, it returns nil. func IPFromAddr(addr net.Addr) (ip net.IP) { @@ -31,6 +49,12 @@ func IsValidHostOuterRune(r rune) (ok bool) { (r >= '0' && r <= '9') } +// JoinHostPort is a convinient wrapper for net.JoinHostPort with port of type +// int. +func JoinHostPort(host string, port int) (hostport string) { + return net.JoinHostPort(host, strconv.Itoa(port)) +} + // isValidHostRune returns true if r is a valid rune for a hostname label. func isValidHostRune(r rune) (ok bool) { return r == '-' || IsValidHostOuterRune(r) diff --git a/internal/aghnet/addr_test.go b/internal/aghnet/addr_test.go index df9cd740..0711a597 100644 --- a/internal/aghnet/addr_test.go +++ b/internal/aghnet/addr_test.go @@ -9,6 +9,26 @@ import ( "github.com/stretchr/testify/require" ) +func TestCloneIP(t *testing.T) { + assert.Equal(t, net.IP(nil), CloneIP(nil)) + assert.Equal(t, net.IP{}, CloneIP(net.IP{})) + + ip := net.IP{1, 2, 3, 4} + clone := CloneIP(ip) + assert.Equal(t, ip, clone) + assert.NotSame(t, &ip[0], &clone[0]) +} + +func TestCloneMAC(t *testing.T) { + assert.Equal(t, net.HardwareAddr(nil), CloneMAC(nil)) + assert.Equal(t, net.HardwareAddr{}, CloneMAC(net.HardwareAddr{})) + + mac := net.HardwareAddr{0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC} + clone := CloneMAC(mac) + assert.Equal(t, mac, clone) + assert.NotSame(t, &mac[0], &clone[0]) +} + func TestIPFromAddr(t *testing.T) { ip := net.IP{1, 2, 3, 4} assert.Equal(t, net.IP(nil), IPFromAddr(nil)) @@ -66,6 +86,14 @@ func TestValidateHardwareAddress(t *testing.T) { } } +func TestJoinHostPort(t *testing.T) { + assert.Equal(t, ":0", JoinHostPort("", 0)) + assert.Equal(t, "host:12345", JoinHostPort("host", 12345)) + assert.Equal(t, "1.2.3.4:12345", JoinHostPort("1.2.3.4", 12345)) + assert.Equal(t, "[1234::5678]:12345", JoinHostPort("1234::5678", 12345)) + assert.Equal(t, "[1234::5678%lo]:12345", JoinHostPort("1234::5678%lo", 12345)) +} + func repeatStr(b *strings.Builder, s string, n int) { for i := 0; i < n; i++ { _, _ = b.WriteString(s) diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index 13ace026..5a25dc35 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -424,9 +424,3 @@ func CollectAllIfacesAddrs() (addrs []string, err error) { return addrs, nil } - -// JoinHostPort is a convinient wrapper for net.JoinHostPort with port of type -// int. -func JoinHostPort(host string, port int) (hostport string) { - return net.JoinHostPort(host, strconv.Itoa(port)) -} diff --git a/internal/dhcpd/db.go b/internal/dhcpd/db.go index e9fd97c0..af6bf72e 100644 --- a/internal/dhcpd/db.go +++ b/internal/dhcpd/db.go @@ -4,6 +4,7 @@ package dhcpd import ( "encoding/json" + "fmt" "net" "os" "time" @@ -31,7 +32,7 @@ func normalizeIP(ip net.IP) net.IP { } // Load lease table from DB -func (s *Server) dbLoad() { +func (s *Server) dbLoad() (err error) { dynLeases := []*Lease{} staticLeases := []*Lease{} v6StaticLeases := []*Lease{} @@ -40,17 +41,16 @@ func (s *Server) dbLoad() { data, err := os.ReadFile(s.conf.DBFilePath) if err != nil { if !errors.Is(err, os.ErrNotExist) { - log.Error("dhcp: can't read file %q: %v", s.conf.DBFilePath, err) + return fmt.Errorf("reading db: %w", err) } - return + return nil } obj := []leaseJSON{} err = json.Unmarshal(data, &obj) if err != nil { - log.Error("dhcp: invalid DB: %v", err) - return + return fmt.Errorf("decoding db: %w", err) } numLeases := len(obj) @@ -85,15 +85,23 @@ func (s *Server) dbLoad() { } leases4 := normalizeLeases(staticLeases, dynLeases) - s.srv4.ResetLeases(leases4) + err = s.srv4.ResetLeases(leases4) + if err != nil { + return fmt.Errorf("resetting dhcpv4 leases: %w", err) + } leases6 := normalizeLeases(v6StaticLeases, v6DynLeases) if s.srv6 != nil { - s.srv6.ResetLeases(leases6) + err = s.srv6.ResetLeases(leases6) + if err != nil { + return fmt.Errorf("resetting dhcpv6 leases: %w", err) + } } log.Info("dhcp: loaded leases v4:%d v6:%d total-read:%d from DB", len(leases4), len(leases6), numLeases) + + return nil } // Skip duplicate leases @@ -124,20 +132,24 @@ func normalizeLeases(staticLeases, dynLeases []*Lease) []*Lease { } // Store lease table in DB -func (s *Server) dbStore() { - var leases []leaseJSON +func (s *Server) dbStore() (err error) { + // Use an empty slice here as opposed to nil so that it doesn't write + // "null" into the database file if leases are empty. + leases := []leaseJSON{} leases4 := s.srv4.getLeasesRef() for _, l := range leases4 { if l.Expiry.Unix() == 0 { continue } + lease := leaseJSON{ HWAddr: l.HWAddr, IP: l.IP, Hostname: l.Hostname, Expiry: l.Expiry.Unix(), } + leases = append(leases, lease) } @@ -147,29 +159,30 @@ func (s *Server) dbStore() { if l.Expiry.Unix() == 0 { continue } + lease := leaseJSON{ HWAddr: l.HWAddr, IP: l.IP, Hostname: l.Hostname, Expiry: l.Expiry.Unix(), } + leases = append(leases, lease) } } - data, err := json.Marshal(leases) + var data []byte + data, err = json.Marshal(leases) if err != nil { - log.Error("json.Marshal: %v", err) - return + return fmt.Errorf("encoding db: %w", err) } err = maybe.WriteFile(s.conf.DBFilePath, data, 0o644) if err != nil { - log.Error("dhcp: can't store lease table on disk: %v filename: %s", - err, s.conf.DBFilePath) - - return + return fmt.Errorf("writing db: %w", err) } - log.Info("dhcp: stored %d leases in DB", len(leases)) + log.Info("dhcp: stored %d leases in db", len(leases)) + + return nil } diff --git a/internal/dhcpd/dhcpd.go b/internal/dhcpd/dhcpd.go index f652a42e..12bec4b7 100644 --- a/internal/dhcpd/dhcpd.go +++ b/internal/dhcpd/dhcpd.go @@ -10,6 +10,7 @@ import ( "runtime" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/log" ) @@ -36,6 +37,37 @@ type Lease struct { IP net.IP `json:"ip"` } +// Clone returns a deep copy of l. +func (l *Lease) Clone() (clone *Lease) { + if l == nil { + return nil + } + + return &Lease{ + Expiry: l.Expiry, + Hostname: l.Hostname, + HWAddr: aghnet.CloneMAC(l.HWAddr), + IP: aghnet.CloneIP(l.IP), + } +} + +// IsBlocklisted returns true if the lease is blocklisted. +// +// TODO(a.garipov): Just make it a boolean field. +func (l *Lease) IsBlocklisted() (ok bool) { + if len(l.HWAddr) == 0 { + return false + } + + for _, b := range l.HWAddr { + if b != 0 { + return false + } + } + + return true +} + // IsStatic returns true if the lease is static. // // TODO(a.garipov): Just make it a boolean field. @@ -131,16 +163,27 @@ type Server struct { onLeaseChanged []OnLeaseChangedT } +// GetLeasesFlags are the flags for GetLeases. +type GetLeasesFlags uint8 + +// GetLeasesFlags values +const ( + LeasesDynamic GetLeasesFlags = 0b0001 + LeasesStatic GetLeasesFlags = 0b0010 + + LeasesAll = LeasesDynamic | LeasesStatic +) + // ServerInterface is an interface for servers. type ServerInterface interface { Enabled() (ok bool) - Leases(flags int) []Lease + Leases(flags GetLeasesFlags) (leases []*Lease) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) } // Create - create object -func Create(conf ServerConfig) *Server { - s := &Server{} +func Create(conf ServerConfig) (s *Server, err error) { + s = &Server{} s.conf.Enabled = conf.Enabled s.conf.InterfaceName = conf.InterfaceName @@ -166,15 +209,18 @@ func Create(conf ServerConfig) *Server { webHandlersRegistered = true } - var err4, err6 error v4conf := conf.Conf4 v4conf.Enabled = s.conf.Enabled if len(v4conf.RangeStart) == 0 { v4conf.Enabled = false } + v4conf.InterfaceName = s.conf.InterfaceName v4conf.notify = s.onNotify - s.srv4, err4 = v4Create(v4conf) + s.srv4, err = v4Create(v4conf) + if err != nil { + return nil, fmt.Errorf("creating dhcpv4 srv: %w", err) + } v6conf := conf.Conf6 v6conf.Enabled = s.conf.Enabled @@ -183,29 +229,26 @@ func Create(conf ServerConfig) *Server { } v6conf.InterfaceName = s.conf.InterfaceName v6conf.notify = s.onNotify - s.srv6, err6 = v6Create(v6conf) - - if err4 != nil { - log.Error("%s", err4) - return nil - } - if err6 != nil { - log.Error("%s", err6) - return nil + s.srv6, err = v6Create(v6conf) + if err != nil { + return nil, fmt.Errorf("creating dhcpv6 srv: %w", err) } s.conf.Conf4 = conf.Conf4 s.conf.Conf6 = conf.Conf6 if s.conf.Enabled && !v4conf.Enabled && !v6conf.Enabled { - log.Error("Can't enable DHCP server because neither DHCPv4 nor DHCPv6 servers are configured") - return nil + return nil, fmt.Errorf("neither dhcpv4 nor dhcpv6 srv is configured") } - // we can't delay database loading until DHCP server is started, - // because we need static leases functionality available beforehand - s.dbLoad() - return s + // Don't delay database loading until the DHCP server is started, + // because we need static leases functionality available beforehand. + err = s.dbLoad() + if err != nil { + return nil, fmt.Errorf("loading db: %w", err) + } + + return s, nil } // Enabled returns true when the server is enabled. @@ -213,10 +256,30 @@ func (s *Server) Enabled() (ok bool) { return s.conf.Enabled } +// resetLeases resets all leases in the lease database. +func (s *Server) resetLeases() (err error) { + err = s.srv4.ResetLeases(nil) + if err != nil { + return err + } + + if s.srv6 != nil { + err = s.srv6.ResetLeases(nil) + if err != nil { + return err + } + } + + return s.dbStore() +} + // server calls this function after DB is updated func (s *Server) onNotify(flags uint32) { if flags == LeaseChangedDBStore { - s.dbStore() + err := s.dbStore() + if err != nil { + log.Error("updating db: %s", err) + } return } @@ -263,21 +326,23 @@ func (s *Server) Start() (err error) { } // Stop closes the listening UDP socket -func (s *Server) Stop() { - s.srv4.Stop() - s.srv6.Stop() -} +func (s *Server) Stop() (err error) { + err = s.srv4.Stop() + if err != nil { + return err + } -// flags for Leases() function -const ( - LeasesDynamic = 1 - LeasesStatic = 2 - LeasesAll = LeasesDynamic | LeasesStatic -) + err = s.srv6.Stop() + if err != nil { + return err + } + + return nil +} // Leases returns the list of active IPv4 and IPv6 DHCP leases. It's safe for // concurrent use. -func (s *Server) Leases(flags int) (leases []Lease) { +func (s *Server) Leases(flags GetLeasesFlags) (leases []*Lease) { return append(s.srv4.GetLeases(flags), s.srv6.GetLeases(flags)...) } @@ -290,6 +355,6 @@ func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr { } // AddStaticLease - add static v4 lease -func (s *Server) AddStaticLease(lease Lease) error { - return s.srv4.AddStaticLease(lease) +func (s *Server) AddStaticLease(l *Lease) error { + return s.srv4.AddStaticLease(l) } diff --git a/internal/dhcpd/dhcpd_test.go b/internal/dhcpd/dhcpd_test.go index 8cbc8787..faae9e8d 100644 --- a/internal/dhcpd/dhcpd_test.go +++ b/internal/dhcpd/dhcpd_test.go @@ -43,7 +43,7 @@ func TestDB(t *testing.T) { s.srv6, err = v6Create(V6ServerConf{}) require.NoError(t, err) - leases := []Lease{{ + leases := []*Lease{{ Expiry: time.Now().Add(time.Hour), Hostname: "static-1.local", HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, @@ -57,18 +57,24 @@ func TestDB(t *testing.T) { srv4, ok := s.srv4.(*v4Server) require.True(t, ok) - err = srv4.addLease(&leases[0]) + err = srv4.addLease(leases[0]) require.NoError(t, err) err = s.srv4.AddStaticLease(leases[1]) require.NoError(t, err) - s.dbStore() + err = s.dbStore() + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, os.Remove(dbFilename)) }) - s.srv4.ResetLeases(nil) - s.dbLoad() + + err = s.srv4.ResetLeases(nil) + require.NoError(t, err) + + err = s.dbLoad() + require.NoError(t, err) ll := s.srv4.GetLeases(LeasesAll) require.Len(t, ll, len(leases)) diff --git a/internal/dhcpd/http.go b/internal/dhcpd/http.go index c7de4e82..aec55c94 100644 --- a/internal/dhcpd/http.go +++ b/internal/dhcpd/http.go @@ -60,12 +60,12 @@ func v6JSONToServerConf(j *v6ServerConfJSON) V6ServerConf { // dhcpStatusResponse is the response for /control/dhcp/status endpoint. type dhcpStatusResponse struct { - Enabled bool `json:"enabled"` IfaceName string `json:"interface_name"` V4 V4ServerConf `json:"v4"` V6 V6ServerConf `json:"v6"` - Leases []Lease `json:"leases"` - StaticLeases []Lease `json:"static_leases"` + Leases []*Lease `json:"leases"` + StaticLeases []*Lease `json:"static_leases"` + Enabled bool `json:"enabled"` } func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) { @@ -146,12 +146,68 @@ type dhcpServerConfigJSON struct { Enabled nullBool `json:"enabled"` } +func (s *Server) handleDHCPSetConfigV4( + conf *dhcpServerConfigJSON, +) (srv4 DHCPServer, enabled bool, err error) { + if conf.V4 == nil { + return nil, false, nil + } + + v4Conf := v4JSONToServerConf(conf.V4) + v4Conf.Enabled = conf.Enabled == nbTrue + if len(v4Conf.RangeStart) == 0 { + v4Conf.Enabled = false + } + + enabled = v4Conf.Enabled + v4Conf.InterfaceName = conf.InterfaceName + + c4 := V4ServerConf{} + s.srv4.WriteDiskConfig4(&c4) + v4Conf.notify = c4.notify + v4Conf.ICMPTimeout = c4.ICMPTimeout + v4Conf.Options = c4.Options + + srv4, err = v4Create(v4Conf) + + return srv4, enabled, err +} + +func (s *Server) handleDHCPSetConfigV6( + conf *dhcpServerConfigJSON, +) (srv6 DHCPServer, enabled bool, err error) { + if conf.V6 == nil { + return nil, false, nil + } + + v6Conf := v6JSONToServerConf(conf.V6) + v6Conf.Enabled = conf.Enabled == nbTrue + if len(v6Conf.RangeStart) == 0 { + v6Conf.Enabled = false + } + + // Don't overwrite the RA/SLAAC settings from the config file. + // + // TODO(a.garipov): Perhaps include them into the request to allow + // changing them from the HTTP API? + v6Conf.RASLAACOnly = s.conf.Conf6.RASLAACOnly + v6Conf.RAAllowSLAAC = s.conf.Conf6.RAAllowSLAAC + + enabled = v6Conf.Enabled + v6Conf.InterfaceName = conf.InterfaceName + v6Conf.notify = s.onNotify + + srv6, err = v6Create(v6Conf) + + return srv6, enabled, err +} + func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { - conf := dhcpServerConfigJSON{} + conf := &dhcpServerConfigJSON{} conf.Enabled = boolToNullBool(s.conf.Enabled) conf.InterfaceName = s.conf.InterfaceName - err := json.NewDecoder(r.Body).Decode(&conf) + err := json.NewDecoder(r.Body).Decode(conf) if err != nil { httpError(r, w, http.StatusBadRequest, "failed to parse new dhcp config json: %s", err) @@ -159,61 +215,18 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { return } - var s4 DHCPServer - var s6 DHCPServer - v4Enabled := false - v6Enabled := false + srv4, v4Enabled, err := s.handleDHCPSetConfigV4(conf) + if err != nil { + httpError(r, w, http.StatusBadRequest, "bad dhcpv4 configuration: %s", err) - if conf.V4 != nil { - v4Conf := v4JSONToServerConf(conf.V4) - v4Conf.Enabled = conf.Enabled == nbTrue - if len(v4Conf.RangeStart) == 0 { - v4Conf.Enabled = false - } - - v4Enabled = v4Conf.Enabled - v4Conf.InterfaceName = conf.InterfaceName - - c4 := V4ServerConf{} - s.srv4.WriteDiskConfig4(&c4) - v4Conf.notify = c4.notify - v4Conf.ICMPTimeout = c4.ICMPTimeout - v4Conf.Options = c4.Options - - s4, err = v4Create(v4Conf) - if err != nil { - httpError(r, w, http.StatusBadRequest, - "invalid dhcpv4 configuration: %s", err) - - return - } + return } - if conf.V6 != nil { - v6Conf := v6JSONToServerConf(conf.V6) - v6Conf.Enabled = conf.Enabled == nbTrue - if len(v6Conf.RangeStart) == 0 { - v6Conf.Enabled = false - } + srv6, v6Enabled, err := s.handleDHCPSetConfigV6(conf) + if err != nil { + httpError(r, w, http.StatusBadRequest, "bad dhcpv6 configuration: %s", err) - // Don't overwrite the RA/SLAAC settings from the config file. - // - // TODO(a.garipov): Perhaps include them into the request to - // allow changing them from the HTTP API? - v6Conf.RASLAACOnly = s.conf.Conf6.RASLAACOnly - v6Conf.RAAllowSLAAC = s.conf.Conf6.RAAllowSLAAC - - v6Enabled = v6Conf.Enabled - v6Conf.InterfaceName = conf.InterfaceName - v6Conf.notify = s.onNotify - - s6, err = v6Create(v6Conf) - if err != nil { - httpError(r, w, http.StatusBadRequest, - "invalid dhcpv6 configuration: %s", err) - - return - } + return } if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled { @@ -223,7 +236,12 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { return } - s.Stop() + err = s.Stop() + if err != nil { + httpError(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err) + + return + } if conf.Enabled != nbNull { s.conf.Enabled = conf.Enabled == nbTrue @@ -233,16 +251,22 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { s.conf.InterfaceName = conf.InterfaceName } - if s4 != nil { - s.srv4 = s4 + if srv4 != nil { + s.srv4 = srv4 } - if s6 != nil { - s.srv6 = s6 + if srv6 != nil { + s.srv6 = srv6 } s.conf.ConfigModified() - s.dbLoad() + + err = s.dbLoad() + if err != nil { + httpError(r, w, http.StatusInternalServerError, "loading leases db: %s", err) + + return + } if s.conf.Enabled { var code int @@ -431,26 +455,26 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque } func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) { - lj := Lease{} - err := json.NewDecoder(r.Body).Decode(&lj) + l := &Lease{} + err := json.NewDecoder(r.Body).Decode(l) if err != nil { httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) return } - if lj.IP == nil { + if l.IP == nil { httpError(r, w, http.StatusBadRequest, "invalid IP") return } - ip4 := lj.IP.To4() + ip4 := l.IP.To4() if ip4 == nil { - lj.IP = lj.IP.To16() + l.IP = l.IP.To16() - err = s.srv6.AddStaticLease(lj) + err = s.srv6.AddStaticLease(l) if err != nil { httpError(r, w, http.StatusBadRequest, "%s", err) } @@ -458,8 +482,8 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request return } - lj.IP = ip4 - err = s.srv4.AddStaticLease(lj) + l.IP = ip4 + err = s.srv4.AddStaticLease(l) if err != nil { httpError(r, w, http.StatusBadRequest, "%s", err) @@ -468,26 +492,26 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request } func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) { - lj := Lease{} - err := json.NewDecoder(r.Body).Decode(&lj) + l := &Lease{} + err := json.NewDecoder(r.Body).Decode(l) if err != nil { httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) return } - if lj.IP == nil { + if l.IP == nil { httpError(r, w, http.StatusBadRequest, "invalid IP") return } - ip4 := lj.IP.To4() + ip4 := l.IP.To4() if ip4 == nil { - lj.IP = lj.IP.To16() + l.IP = l.IP.To16() - err = s.srv6.RemoveStaticLease(lj) + err = s.srv6.RemoveStaticLease(l) if err != nil { httpError(r, w, http.StatusBadRequest, "%s", err) } @@ -495,8 +519,8 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ return } - lj.IP = ip4 - err = s.srv4.RemoveStaticLease(lj) + l.IP = ip4 + err = s.srv4.RemoveStaticLease(l) if err != nil { httpError(r, w, http.StatusBadRequest, "%s", err) @@ -505,11 +529,16 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ } func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) { - s.Stop() + err := s.Stop() + if err != nil { + httpError(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err) - err := os.Remove(s.conf.DBFilePath) + return + } + + err = os.Remove(s.conf.DBFilePath) if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Error("dhcp: removing %q: %s", s.conf.DBFilePath, err) + log.Error("dhcp: removing db: %s", err) } oldconf := s.conf @@ -531,6 +560,16 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) { s.conf.ConfigModified() } +func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) { + err := s.resetLeases() + if err != nil { + msg := "resetting leases: %s" + httpError(r, w, http.StatusInternalServerError, msg, err) + + return + } +} + func (s *Server) registerHandlers() { s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.handleDHCPStatus) s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.handleDHCPInterfaces) @@ -539,6 +578,7 @@ func (s *Server) registerHandlers() { s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.handleDHCPAddStaticLease) s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.handleDHCPRemoveStaticLease) s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.handleReset) + s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases) } // jsonError is a generic JSON error response. @@ -579,4 +619,5 @@ func (s *Server) registerNotImplementedHandlers() { s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", h) s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", h) s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", h) + s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", h) } diff --git a/internal/dhcpd/server.go b/internal/dhcpd/server.go index 51395535..13c3b91b 100644 --- a/internal/dhcpd/server.go +++ b/internal/dhcpd/server.go @@ -7,14 +7,14 @@ import ( // DHCPServer - DHCP server interface type DHCPServer interface { - // ResetLeases - reset leases - ResetLeases(leases []*Lease) - // GetLeases - get leases - GetLeases(flags int) []Lease + // ResetLeases resets leases. + ResetLeases(leases []*Lease) (err error) + // GetLeases returns deep clones of the current leases. + GetLeases(flags GetLeasesFlags) (leases []*Lease) // AddStaticLease - add a static lease - AddStaticLease(lease Lease) error + AddStaticLease(l *Lease) (err error) // RemoveStaticLease - remove a static lease - RemoveStaticLease(l Lease) error + RemoveStaticLease(l *Lease) (err error) // FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases FindMACbyIP(ip net.IP) net.HardwareAddr @@ -24,9 +24,9 @@ type DHCPServer interface { WriteDiskConfig6(c *V6ServerConf) // Start - start server - Start() error + Start() (err error) // Stop - stop server - Stop() + Stop() (err error) getLeasesRef() []*Lease } diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index 7c915d67..9725a05f 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -96,9 +96,9 @@ func (s *v4Server) validHostnameForClient(cliHostname string, ip net.IP) (hostna return hostname } -// ResetLeases - reset leases -func (s *v4Server) ResetLeases(leases []*Lease) { - var err error +// ResetLeases resets leases. +func (s *v4Server) ResetLeases(leases []*Lease) (err error) { + defer func() { err = errors.Annotate(err, "dhcpv4: %w") }() if !s.conf.Enabled { return @@ -125,6 +125,8 @@ func (s *v4Server) ResetLeases(leases []*Lease) { continue } } + + return nil } // getLeasesRef returns the actual leases slice. For internal use only. @@ -154,14 +156,12 @@ func (s *v4Server) isBlocklisted(l *Lease) (ok bool) { // GetLeases returns the list of current DHCP leases. It is safe for concurrent // use. -func (s *v4Server) GetLeases(flags int) (res []Lease) { +func (s *v4Server) GetLeases(flags GetLeasesFlags) (leases []*Lease) { // The function shouldn't return nil, because zero-length slice behaves // differently in cases like marshalling. Our front-end also requires // a non-nil value in the response. - res = []Lease{} + leases = []*Lease{} - // TODO(a.garipov): Remove the silly bit twiddling and make GetLeases - // accept booleans. Seriously, this doesn't even save stack space. getDynamic := flags&LeasesDynamic != 0 getStatic := flags&LeasesStatic != 0 @@ -171,17 +171,17 @@ func (s *v4Server) GetLeases(flags int) (res []Lease) { now := time.Now() for _, l := range s.leases { if getDynamic && l.Expiry.After(now) && !s.isBlocklisted(l) { - res = append(res, *l) + leases = append(leases, l.Clone()) continue } if getStatic && l.IsStatic() { - res = append(res, *l) + leases = append(leases, l.Clone()) } } - return res + return leases } // FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases @@ -305,7 +305,7 @@ func (s *v4Server) addLease(l *Lease) (err error) { } // rmLease removes a lease with the same properties. -func (s *v4Server) rmLease(lease Lease) (err error) { +func (s *v4Server) rmLease(lease *Lease) (err error) { if len(s.leases) == 0 { return nil } @@ -326,7 +326,7 @@ func (s *v4Server) rmLease(lease Lease) (err error) { } // AddStaticLease adds a static lease. It is safe for concurrent use. -func (s *v4Server) AddStaticLease(l Lease) (err error) { +func (s *v4Server) AddStaticLease(l *Lease) (err error) { defer func() { err = errors.Annotate(err, "dhcpv4: adding static lease: %w") }() if ip4 := l.IP.To4(); ip4 == nil { @@ -365,7 +365,7 @@ func (s *v4Server) AddStaticLease(l Lease) (err error) { s.leasesLock.Lock() defer s.leasesLock.Unlock() - err = s.rmDynamicLease(&l) + err = s.rmDynamicLease(l) if err != nil { err = fmt.Errorf( "removing dynamic leases for %s (%s): %w", @@ -377,7 +377,7 @@ func (s *v4Server) AddStaticLease(l Lease) (err error) { return } - err = s.addLease(&l) + err = s.addLease(l) if err != nil { err = fmt.Errorf("adding static lease for %s (%s): %w", l.IP, l.HWAddr, err) @@ -395,7 +395,7 @@ func (s *v4Server) AddStaticLease(l Lease) (err error) { } // RemoveStaticLease removes a static lease. It is safe for concurrent use. -func (s *v4Server) RemoveStaticLease(l Lease) (err error) { +func (s *v4Server) RemoveStaticLease(l *Lease) (err error) { defer func() { err = errors.Annotate(err, "dhcpv4: %w") }() if len(l.IP) != 4 { @@ -993,15 +993,15 @@ func (s *v4Server) Start() (err error) { } // Stop - stop server -func (s *v4Server) Stop() { +func (s *v4Server) Stop() (err error) { if s.srv == nil { return } log.Debug("dhcpv4: stopping") - err := s.srv.Close() + err = s.srv.Close() if err != nil { - log.Error("dhcpv4: srv.Close: %s", err) + return fmt.Errorf("closing dhcpv4 srv: %w", err) } // Signal to the clients containers in packages home and dnsforward that @@ -1009,6 +1009,8 @@ func (s *v4Server) Stop() { s.conf.notify(LeaseChangedRemovedAll) s.srv = nil + + return nil } // Create DHCPv4 server diff --git a/internal/dhcpd/v46_windows.go b/internal/dhcpd/v46_windows.go index 66511515..7f2c4121 100644 --- a/internal/dhcpd/v46_windows.go +++ b/internal/dhcpd/v46_windows.go @@ -9,16 +9,15 @@ import "net" type winServer struct{} -func (s *winServer) ResetLeases(leases []*Lease) {} -func (s *winServer) GetLeases(flags int) []Lease { return nil } -func (s *winServer) getLeasesRef() []*Lease { return nil } -func (s *winServer) AddStaticLease(lease Lease) error { return nil } -func (s *winServer) RemoveStaticLease(l Lease) error { return nil } -func (s *winServer) FindMACbyIP(ip net.IP) net.HardwareAddr { return nil } -func (s *winServer) WriteDiskConfig4(c *V4ServerConf) {} -func (s *winServer) WriteDiskConfig6(c *V6ServerConf) {} -func (s *winServer) Start() error { return nil } -func (s *winServer) Stop() {} -func (s *winServer) Reset() {} -func v4Create(conf V4ServerConf) (DHCPServer, error) { return &winServer{}, nil } -func v6Create(conf V6ServerConf) (DHCPServer, error) { return &winServer{}, nil } +func (s *winServer) ResetLeases(_ []*Lease) (err error) { return nil } +func (s *winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil } +func (s *winServer) getLeasesRef() []*Lease { return nil } +func (s *winServer) AddStaticLease(_ *Lease) (err error) { return nil } +func (s *winServer) RemoveStaticLease(_ *Lease) (err error) { return nil } +func (s *winServer) FindMACbyIP(ip net.IP) (mac net.HardwareAddr) { return nil } +func (s *winServer) WriteDiskConfig4(c *V4ServerConf) {} +func (s *winServer) WriteDiskConfig6(c *V6ServerConf) {} +func (s *winServer) Start() (err error) { return nil } +func (s *winServer) Stop() (err error) { return nil } +func v4Create(conf V4ServerConf) (DHCPServer, error) { return &winServer{}, nil } +func v6Create(conf V6ServerConf) (DHCPServer, error) { return &winServer{}, nil } diff --git a/internal/dhcpd/v4_test.go b/internal/dhcpd/v4_test.go index b9a1defb..9630d8cc 100644 --- a/internal/dhcpd/v4_test.go +++ b/internal/dhcpd/v4_test.go @@ -30,7 +30,7 @@ func TestV4_AddRemove_static(t *testing.T) { assert.Empty(t, ls) // Add static lease. - l := Lease{ + l := &Lease{ Hostname: "static-1.local", HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, IP: net.IP{192, 168, 10, 150}, @@ -50,7 +50,7 @@ func TestV4_AddRemove_static(t *testing.T) { assert.True(t, ls[0].IsStatic()) // Try to remove static lease. - err = s.RemoveStaticLease(Lease{ + err = s.RemoveStaticLease(&Lease{ IP: net.IP{192, 168, 10, 110}, HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, }) @@ -92,7 +92,7 @@ func TestV4_AddReplace(t *testing.T) { require.NoError(t, err) } - stLeases := []Lease{{ + stLeases := []*Lease{{ Hostname: "static-1.local", HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, IP: net.IP{192, 168, 10, 150}, @@ -134,7 +134,7 @@ func TestV4StaticLease_Get(t *testing.T) { s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}} - l := Lease{ + l := &Lease{ Hostname: "static-1.local", HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, IP: net.IP{192, 168, 10, 150}, diff --git a/internal/dhcpd/v6.go b/internal/dhcpd/v6.go index 6d7d6a96..ff5be152 100644 --- a/internal/dhcpd/v6.go +++ b/internal/dhcpd/v6.go @@ -57,10 +57,12 @@ func ip6InRange(start, ip net.IP) bool { return start[15] <= ip[15] } -// ResetLeases - reset leases -func (s *v6Server) ResetLeases(ll []*Lease) { +// ResetLeases resets leases. +func (s *v6Server) ResetLeases(leases []*Lease) (err error) { + defer func() { err = errors.Annotate(err, "dhcpv4: %w") }() + s.leases = nil - for _, l := range ll { + for _, l := range leases { if l.Expiry.Unix() != leaseExpireStatic && !ip6InRange(s.conf.ipStart, l.IP) { @@ -72,28 +74,31 @@ func (s *v6Server) ResetLeases(ll []*Lease) { s.addLease(l) } + + return nil } -// GetLeases - get current leases -func (s *v6Server) GetLeases(flags int) []Lease { +// GetLeases returns the list of current DHCP leases. It is safe for concurrent +// use. +func (s *v6Server) GetLeases(flags GetLeasesFlags) (leases []*Lease) { // The function shouldn't return nil value because zero-length slice // behaves differently in cases like marshalling. Our front-end also // requires non-nil value in the response. - result := []Lease{} + leases = []*Lease{} s.leasesLock.Lock() - for _, lease := range s.leases { - if lease.Expiry.Unix() == leaseExpireStatic { + for _, l := range s.leases { + if l.Expiry.Unix() == leaseExpireStatic { if (flags & LeasesStatic) != 0 { - result = append(result, *lease) + leases = append(leases, l.Clone()) } } else { if (flags & LeasesDynamic) != 0 { - result = append(result, *lease) + leases = append(leases, l.Clone()) } } } s.leasesLock.Unlock() - return result + return leases } // getLeasesRef returns the actual leases slice. For internal use only. @@ -133,12 +138,11 @@ func (s *v6Server) leaseRemoveSwapByIndex(i int) { // Remove a dynamic lease with the same properties // Return error if a static lease is found -func (s *v6Server) rmDynamicLease(lease Lease) error { +func (s *v6Server) rmDynamicLease(lease *Lease) (err error) { for i := 0; i < len(s.leases); i++ { l := s.leases[i] if bytes.Equal(l.HWAddr, lease.HWAddr) { - if l.Expiry.Unix() == leaseExpireStatic { return fmt.Errorf("static lease already exists") } @@ -147,11 +151,11 @@ func (s *v6Server) rmDynamicLease(lease Lease) error { if i == len(s.leases) { break } + l = s.leases[i] } if net.IP.Equal(l.IP, lease.IP) { - if l.Expiry.Unix() == leaseExpireStatic { return fmt.Errorf("static lease already exists") } @@ -159,11 +163,12 @@ func (s *v6Server) rmDynamicLease(lease Lease) error { s.leaseRemoveSwapByIndex(i) } } + return nil } // AddStaticLease adds a static lease. It is safe for concurrent use. -func (s *v6Server) AddStaticLease(l Lease) (err error) { +func (s *v6Server) AddStaticLease(l *Lease) (err error) { defer func() { err = errors.Annotate(err, "dhcpv6: %w") }() if len(l.IP) != 16 { @@ -181,18 +186,21 @@ func (s *v6Server) AddStaticLease(l Lease) (err error) { err = s.rmDynamicLease(l) if err != nil { s.leasesLock.Unlock() + return err } - s.addLease(&l) + + s.addLease(l) s.conf.notify(LeaseChangedDBStore) s.leasesLock.Unlock() s.conf.notify(LeaseChangedAddedStatic) + return nil } // RemoveStaticLease removes a static lease. It is safe for concurrent use. -func (s *v6Server) RemoveStaticLease(l Lease) (err error) { +func (s *v6Server) RemoveStaticLease(l *Lease) (err error) { defer func() { err = errors.Annotate(err, "dhcpv6: %w") }() if len(l.IP) != 16 { @@ -224,19 +232,20 @@ func (s *v6Server) addLease(l *Lease) { } // Remove a lease with the same properties -func (s *v6Server) rmLease(lease Lease) error { +func (s *v6Server) rmLease(lease *Lease) (err error) { for i, l := range s.leases { if net.IP.Equal(l.IP, lease.IP) { - if !bytes.Equal(l.HWAddr, lease.HWAddr) || l.Hostname != lease.Hostname { return fmt.Errorf("lease not found") } s.leaseRemoveSwapByIndex(i) + return nil } } + return fmt.Errorf("lease not found") } @@ -654,10 +663,10 @@ func (s *v6Server) Start() (err error) { } // Stop - stop server -func (s *v6Server) Stop() { - err := s.ra.Close() +func (s *v6Server) Stop() (err error) { + err = s.ra.Close() if err != nil { - log.Error("dhcpv6: s.ra.Close: %s", err) + return fmt.Errorf("closing ra ctx: %w", err) } // DHCPv6 server may not be initialized if ra_slaac_only=true @@ -668,11 +677,13 @@ func (s *v6Server) Stop() { log.Debug("dhcpv6: stopping") err = s.srv.Close() if err != nil { - log.Error("dhcpv6: srv.Close: %s", err) + return fmt.Errorf("closing dhcpv6 srv: %w", err) } // now server.Serve() will return s.srv = nil + + return nil } // Create DHCPv6 server diff --git a/internal/dhcpd/v6_test.go b/internal/dhcpd/v6_test.go index 1467909f..794b98e4 100644 --- a/internal/dhcpd/v6_test.go +++ b/internal/dhcpd/v6_test.go @@ -22,34 +22,39 @@ func TestV6_AddRemove_static(t *testing.T) { RangeStart: net.ParseIP("2001::1"), notify: notify6, }) - require.Nil(t, err) + require.NoError(t, err) require.Empty(t, s.GetLeases(LeasesStatic)) // Add static lease. - l := Lease{ + l := &Lease{ IP: net.ParseIP("2001::1"), HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, } - require.Nil(t, s.AddStaticLease(l)) + err = s.AddStaticLease(l) + require.NoError(t, err) // Try to add the same static lease. - require.NotNil(t, s.AddStaticLease(l)) + err = s.AddStaticLease(l) + require.Error(t, err) ls := s.GetLeases(LeasesStatic) require.Len(t, ls, 1) + assert.Equal(t, l.IP, ls[0].IP) assert.Equal(t, l.HWAddr, ls[0].HWAddr) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) // Try to remove non-existent static lease. - require.NotNil(t, s.RemoveStaticLease(Lease{ + err = s.RemoveStaticLease(&Lease{ IP: net.ParseIP("2001::2"), HWAddr: l.HWAddr, - })) + }) + require.Error(t, err) // Remove static lease. - require.Nil(t, s.RemoveStaticLease(l)) + err = s.RemoveStaticLease(l) + require.NoError(t, err) assert.Empty(t, s.GetLeases(LeasesStatic)) } @@ -60,7 +65,8 @@ func TestV6_AddReplace(t *testing.T) { RangeStart: net.ParseIP("2001::1"), notify: notify6, }) - require.Nil(t, err) + require.NoError(t, err) + s, ok := sIface.(*v6Server) require.True(t, ok) @@ -77,7 +83,7 @@ func TestV6_AddReplace(t *testing.T) { s.addLease(l) } - stLeases := []Lease{{ + stLeases := []*Lease{{ IP: net.ParseIP("2001::1"), HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, }, { @@ -86,7 +92,8 @@ func TestV6_AddReplace(t *testing.T) { }} for _, l := range stLeases { - require.Nil(t, s.AddStaticLease(l)) + err = s.AddStaticLease(l) + require.NoError(t, err) } ls := s.GetLeases(LeasesStatic) @@ -106,8 +113,9 @@ func TestV6GetLease(t *testing.T) { RangeStart: net.ParseIP("2001::1"), notify: notify6, }) - require.Nil(t, err) + require.NoError(t, err) s, ok := sIface.(*v6Server) + require.True(t, ok) dnsAddr := net.ParseIP("2000::1") @@ -118,33 +126,36 @@ func TestV6GetLease(t *testing.T) { LinkLayerAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, } - l := Lease{ + l := &Lease{ IP: net.ParseIP("2001::1"), HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, } - require.Nil(t, s.AddStaticLease(l)) + err = s.AddStaticLease(l) + require.NoError(t, err) var req, resp, msg *dhcpv6.Message mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA} t.Run("solicit", func(t *testing.T) { req, err = dhcpv6.NewSolicit(mac) - require.Nil(t, err) + require.NoError(t, err) msg, err = req.GetInnerMessage() - require.Nil(t, err) + require.NoError(t, err) resp, err = dhcpv6.NewAdvertiseFromSolicit(msg) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, s.process(msg, req, resp)) }) - require.Nil(t, err) + require.NoError(t, err) + resp.AddOption(dhcpv6.OptServerID(s.sid)) var oia *dhcpv6.OptIANA var oiaAddr *dhcpv6.OptIAAddress t.Run("advertise", func(t *testing.T) { require.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type()) + oia = resp.Options.OneIANA() oiaAddr = oia.Options.OneAddress() @@ -154,20 +165,21 @@ func TestV6GetLease(t *testing.T) { t.Run("request", func(t *testing.T) { req, err = dhcpv6.NewRequestFromAdvertise(resp) - require.Nil(t, err) + require.NoError(t, err) msg, err = req.GetInnerMessage() - require.Nil(t, err) + require.NoError(t, err) resp, err = dhcpv6.NewReplyFromMessage(msg) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, s.process(msg, req, resp)) }) - require.Nil(t, err) + require.NoError(t, err) t.Run("reply", func(t *testing.T) { require.Equal(t, dhcpv6.MessageTypeReply, resp.Type()) + oia = resp.Options.OneIANA() oiaAddr = oia.Options.OneAddress() @@ -182,6 +194,7 @@ func TestV6GetLease(t *testing.T) { t.Run("lease", func(t *testing.T) { ls := s.GetLeases(LeasesStatic) require.Len(t, ls, 1) + assert.Equal(t, l.IP, ls[0].IP) assert.Equal(t, l.HWAddr, ls[0].HWAddr) }) @@ -193,7 +206,8 @@ func TestV6GetDynamicLease(t *testing.T) { RangeStart: net.ParseIP("2001::2"), notify: notify6, }) - require.Nil(t, err) + require.NoError(t, err) + s, ok := sIface.(*v6Server) require.True(t, ok) @@ -209,23 +223,25 @@ func TestV6GetDynamicLease(t *testing.T) { mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA} t.Run("solicit", func(t *testing.T) { req, err = dhcpv6.NewSolicit(mac) - require.Nil(t, err) + require.NoError(t, err) msg, err = req.GetInnerMessage() - require.Nil(t, err) + require.NoError(t, err) resp, err = dhcpv6.NewAdvertiseFromSolicit(msg) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, s.process(msg, req, resp)) }) - require.Nil(t, err) + require.NoError(t, err) + resp.AddOption(dhcpv6.OptServerID(s.sid)) var oia *dhcpv6.OptIANA var oiaAddr *dhcpv6.OptIAAddress t.Run("advertise", func(t *testing.T) { require.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type()) + oia = resp.Options.OneIANA() oiaAddr = oia.Options.OneAddress() assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String()) @@ -233,20 +249,21 @@ func TestV6GetDynamicLease(t *testing.T) { t.Run("request", func(t *testing.T) { req, err = dhcpv6.NewRequestFromAdvertise(resp) - require.Nil(t, err) + require.NoError(t, err) msg, err = req.GetInnerMessage() - require.Nil(t, err) + require.NoError(t, err) resp, err = dhcpv6.NewReplyFromMessage(msg) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, s.process(msg, req, resp)) }) - require.Nil(t, err) + require.NoError(t, err) t.Run("reply", func(t *testing.T) { require.Equal(t, dhcpv6.MessageTypeReply, resp.Type()) + oia = resp.Options.OneIANA() oiaAddr = oia.Options.OneAddress() assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String()) @@ -254,11 +271,13 @@ func TestV6GetDynamicLease(t *testing.T) { dnsAddrs := resp.Options.DNS() require.Len(t, dnsAddrs, 1) + assert.Equal(t, dnsAddr, dnsAddrs[0]) t.Run("lease", func(t *testing.T) { ls := s.GetLeases(LeasesDynamic) require.Len(t, ls, 1) + assert.Equal(t, "2001::2", ls[0].IP.String()) assert.Equal(t, mac, ls[0].HWAddr) }) diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index a03cc87f..9bc4f53e 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -993,15 +993,14 @@ type testDHCP struct{} func (d *testDHCP) Enabled() (ok bool) { return true } -func (d *testDHCP) Leases(flags int) []dhcpd.Lease { - l := dhcpd.Lease{ +func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) { + return []*dhcpd.Lease{{ IP: net.IP{192, 168, 12, 34}, HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, Hostname: "myhost", - } - - return []dhcpd.Lease{l} + }} } + func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} func TestPTRResponseFromDHCPLeases(t *testing.T) { diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index c609721c..3d817e5c 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -243,6 +243,8 @@ func TestClientsAddExisting(t *testing.T) { }) t.Run("complicated", func(t *testing.T) { + var err error + testIP := net.IP{1, 2, 3, 4} // First, init a DHCP server with a single static lease. @@ -258,10 +260,12 @@ func TestClientsAddExisting(t *testing.T) { }, } - clients.dhcpServer = dhcpd.Create(config) + clients.dhcpServer, err = dhcpd.Create(config) + require.NoError(t, err) + t.Cleanup(func() { _ = os.Remove("leases.db") }) - err := clients.dhcpServer.AddStaticLease(dhcpd.Lease{ + err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{ HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, IP: testIP, Hostname: "testhost", diff --git a/internal/home/home.go b/internal/home/home.go index 61c613c9..77910436 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -241,8 +241,8 @@ func setupConfig(args options) (err error) { config.DHCP.HTTPRegister = httpRegister config.DHCP.ConfigModified = onConfigModified - Context.dhcpServer = dhcpd.Create(config.DHCP) - if Context.dhcpServer == nil { + Context.dhcpServer, err = dhcpd.Create(config.DHCP) + if Context.dhcpServer == nil || err != nil { // TODO(a.garipov): There are a lot of places in the code right // now which assume that the DHCP server can be nil despite this // condition. Inspect them and perhaps rewrite them to use @@ -630,7 +630,7 @@ func configureLogger(args options) { // cleanup stops and resets all the modules. func cleanup(ctx context.Context) { - log.Info("Stopping AdGuard Home") + log.Info("stopping AdGuard Home") if Context.web != nil { Context.web.Close(ctx) @@ -643,11 +643,14 @@ func cleanup(ctx context.Context) { err := stopDNSServer() if err != nil { - log.Error("Couldn't stop DNS server: %s", err) + log.Error("stopping dns server: %s", err) } if Context.dhcpServer != nil { - Context.dhcpServer.Stop() + err = Context.dhcpServer.Stop() + if err != nil { + log.Error("stopping dhcp server: %s", err) + } } Context.etcHosts.Close() diff --git a/openapi/CHANGELOG.md b/openapi/CHANGELOG.md index 2f34c442..5b0d28c0 100644 --- a/openapi/CHANGELOG.md +++ b/openapi/CHANGELOG.md @@ -4,6 +4,11 @@ ## v0.107: API changes +### `POST /control/dhcp/reset_leases` + +* The new `POST /control/dhcp/reset_leases` HTTP API allows removing all leases + from the DHCP server's database without erasing its configuration. + ### The parameter `"host"` in `GET /apple/*.mobileconfig` is now required. * The parameter `"host"` in `GET` requests for `/apple/doh.mobileconfig` and diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 07194338..313c739f 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -473,6 +473,21 @@ 'schema': '$ref': '#/components/schemas/Error' 'description': 'Not implemented (for example, on Windows).' + '/dhcp/reset_leases': + 'post': + 'tags': + - 'dhcp' + 'operationId': 'dhcpResetLeases' + 'summary': 'Reset DHCP leases' + 'responses': + '200': + 'description': 'OK.' + '501': + 'content': + 'application/json': + 'schema': + '$ref': '#/components/schemas/Error' + 'description': 'Not implemented (for example, on Windows).' '/filtering/status': 'get': 'tags':