Pull request: dhcpd: fix ip ranges

Updates #2541.

Squashed commit of the following:

commit c81299991876f48836d24872d9145331a0bc9e6e
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Mar 16 18:10:07 2021 +0300

    agherr: imp docs

commit f43a5f5cde0ea16dd38dd533e16e415a1d306cb2
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Mar 16 17:35:59 2021 +0300

    all: imp err handling, fix code

commit ed26ad0ff53882725f7747264f8094e6fb9b0423
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Mar 16 12:24:17 2021 +0300

    dhcpd: fix ip ranges
This commit is contained in:
Ainar Garipov 2021-03-16 19:11:32 +03:00
parent e6a8fe452c
commit 9736123483
13 changed files with 463 additions and 137 deletions

1
go.mod
View File

@ -32,6 +32,7 @@ require (
github.com/stretchr/testify v1.6.1 github.com/stretchr/testify v1.6.1
github.com/ti-mo/netfilter v0.4.0 github.com/ti-mo/netfilter v0.4.0
github.com/u-root/u-root v7.0.0+incompatible github.com/u-root/u-root v7.0.0+incompatible
github.com/willf/bitset v1.1.11
go.etcd.io/bbolt v1.3.5 go.etcd.io/bbolt v1.3.5
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110

2
go.sum
View File

@ -425,6 +425,8 @@ github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGr
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/willf/bitset v1.1.11 h1:N7Z7E9UvjW+sGsEl7k/SJrvY2reP1A07MrGuCjIOjRE=
github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=

View File

@ -1,5 +1,4 @@
// Package agherr contains the extended error type, and the function for // Package agherr contains AdGuard Home's error handling helpers.
// wrapping several errors.
package agherr package agherr
import ( import (
@ -23,8 +22,10 @@ type manyError struct {
} }
// Many wraps several errors and returns a single error. // Many wraps several errors and returns a single error.
func Many(message string, underlying ...error) error { //
err := &manyError{ // TODO(a.garipov): Add formatting to message.
func Many(message string, underlying ...error) (err error) {
err = &manyError{
message: message, message: message,
underlying: underlying, underlying: underlying,
} }
@ -33,7 +34,7 @@ func Many(message string, underlying ...error) error {
} }
// Error implements the error interface for *manyError. // Error implements the error interface for *manyError.
func (e *manyError) Error() string { func (e *manyError) Error() (msg string) {
switch len(e.underlying) { switch len(e.underlying) {
case 0: case 0:
return e.message return e.message
@ -58,7 +59,7 @@ func (e *manyError) Error() string {
} }
// Unwrap implements the hidden errors.wrapper interface for *manyError. // Unwrap implements the hidden errors.wrapper interface for *manyError.
func (e *manyError) Unwrap() error { func (e *manyError) Unwrap() (err error) {
if len(e.underlying) == 0 { if len(e.underlying) == 0 {
return nil return nil
} }
@ -71,3 +72,38 @@ func (e *manyError) Unwrap() error {
type wrapper interface { type wrapper interface {
Unwrap() error Unwrap() error
} }
// Annotate annotates the error with the message, unless the error is nil. This
// is a helper function to simplify code like this:
//
// func (f *foo) doStuff(s string) (err error) {
// defer func() {
// if err != nil {
// err = fmt.Errorf("bad foo string %q: %w", s, err)
// }
// }()
//
// // …
// }
//
// Instead, write:
//
// func (f *foo) doStuff(s string) (err error) {
// defer agherr.Annotate("bad foo string %q: %w", &err, s)
//
// // …
// }
//
// msg must contain the final ": %w" verb.
func Annotate(msg string, errPtr *error, args ...interface{}) {
if errPtr == nil {
return
}
err := *errPtr
if err != nil {
args = append(args, err)
*errPtr = fmt.Errorf(msg, args...)
}
}

View File

@ -6,30 +6,32 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestError_Error(t *testing.T) { func TestError_Error(t *testing.T) {
testCases := []struct { testCases := []struct {
err error
name string name string
want string want string
err error
}{{ }{{
err: Many("a"),
name: "simple", name: "simple",
want: "a", want: "a",
err: Many("a"),
}, { }, {
err: Many("a", errors.New("b")),
name: "wrapping", name: "wrapping",
want: "a: b", want: "a: b",
err: Many("a", errors.New("b")),
}, { }, {
err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")),
name: "wrapping several", name: "wrapping several",
want: "a: b (hidden: c, d)", want: "a: b (hidden: c, d)",
err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")),
}, { }, {
err: Many("a", Many("b", errors.New("c"), errors.New("d"))),
name: "wrapping wrapper", name: "wrapping wrapper",
want: "a: b: c (hidden: d)", want: "a: b: c (hidden: d)",
err: Many("a", Many("b", errors.New("c"), errors.New("d"))),
}} }}
for _, tc := range testCases { for _, tc := range testCases {
assert.Equal(t, tc.want, tc.err.Error(), tc.name) assert.Equal(t, tc.want, tc.err.Error(), tc.name)
} }
@ -43,33 +45,78 @@ func TestError_Unwrap(t *testing.T) {
errWrapped errWrapped
errNil errNil
) )
errs := []error{ errs := []error{
errSimple: errors.New("a"), errSimple: errors.New("a"),
errWrapped: fmt.Errorf("err: %w", errors.New("nested")), errWrapped: fmt.Errorf("err: %w", errors.New("nested")),
errNil: nil, errNil: nil,
} }
testCases := []struct { testCases := []struct {
name string
want error want error
wrapped error wrapped error
name string
}{{ }{{
name: "simple",
want: errs[errSimple], want: errs[errSimple],
wrapped: Many("a", errs[errSimple]), wrapped: Many("a", errs[errSimple]),
name: "simple",
}, { }, {
name: "nested",
want: errs[errWrapped], want: errs[errWrapped],
wrapped: Many("b", errs[errWrapped]), wrapped: Many("b", errs[errWrapped]),
name: "nested",
}, { }, {
name: "nil passed",
want: errs[errNil], want: errs[errNil],
wrapped: Many("c", errs[errNil]), wrapped: Many("c", errs[errNil]),
name: "nil passed",
}, { }, {
name: "nil not passed",
want: nil, want: nil,
wrapped: Many("d"), wrapped: Many("d"),
name: "nil not passed",
}} }}
for _, tc := range testCases { for _, tc := range testCases {
assert.Equal(t, tc.want, errors.Unwrap(tc.wrapped), tc.name) assert.Equal(t, tc.want, errors.Unwrap(tc.wrapped), tc.name)
} }
} }
func TestAnnotate(t *testing.T) {
const s = "1234"
const wantMsg = `bad string "1234": test`
// Don't use const, because we can't take a pointer of a constant.
var errTest error = Error("test")
t.Run("nil", func(t *testing.T) {
var errPtr *error
assert.NotPanics(t, func() {
Annotate("bad string %q: %w", errPtr, s)
})
})
t.Run("non_nil", func(t *testing.T) {
errPtr := &errTest
assert.NotPanics(t, func() {
Annotate("bad string %q: %w", errPtr, s)
})
require.NotNil(t, errPtr)
err := *errPtr
require.NotNil(t, err)
assert.Equal(t, wantMsg, err.Error())
})
t.Run("defer", func(t *testing.T) {
f := func() (err error) {
defer Annotate("bad string %q: %w", &errTest, s)
return errTest
}
err := f()
require.NotNil(t, err)
assert.Equal(t, wantMsg, err.Error())
})
}

99
internal/dhcpd/iprange.go Normal file
View File

@ -0,0 +1,99 @@
package dhcpd
import (
"fmt"
"math"
"math/big"
"net"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
)
// ipRange is an inclusive range of IP addresses.
//
// It is safe for concurrent use.
//
// TODO(a.garipov): Perhaps create an optimised version with uint32 for
// IPv4 ranges? Or use one of uint128 packages?
type ipRange struct {
start *big.Int
end *big.Int
}
// maxRangeLen is the maximum IP range length. The bitsets used in servers only
// accept uints, which can have the size of 32 bit.
const maxRangeLen = math.MaxUint32
// newIPRange creates a new IP address range. start must be less than end. The
// resulting range must not be greater than maxRangeLen.
func newIPRange(start, end net.IP) (r *ipRange, err error) {
defer agherr.Annotate("invalid ip range: %w", &err)
// Make sure that both are 16 bytes long to simplify handling in
// methods.
start, end = start.To16(), end.To16()
startInt := (&big.Int{}).SetBytes(start)
endInt := (&big.Int{}).SetBytes(end)
diff := (&big.Int{}).Sub(endInt, startInt)
if diff.Sign() <= 0 {
return nil, fmt.Errorf("start is greater than or equal to end")
} else if !diff.IsUint64() || diff.Uint64() > maxRangeLen {
return nil, fmt.Errorf("range is too large")
}
r = &ipRange{
start: startInt,
end: endInt,
}
return r, nil
}
// contains returns true if r contains ip.
func (r *ipRange) contains(ip net.IP) (ok bool) {
ipInt := (&big.Int{}).SetBytes(ip.To16())
return r.containsInt(ipInt)
}
// containsInt returns true if r contains ipInt.
func (r *ipRange) containsInt(ipInt *big.Int) (ok bool) {
return ipInt.Cmp(r.start) >= 0 && ipInt.Cmp(r.end) <= 0
}
// ipPredicate is a function that is called on every IP address in
// (*ipRange).find. ip is given in the 16-byte form.
type ipPredicate func(ip net.IP) (ok bool)
// find finds the first IP address in r for which p returns true. ip is in the
// 16-byte form.
func (r *ipRange) find(p ipPredicate) (ip net.IP) {
ip = make(net.IP, net.IPv6len)
_1 := big.NewInt(1)
for i := (&big.Int{}).Set(r.start); i.Cmp(r.end) <= 0; i.Add(i, _1) {
i.FillBytes(ip)
if p(ip) {
return ip
}
}
return nil
}
// offset returns the offset of ip from the beginning of r. It returns 0 and
// false if ip is not in r.
func (r *ipRange) offset(ip net.IP) (offset uint, ok bool) {
ip = ip.To16()
ipInt := (&big.Int{}).SetBytes(ip)
if !r.containsInt(ipInt) {
return 0, false
}
offsetInt := (&big.Int{}).Sub(ipInt, r.start)
// Assume that the range was checked against maxRangeLen during
// construction.
return uint(offsetInt.Uint64()), true
}

View File

@ -0,0 +1,154 @@
package dhcpd
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewIPRange(t *testing.T) {
start4 := net.IP{0, 0, 0, 1}
end4 := net.IP{0, 0, 0, 3}
start6 := net.IP{
0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
}
end6 := net.IP{
0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x03,
}
end6Large := net.IP{
0x02, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x03,
}
testCases := []struct {
name string
wantErrMsg string
start net.IP
end net.IP
}{{
name: "success_ipv4",
wantErrMsg: "",
start: start4,
end: end4,
}, {
name: "success_ipv6",
wantErrMsg: "",
start: start6,
end: end6,
}, {
name: "start_gt_end",
wantErrMsg: "invalid ip range: start is greater than or equal to end",
start: end4,
end: start4,
}, {
name: "start_eq_end",
wantErrMsg: "invalid ip range: start is greater than or equal to end",
start: start4,
end: start4,
}, {
name: "too_large",
wantErrMsg: "invalid ip range: range is too large",
start: start6,
end: end6Large,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r, err := newIPRange(tc.start, tc.end)
if tc.wantErrMsg == "" {
assert.Nil(t, err)
assert.NotNil(t, r)
} else {
require.NotNil(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
})
}
}
func TestIPRange_Contains(t *testing.T) {
start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 3}
r, err := newIPRange(start, end)
require.Nil(t, err)
assert.True(t, r.contains(start))
assert.True(t, r.contains(net.IP{0, 0, 0, 2}))
assert.True(t, r.contains(end))
assert.False(t, r.contains(net.IP{0, 0, 0, 0}))
assert.False(t, r.contains(net.IP{0, 0, 0, 4}))
}
func TestIPRange_Find(t *testing.T) {
start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 5}
r, err := newIPRange(start, end)
require.Nil(t, err)
want := net.IPv4(0, 0, 0, 2)
got := r.find(func(ip net.IP) (ok bool) {
return ip[len(ip)-1]%2 == 0
})
assert.Equal(t, want, got)
got = r.find(func(ip net.IP) (ok bool) {
return ip[len(ip)-1]%10 == 0
})
assert.Nil(t, got)
}
func TestIPRange_Offset(t *testing.T) {
start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 5}
r, err := newIPRange(start, end)
require.Nil(t, err)
testCases := []struct {
name string
in net.IP
wantOffset uint
wantOK bool
}{{
name: "in",
in: net.IP{0, 0, 0, 2},
wantOffset: 1,
wantOK: true,
}, {
name: "in_start",
in: start,
wantOffset: 0,
wantOK: true,
}, {
name: "in_end",
in: end,
wantOffset: 4,
wantOK: true,
}, {
name: "out_after",
in: net.IP{0, 0, 0, 6},
wantOffset: 0,
wantOK: false,
}, {
name: "out_before",
in: net.IP{0, 0, 0, 0},
wantOffset: 0,
wantOK: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
offset, ok := r.offset(tc.in)
assert.Equal(t, tc.wantOffset, offset)
assert.Equal(t, tc.wantOK, ok)
})
}
}

View File

@ -100,11 +100,7 @@ func newDHCPOptionParser() (p *dhcpOptionParser) {
// parse parses an option. See the handlers' documentation for more info. // parse parses an option. See the handlers' documentation for more info.
func (p *dhcpOptionParser) parse(s string) (code uint8, data []byte, err error) { func (p *dhcpOptionParser) parse(s string) (code uint8, data []byte, err error) {
defer func() { defer agherr.Annotate("invalid option string %q: %w", &err, s)
if err != nil {
err = fmt.Errorf("invalid option string %q: %w", s, err)
}
}()
s = strings.TrimSpace(s) s = strings.TrimSpace(s)
parts := strings.SplitN(s, " ", 3) parts := strings.SplitN(s, " ", 3)

View File

@ -8,7 +8,7 @@ import (
) )
func TestDHCPOptionParser(t *testing.T) { func TestDHCPOptionParser(t *testing.T) {
testCasesA := []struct { testCases := []struct {
name string name string
in string in string
wantErrMsg string wantErrMsg string
@ -92,7 +92,7 @@ func TestDHCPOptionParser(t *testing.T) {
p := newDHCPOptionParser() p := newDHCPOptionParser()
for _, tc := range testCasesA { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
code, data, err := p.parse(tc.in) code, data, err := p.parse(tc.in)
if tc.wantErrMsg == "" { if tc.wantErrMsg == "" {

View File

@ -60,8 +60,8 @@ type V4ServerConf struct {
// DEC_CODE ip IP_ADDR // DEC_CODE ip IP_ADDR
Options []string `yaml:"options" json:"-"` Options []string `yaml:"options" json:"-"`
ipStart net.IP // starting IP address for dynamic leases ipRange *ipRange
ipEnd net.IP // ending IP address for dynamic leases
leaseTime time.Duration // the time during which a dynamic lease is considered valid leaseTime time.Duration // the time during which a dynamic lease is considered valid
dnsIPAddrs []net.IP // IPv4 addresses to return to DHCP clients as DNS server addresses dnsIPAddrs []net.IP // IPv4 addresses to return to DHCP clients as DNS server addresses
routerIP net.IP // value for Option Router routerIP net.IP // value for Option Router

View File

@ -4,7 +4,6 @@ package dhcpd
import ( import (
"bytes" "bytes"
"encoding/binary"
"fmt" "fmt"
"net" "net"
"sync" "sync"
@ -14,19 +13,25 @@ import (
"github.com/go-ping/ping" "github.com/go-ping/ping"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4" "github.com/insomniacslk/dhcp/dhcpv4/server4"
"github.com/willf/bitset"
) )
// v4Server is a DHCPv4 server. // v4Server is a DHCPv4 server.
// //
// TODO(a.garipov): Think about unifying this and v6Server. // TODO(a.garipov): Think about unifying this and v6Server.
type v4Server struct { type v4Server struct {
srv *server4.Server
leasesLock sync.Mutex
leases []*Lease
// TODO(e.burkov): This field type should be a normal bitmap.
ipAddrs [256]byte
conf V4ServerConf conf V4ServerConf
srv *server4.Server
// leasedOffsets contains offsets from conf.ipRange.start that have been
// leased.
leasedOffsets *bitset.BitSet
// leases contains all dynamic and static leases.
leases []*Lease
// leasesLock protects leases and leasedOffsets.
leasesLock sync.Mutex
} }
// WriteDiskConfig4 - write configuration // WriteDiskConfig4 - write configuration
@ -38,27 +43,14 @@ func (s *v4Server) WriteDiskConfig4(c *V4ServerConf) {
func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) { func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) {
} }
// Return TRUE if IP address is within range [start..stop]
func ip4InRange(start, stop, ip net.IP) bool {
if len(start) != 4 || len(stop) != 4 {
return false
}
from := binary.BigEndian.Uint32(start)
to := binary.BigEndian.Uint32(stop)
check := binary.BigEndian.Uint32(ip)
return from <= check && check <= to
}
// ResetLeases - reset leases // ResetLeases - reset leases
func (s *v4Server) ResetLeases(leases []*Lease) { func (s *v4Server) ResetLeases(leases []*Lease) {
s.leases = nil s.leases = nil
for _, l := range leases { for _, l := range leases {
if l.Expiry.Unix() != leaseExpireStatic && !s.conf.ipRange.contains(l.IP) {
log.Debug("dhcpv4: skipping a lease with ip %v: not within current ip range", l.IP)
if l.Expiry.Unix() != leaseExpireStatic &&
!ip4InRange(s.conf.ipStart, s.conf.ipEnd, l.IP) {
log.Debug("dhcpv4: skipping a lease with IP %v: not within current IP range", l.IP)
continue continue
} }
@ -127,16 +119,18 @@ func (s *v4Server) blacklistLease(lease *Lease) {
lease.Expiry = time.Now().Add(s.conf.leaseTime) lease.Expiry = time.Now().Add(s.conf.leaseTime)
} }
// Remove (swap) lease by index // rmLeaseByIndex removes a lease by its index in the leases slice.
func (s *v4Server) leaseRemoveSwapByIndex(i int) { func (s *v4Server) rmLeaseByIndex(i int) {
s.ipAddrs[s.leases[i].IP[3]] = 0 l := s.leases[i]
log.Debug("dhcpv4: removed lease %s", s.leases[i].HWAddr) s.leases = append(s.leases[:i], s.leases[i+1:]...)
n := len(s.leases) r := s.conf.ipRange
if i != n-1 { offset, ok := r.offset(l.IP)
s.leases[i] = s.leases[n-1] // swap with the last element if ok {
s.leasedOffsets.Clear(offset)
} }
s.leases = s.leases[:n-1]
log.Debug("dhcpv4: removed lease %s (%s)", l.IP, l.HWAddr)
} }
// Remove a dynamic lease with the same properties // Remove a dynamic lease with the same properties
@ -146,51 +140,61 @@ func (s *v4Server) rmDynamicLease(lease Lease) error {
l := s.leases[i] l := s.leases[i]
if bytes.Equal(l.HWAddr, lease.HWAddr) { if bytes.Equal(l.HWAddr, lease.HWAddr) {
if l.Expiry.Unix() == leaseExpireStatic { if l.Expiry.Unix() == leaseExpireStatic {
return fmt.Errorf("static lease already exists") return fmt.Errorf("static lease already exists")
} }
s.leaseRemoveSwapByIndex(i) s.rmLeaseByIndex(i)
if i == len(s.leases) { if i == len(s.leases) {
break break
} }
l = s.leases[i] l = s.leases[i]
} }
if net.IP.Equal(l.IP, lease.IP) { if net.IP.Equal(l.IP, lease.IP) {
if l.Expiry.Unix() == leaseExpireStatic { if l.Expiry.Unix() == leaseExpireStatic {
return fmt.Errorf("static lease already exists") return fmt.Errorf("static lease already exists")
} }
s.leaseRemoveSwapByIndex(i) s.rmLeaseByIndex(i)
} }
} }
return nil return nil
} }
// Add a lease // addLease adds a lease.
func (s *v4Server) addLease(l *Lease) { func (s *v4Server) addLease(l *Lease) {
r := s.conf.ipRange
offset, ok := r.offset(l.IP)
if !ok {
// TODO(a.garipov): Better error handling.
log.Debug("dhcpv4: lease %s (%s) out of range, not adding", l.IP, l.HWAddr)
return
}
s.leases = append(s.leases, l) s.leases = append(s.leases, l)
s.ipAddrs[l.IP[3]] = 1 s.leasedOffsets.Set(uint(offset))
log.Debug("dhcpv4: added lease %s <-> %s", l.IP, l.HWAddr)
log.Debug("dhcpv4: added lease %s (%s)", l.IP, l.HWAddr)
} }
// Remove a lease with the same properties // Remove a lease with the same properties
func (s *v4Server) rmLease(lease Lease) error { func (s *v4Server) rmLease(lease Lease) error {
for i, l := range s.leases { for i, l := range s.leases {
if net.IP.Equal(l.IP, lease.IP) { if l.IP.Equal(lease.IP) {
if !bytes.Equal(l.HWAddr, lease.HWAddr) || if !bytes.Equal(l.HWAddr, lease.HWAddr) ||
l.Hostname != lease.Hostname { l.Hostname != lease.Hostname {
return fmt.Errorf("lease not found") return fmt.Errorf("lease not found")
} }
s.leaseRemoveSwapByIndex(i) s.rmLeaseByIndex(i)
return nil return nil
} }
} }
return fmt.Errorf("lease not found") return fmt.Errorf("lease not found")
} }
@ -258,7 +262,7 @@ func (s *v4Server) addrAvailable(target net.IP) bool {
pinger.Timeout = time.Duration(s.conf.ICMPTimeout) * time.Millisecond pinger.Timeout = time.Duration(s.conf.ICMPTimeout) * time.Millisecond
pinger.Count = 1 pinger.Count = 1
reply := false reply := false
pinger.OnRecv = func(pkt *ping.Packet) { pinger.OnRecv = func(_ *ping.Packet) {
reply = true reply = true
} }
log.Debug("dhcpv4: Sending ICMP Echo to %v", target) log.Debug("dhcpv4: Sending ICMP Echo to %v", target)
@ -278,30 +282,31 @@ func (s *v4Server) addrAvailable(target net.IP) bool {
return true return true
} }
// Find lease by MAC // findLease finds a lease by its MAC-address.
func (s *v4Server) findLease(mac net.HardwareAddr) *Lease { func (s *v4Server) findLease(mac net.HardwareAddr) (l *Lease) {
for i := range s.leases { for _, l = range s.leases {
if bytes.Equal(mac, s.leases[i].HWAddr) { if bytes.Equal(mac, l.HWAddr) {
return s.leases[i] return l
} }
} }
return nil return nil
} }
// Get next free IP // nextIP generates a new free IP.
func (s *v4Server) findFreeIP() net.IP { func (s *v4Server) nextIP() (ip net.IP) {
for i := s.conf.ipStart[3]; ; i++ { r := s.conf.ipRange
if s.ipAddrs[i] == 0 { ip = r.find(func(next net.IP) (ok bool) {
ip := make([]byte, 4) offset, ok := r.offset(next)
copy(ip, s.conf.ipStart) if !ok {
ip[3] = i // Shouldn't happen.
return ip return false
} }
if i == s.conf.ipEnd[3] {
break return !s.leasedOffsets.Test(uint(offset))
} })
}
return nil return ip.To4()
} }
// Find an expired lease and return its index or -1 // Find an expired lease and return its index or -1
@ -316,24 +321,30 @@ func (s *v4Server) findExpiredLease() int {
return -1 return -1
} }
// Reserve lease for MAC // reserveLease reserves a lease for a client by its MAC-address. It returns
func (s *v4Server) reserveLease(mac net.HardwareAddr) *Lease { // nil if it couldn't allocate a new lease.
l := Lease{} func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease) {
l.HWAddr = make([]byte, 6) l = &Lease{
HWAddr: make([]byte, 6),
}
copy(l.HWAddr, mac) copy(l.HWAddr, mac)
l.IP = s.findFreeIP() l.IP = s.nextIP()
if l.IP == nil { if l.IP == nil {
i := s.findExpiredLease() i := s.findExpiredLease()
if i < 0 { if i < 0 {
return nil return nil
} }
copy(s.leases[i].HWAddr, mac) copy(s.leases[i].HWAddr, mac)
return s.leases[i] return s.leases[i]
} }
s.addLease(&l) s.addLease(l)
return &l
return l
} }
func (s *v4Server) commitLease(l *Lease) { func (s *v4Server) commitLease(l *Lease) {
@ -650,22 +661,12 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
s.conf.subnetMask = make([]byte, 4) s.conf.subnetMask = make([]byte, 4)
copy(s.conf.subnetMask, s.conf.SubnetMask.To4()) copy(s.conf.subnetMask, s.conf.SubnetMask.To4())
s.conf.ipStart, err = tryTo4(conf.RangeStart) s.conf.ipRange, err = newIPRange(conf.RangeStart, conf.RangeEnd)
if s.conf.ipStart == nil { if err != nil {
return s, fmt.Errorf("dhcpv4: %w", err) return s, fmt.Errorf("dhcpv4: %w", err)
} }
if s.conf.ipStart[0] == 0 {
return s, fmt.Errorf("dhcpv4: invalid range start IP")
}
s.conf.ipEnd, err = tryTo4(conf.RangeEnd) s.leasedOffsets = &bitset.BitSet{}
if s.conf.ipEnd == nil {
return s, fmt.Errorf("dhcpv4: %w", err)
}
if !net.IP.Equal(s.conf.ipStart[:3], s.conf.ipEnd[:3]) ||
s.conf.ipStart[3] > s.conf.ipEnd[3] {
return s, fmt.Errorf("dhcpv4: range end IP should match range start IP")
}
if conf.LeaseDuration == 0 { if conf.LeaseDuration == 0 {
s.conf.leaseTime = time.Hour * 24 s.conf.leaseTime = time.Hour * 24

View File

@ -212,18 +212,26 @@ func TestV4DynamicLease_Get(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 1, s.process(req, resp)) assert.Equal(t, 1, s.process(req, resp))
}) })
// Don't continue if we got any errors in the previous subtest.
require.Nil(t, err) require.Nil(t, err)
t.Run("offer", func(t *testing.T) { t.Run("offer", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr) assert.Equal(t, mac, resp.ClientHWAddr)
assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr))
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0])) assert.Equal(t, s.conf.RangeStart, resp.YourIPAddr)
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier())) assert.Equal(t, s.conf.GatewayIP, resp.ServerIdentifier())
router := resp.Router()
require.Len(t, router, 1)
assert.Equal(t, s.conf.GatewayIP, router[0])
assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) assert.Equal(t, s.conf.subnetMask, resp.SubnetMask())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)]) assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)])
assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)])))
assert.Equal(t, net.IP{1, 2, 3, 4}, net.IP(resp.RelayAgentInfo().ToBytes()))
}) })
t.Run("request", func(t *testing.T) { t.Run("request", func(t *testing.T) {
@ -260,31 +268,3 @@ func TestV4DynamicLease_Get(t *testing.T) {
assert.Equal(t, mac, ls[0].HWAddr) assert.Equal(t, mac, ls[0].HWAddr)
}) })
} }
func TestIP4InRange(t *testing.T) {
start := net.IP{192, 168, 10, 100}
stop := net.IP{192, 168, 10, 200}
testCases := []struct {
ip net.IP
want bool
}{{
ip: net.IP{192, 168, 10, 99},
want: false,
}, {
ip: net.IP{192, 168, 11, 100},
want: false,
}, {
ip: net.IP{192, 168, 11, 201},
want: false,
}, {
ip: start,
want: true,
}}
for _, tc := range testCases {
t.Run(tc.ip.String(), func(t *testing.T) {
assert.Equal(t, tc.want, ip4InRange(start, stop, tc.ip))
})
}
}

View File

@ -231,8 +231,17 @@ func TestClientsAddExisting(t *testing.T) {
// First, init a DHCP server with a single static lease. // First, init a DHCP server with a single static lease.
config := dhcpd.ServerConfig{ config := dhcpd.ServerConfig{
Enabled: true,
DBFilePath: "leases.db", DBFilePath: "leases.db",
Conf4: dhcpd.V4ServerConf{
Enabled: true,
GatewayIP: net.IP{1, 2, 3, 1},
SubnetMask: net.IP{255, 255, 255, 0},
RangeStart: net.IP{1, 2, 3, 2},
RangeEnd: net.IP{1, 2, 3, 10},
},
} }
clients.dhcpServer = dhcpd.Create(config) clients.dhcpServer = dhcpd.Create(config)
t.Cleanup(func() { _ = os.Remove("leases.db") }) t.Cleanup(func() { _ = os.Remove("leases.db") })

View File

@ -1,6 +1,8 @@
checks = ["all"] checks = ["all"]
initialisms = [ initialisms = [
# See https://github.com/dominikh/go-tools/blob/master/config/config.go. # See https://github.com/dominikh/go-tools/blob/master/config/config.go.
#
# Do not add "PTR" since we use "Ptr" as a suffix.
"inherit" "inherit"
, "DHCP" , "DHCP"
, "DOH" , "DOH"
@ -8,7 +10,6 @@ initialisms = [
, "DOT" , "DOT"
, "EDNS" , "EDNS"
, "MX" , "MX"
, "PTR"
, "QUIC" , "QUIC"
, "RA" , "RA"
, "SDNS" , "SDNS"