diff --git a/internal/aghio/limitedreadcloser_test.go b/internal/aghio/limitedreadcloser_test.go index 1f10e32b..9cccda17 100644 --- a/internal/aghio/limitedreadcloser_test.go +++ b/internal/aghio/limitedreadcloser_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLimitReadCloser(t *testing.T) { @@ -78,11 +79,11 @@ func TestLimitedReadCloser_Read(t *testing.T) { buf := make([]byte, tc.limit+1) lreader, err := LimitReadCloser(readCloser, tc.limit) - assert.Nil(t, err) + require.Nil(t, err) n, err := lreader.Read(buf) - assert.Equal(t, n, tc.want) - assert.Equal(t, tc.err, err) + require.Equal(t, tc.err, err) + assert.Equal(t, tc.want, n) }) } } diff --git a/internal/dhcpd/dhcpd_test.go b/internal/dhcpd/dhcpd_test.go index 1aa1b9a6..e9b44518 100644 --- a/internal/dhcpd/dhcpd_test.go +++ b/internal/dhcpd/dhcpd_test.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -20,116 +21,156 @@ func TestMain(m *testing.M) { func testNotify(flags uint32) { } -// Leases database store/load +// Leases database store/load. func TestDB(t *testing.T) { var err error - s := Server{} - s.conf.DBFilePath = dbFilename + s := Server{ + conf: ServerConfig{ + DBFilePath: dbFilename, + }, + } - conf := V4ServerConf{ + s.srv4, err = v4Create(V4ServerConf{ Enabled: true, RangeStart: net.IP{192, 168, 10, 100}, RangeEnd: net.IP{192, 168, 10, 200}, GatewayIP: net.IP{192, 168, 10, 1}, SubnetMask: net.IP{255, 255, 255, 0}, notify: testNotify, - } - s.srv4, err = v4Create(conf) - assert.Nil(t, err) + }) + require.Nil(t, err) s.srv6, err = v6Create(V6ServerConf{}) - assert.Nil(t, err) + require.Nil(t, err) - l := Lease{} - l.IP = net.IP{192, 168, 10, 100} - l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - exp1 := time.Now().Add(time.Hour) - l.Expiry = exp1 + leases := []Lease{{ + IP: net.IP{192, 168, 10, 100}, + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + Expiry: time.Now().Add(time.Hour), + }, { + IP: net.IP{192, 168, 10, 101}, + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xBB}, + }} srv4, ok := s.srv4.(*v4Server) - assert.True(t, ok) + require.True(t, ok) - srv4.addLease(&l) + srv4.addLease(&leases[0]) + require.Nil(t, s.srv4.AddStaticLease(leases[1])) - l2 := Lease{} - l2.IP = net.IP{192, 168, 10, 101} - l2.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:bb") - err = s.srv4.AddStaticLease(l2) - assert.Nil(t, err) - - _ = os.Remove("leases.db") s.dbStore() + t.Cleanup(func() { + assert.Nil(t, os.Remove(dbFilename)) + }) s.srv4.ResetLeases(nil) - s.dbLoad() ll := s.srv4.GetLeases(LeasesAll) + require.Len(t, ll, len(leases)) - assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String()) - assert.True(t, net.IP{192, 168, 10, 101}.Equal(ll[0].IP)) + assert.Equal(t, leases[1].HWAddr, ll[0].HWAddr) + assert.Equal(t, leases[1].IP, ll[0].IP) assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix()) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String()) - assert.True(t, net.IP{192, 168, 10, 100}.Equal(ll[1].IP)) - assert.Equal(t, exp1.Unix(), ll[1].Expiry.Unix()) - - _ = os.Remove("leases.db") + assert.Equal(t, leases[0].HWAddr, ll[1].HWAddr) + assert.Equal(t, leases[0].IP, ll[1].IP) + assert.Equal(t, leases[0].Expiry.Unix(), ll[1].Expiry.Unix()) } func TestIsValidSubnetMask(t *testing.T) { - assert.True(t, isValidSubnetMask([]byte{255, 255, 255, 0})) - assert.True(t, isValidSubnetMask([]byte{255, 255, 254, 0})) - assert.True(t, isValidSubnetMask([]byte{255, 255, 252, 0})) - assert.False(t, isValidSubnetMask([]byte{255, 255, 253, 0})) - assert.False(t, isValidSubnetMask([]byte{255, 255, 255, 1})) + testCases := []struct { + mask net.IP + want bool + }{{ + mask: net.IP{255, 255, 255, 0}, + want: true, + }, { + mask: net.IP{255, 255, 254, 0}, + want: true, + }, { + mask: net.IP{255, 255, 252, 0}, + want: true, + }, { + mask: net.IP{255, 255, 253, 0}, + }, { + mask: net.IP{255, 255, 255, 1}, + }} + + for _, tc := range testCases { + t.Run(tc.mask.String(), func(t *testing.T) { + assert.Equal(t, tc.want, isValidSubnetMask(tc.mask)) + }) + } } func TestNormalizeLeases(t *testing.T) { - dynLeases := []*Lease{} - staticLeases := []*Lease{} + dynLeases := []*Lease{{ + HWAddr: net.HardwareAddr{1, 2, 3, 4}, + }, { + HWAddr: net.HardwareAddr{1, 2, 3, 5}, + }} - lease := &Lease{} - lease.HWAddr = []byte{1, 2, 3, 4} - dynLeases = append(dynLeases, lease) - lease = new(Lease) - lease.HWAddr = []byte{1, 2, 3, 5} - dynLeases = append(dynLeases, lease) - - lease = new(Lease) - lease.HWAddr = []byte{1, 2, 3, 4} - lease.IP = []byte{0, 2, 3, 4} - staticLeases = append(staticLeases, lease) - lease = new(Lease) - lease.HWAddr = []byte{2, 2, 3, 4} - staticLeases = append(staticLeases, lease) + staticLeases := []*Lease{{ + HWAddr: net.HardwareAddr{1, 2, 3, 4}, + IP: net.IP{0, 2, 3, 4}, + }, { + HWAddr: net.HardwareAddr{2, 2, 3, 4}, + }} leases := normalizeLeases(staticLeases, dynLeases) + require.Len(t, leases, 3) - assert.Len(t, leases, 3) - assert.True(t, bytes.Equal(leases[0].HWAddr, []byte{1, 2, 3, 4})) - assert.True(t, bytes.Equal(leases[0].IP, []byte{0, 2, 3, 4})) - assert.True(t, bytes.Equal(leases[1].HWAddr, []byte{2, 2, 3, 4})) - assert.True(t, bytes.Equal(leases[2].HWAddr, []byte{1, 2, 3, 5})) + assert.Equal(t, leases[0].HWAddr, dynLeases[0].HWAddr) + assert.Equal(t, leases[0].IP, staticLeases[0].IP) + assert.Equal(t, leases[1].HWAddr, staticLeases[1].HWAddr) + assert.Equal(t, leases[2].HWAddr, dynLeases[1].HWAddr) } func TestOptions(t *testing.T) { - code, val := parseOptionString(" 12 hex abcdef ") - assert.EqualValues(t, 12, code) - assert.True(t, bytes.Equal([]byte{0xab, 0xcd, 0xef}, val)) + testCases := []struct { + name string + optStr string + wantCode uint8 + wantVal []byte + }{{ + name: "all_right_hex", + optStr: " 12 hex abcdef ", + wantCode: 12, + wantVal: []byte{0xab, 0xcd, 0xef}, + }, { + name: "bad_hex", + optStr: " 12 hex abcdef1 ", + wantCode: 0, + }, { + name: "all_right_ip", + optStr: "123 ip 1.2.3.4", + wantCode: 123, + wantVal: net.IPv4(1, 2, 3, 4), + }, { + name: "bad_code", + optStr: "256 ip 1.1.1.1", + wantCode: 0, + }, { + name: "negative_code", + optStr: "-1 ip 1.1.1.1", + wantCode: 0, + }, { + name: "bad_ip", + optStr: "12 ip 1.1.1.1x", + wantCode: 0, + }, { + name: "bad_mode", + optStr: "12 x 1.1.1.1", + wantCode: 0, + }} - code, _ = parseOptionString(" 12 hex abcdef1 ") - assert.EqualValues(t, 0, code) - - code, val = parseOptionString("123 ip 1.2.3.4") - assert.EqualValues(t, 123, code) - assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(val))) - - code, _ = parseOptionString("256 ip 1.1.1.1") - assert.EqualValues(t, 0, code) - code, _ = parseOptionString("-1 ip 1.1.1.1") - assert.EqualValues(t, 0, code) - code, _ = parseOptionString("12 ip 1.1.1.1x") - assert.EqualValues(t, 0, code) - code, _ = parseOptionString("12 x 1.1.1.1") - assert.EqualValues(t, 0, code) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + code, val := parseOptionString(tc.optStr) + require.EqualValues(t, tc.wantCode, code) + if tc.wantVal != nil { + assert.True(t, bytes.Equal(tc.wantVal, val)) + } + }) + } } diff --git a/internal/dhcpd/dhcphttp_test.go b/internal/dhcpd/dhcphttp_test.go index 47b926dc..36a89a6e 100644 --- a/internal/dhcpd/dhcphttp_test.go +++ b/internal/dhcpd/dhcphttp_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestServer_notImplemented(t *testing.T) { @@ -14,7 +15,7 @@ func TestServer_notImplemented(t *testing.T) { w := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, "/unsupported", nil) - assert.Nil(t, err) + require.Nil(t, err) h(w, r) assert.Equal(t, http.StatusNotImplemented, w.Code) diff --git a/internal/dhcpd/routeradv_test.go b/internal/dhcpd/routeradv_test.go index 95f3d4fa..4a0f4c5b 100644 --- a/internal/dhcpd/routeradv_test.go +++ b/internal/dhcpd/routeradv_test.go @@ -1,7 +1,6 @@ package dhcpd import ( - "bytes" "net" "testing" @@ -9,7 +8,7 @@ import ( ) func TestRA(t *testing.T) { - ra := icmpv6RA{ + data := createICMPv6RAPacket(icmpv6RA{ managedAddressConfiguration: false, otherConfiguration: true, mtu: 1500, @@ -17,8 +16,7 @@ func TestRA(t *testing.T) { prefixLen: 64, recursiveDNSServer: net.ParseIP("fe80::800:27ff:fe00:0"), sourceLinkLayerAddress: []byte{0x0a, 0x00, 0x27, 0x00, 0x00, 0x00}, - } - data := createICMPv6RAPacket(ra) + }) dataCorrect := []byte{ 0x86, 0x00, 0x00, 0x00, 0x40, 0x40, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x00, 0x00, @@ -27,5 +25,5 @@ func TestRA(t *testing.T) { 0x19, 0x03, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x10, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x27, 0xff, 0xfe, 0x00, 0x00, 0x00, } - assert.True(t, bytes.Equal(data, dataCorrect)) + assert.Equal(t, dataCorrect, data) } diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index 2f5484a2..7d24699e 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -23,7 +23,8 @@ type v4Server struct { srv *server4.Server leasesLock sync.Mutex leases []*Lease - ipAddrs [256]byte + // TODO(e.burkov): This field type should be a normal bitmap. + ipAddrs [256]byte conf V4ServerConf } diff --git a/internal/dhcpd/v46_test.go b/internal/dhcpd/v46_test.go index 6007205d..6495eeee 100644 --- a/internal/dhcpd/v46_test.go +++ b/internal/dhcpd/v46_test.go @@ -7,6 +7,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type fakeIface struct { @@ -79,8 +80,8 @@ func TestIfaceIPAddrs(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { got, gotErr := ifaceIPAddrs(tc.iface, tc.ipv) + require.True(t, errors.Is(gotErr, tc.wantErr)) assert.Equal(t, tc.want, got) - assert.True(t, errors.Is(gotErr, tc.wantErr)) }) } } @@ -140,12 +141,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) { want: nil, wantErr: errTest, }, { - name: "ipv4_wait", - iface: &waitingFakeIface{ - addrs: []net.Addr{addr4}, - err: nil, - n: 1, - }, + name: "ipv4_wait", + iface: &waitingFakeIface{addrs: []net.Addr{addr4}, err: nil, n: 1}, ipv: ipVersion4, want: []net.IP{ip4, ip4}, wantErr: nil, @@ -168,12 +165,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) { want: nil, wantErr: errTest, }, { - name: "ipv6_wait", - iface: &waitingFakeIface{ - addrs: []net.Addr{addr6}, - err: nil, - n: 1, - }, + name: "ipv6_wait", + iface: &waitingFakeIface{addrs: []net.Addr{addr6}, err: nil, n: 1}, ipv: ipVersion6, want: []net.IP{ip6, ip6}, wantErr: nil, @@ -182,8 +175,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { got, gotErr := ifaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0) + require.True(t, errors.Is(gotErr, tc.wantErr)) assert.Equal(t, tc.want, got) - assert.True(t, errors.Is(gotErr, tc.wantErr)) }) } } diff --git a/internal/dhcpd/v4_test.go b/internal/dhcpd/v4_test.go index 8edb3113..d204a200 100644 --- a/internal/dhcpd/v4_test.go +++ b/internal/dhcpd/v4_test.go @@ -8,172 +8,182 @@ import ( "github.com/insomniacslk/dhcp/dhcpv4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func notify4(flags uint32) { } -func TestV4StaticLeaseAddRemove(t *testing.T) { - conf := V4ServerConf{ +func TestV4_AddRemove_static(t *testing.T) { + s, err := v4Create(V4ServerConf{ Enabled: true, RangeStart: net.IP{192, 168, 10, 100}, RangeEnd: net.IP{192, 168, 10, 200}, GatewayIP: net.IP{192, 168, 10, 1}, SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, - } - s, err := v4Create(conf) - assert.Nil(t, err) + }) + require.Nil(t, err) ls := s.GetLeases(LeasesStatic) assert.Empty(t, ls) - // add static lease - l := Lease{} - l.IP = net.IP{192, 168, 10, 150} - l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.Nil(t, s.AddStaticLease(l)) - - // try to add the same static lease - fail + // Add static lease. + l := Lease{ + IP: net.IP{192, 168, 10, 150}, + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + } + require.Nil(t, s.AddStaticLease(l)) assert.NotNil(t, s.AddStaticLease(l)) - // check ls = s.GetLeases(LeasesStatic) - assert.Len(t, ls, 1) - assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP)) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) + require.Len(t, ls, 1) + assert.True(t, l.IP.Equal(ls[0].IP)) + assert.Equal(t, l.HWAddr, ls[0].HWAddr) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) - // try to remove static lease - fail - l.IP = net.IP{192, 168, 10, 110} - l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.NotNil(t, s.RemoveStaticLease(l)) + // Try to remove static lease. + assert.NotNil(t, s.RemoveStaticLease(Lease{ + IP: net.IP{192, 168, 10, 110}, + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + })) - // remove static lease - l.IP = net.IP{192, 168, 10, 150} - l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.Nil(t, s.RemoveStaticLease(l)) - - // check + // Remove static lease. + require.Nil(t, s.RemoveStaticLease(l)) ls = s.GetLeases(LeasesStatic) assert.Empty(t, ls) } -func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) { - conf := V4ServerConf{ +func TestV4_AddReplace(t *testing.T) { + sIface, err := v4Create(V4ServerConf{ Enabled: true, RangeStart: net.IP{192, 168, 10, 100}, RangeEnd: net.IP{192, 168, 10, 200}, GatewayIP: net.IP{192, 168, 10, 1}, SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, - } - sIface, err := v4Create(conf) + }) + require.Nil(t, err) + s, ok := sIface.(*v4Server) - assert.True(t, ok) - assert.Nil(t, err) + require.True(t, ok) - // add dynamic lease - ld := Lease{} - ld.IP = net.IP{192, 168, 10, 150} - ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa") - s.addLease(&ld) + dynLeases := []Lease{{ + IP: net.IP{192, 168, 10, 150}, + HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + }, { + IP: net.IP{192, 168, 10, 151}, + HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + }} - // add dynamic lease - { - ld := Lease{} - ld.IP = net.IP{192, 168, 10, 151} - ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa") - s.addLease(&ld) + for i := range dynLeases { + s.addLease(&dynLeases[i]) } - // add static lease with the same IP - l := Lease{} - l.IP = net.IP{192, 168, 10, 150} - l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa") - assert.Nil(t, s.AddStaticLease(l)) + stLeases := []Lease{{ + IP: net.IP{192, 168, 10, 150}, + HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + }, { + IP: net.IP{192, 168, 10, 152}, + HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + }} - // add static lease with the same MAC - l = Lease{} - l.IP = net.IP{192, 168, 10, 152} - l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa") - assert.Nil(t, s.AddStaticLease(l)) + for _, l := range stLeases { + require.Nil(t, s.AddStaticLease(l)) + } - // check ls := s.GetLeases(LeasesStatic) - assert.Len(t, ls, 2) + require.Len(t, ls, 2) - assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP)) - assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) - assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) - - assert.True(t, net.IP{192, 168, 10, 152}.Equal(ls[1].IP)) - assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String()) - assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix()) + for i, l := range ls { + assert.True(t, stLeases[i].IP.Equal(l.IP)) + assert.Equal(t, stLeases[i].HWAddr, l.HWAddr) + assert.EqualValues(t, leaseExpireStatic, l.Expiry.Unix()) + } } -func TestV4StaticLeaseGet(t *testing.T) { - conf := V4ServerConf{ +func TestV4StaticLease_Get(t *testing.T) { + var err error + sIface, err := v4Create(V4ServerConf{ Enabled: true, RangeStart: net.IP{192, 168, 10, 100}, RangeEnd: net.IP{192, 168, 10, 200}, GatewayIP: net.IP{192, 168, 10, 1}, SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, - } - sIface, err := v4Create(conf) + }) + require.Nil(t, err) + s, ok := sIface.(*v4Server) - assert.True(t, ok) - assert.Nil(t, err) + require.True(t, ok) s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}} - l := Lease{} - l.IP = net.IP{192, 168, 10, 150} - l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.Nil(t, s.AddStaticLease(l)) + l := Lease{ + IP: net.IP{192, 168, 10, 150}, + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + } + require.Nil(t, s.AddStaticLease(l)) - // "Discover" - mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") - req, _ := dhcpv4.NewDiscovery(mac) - resp, _ := dhcpv4.NewReplyFromRequest(req) - assert.Equal(t, 1, s.process(req, resp)) + var req, resp *dhcpv4.DHCPv4 + mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA} - // check "Offer" - assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) - assert.True(t, net.IP{192, 168, 10, 150}.Equal(resp.YourIPAddr)) - assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0])) - assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier())) - assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask()))) - assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) + t.Run("discover", func(t *testing.T) { + var err error - // "Request" - req, _ = dhcpv4.NewRequestFromOffer(resp) - resp, _ = dhcpv4.NewReplyFromRequest(req) - assert.Equal(t, 1, s.process(req, resp)) + req, err = dhcpv4.NewDiscovery(mac) + require.Nil(t, err) - // check "Ack" - assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType()) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) - assert.True(t, net.IP{192, 168, 10, 150}.Equal(resp.YourIPAddr)) - assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0])) - assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier())) - assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask()))) - assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) + resp, err = dhcpv4.NewReplyFromRequest(req) + require.Nil(t, err) + assert.Equal(t, 1, s.process(req, resp)) + }) + require.Nil(t, err) + + t.Run("offer", func(t *testing.T) { + assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) + assert.Equal(t, mac, resp.ClientHWAddr) + assert.True(t, l.IP.Equal(resp.YourIPAddr)) + assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0])) + assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier())) + assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) + assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) + }) + + t.Run("request", func(t *testing.T) { + req, err = dhcpv4.NewRequestFromOffer(resp) + require.Nil(t, err) + + resp, err = dhcpv4.NewReplyFromRequest(req) + require.Nil(t, err) + assert.Equal(t, 1, s.process(req, resp)) + }) + require.Nil(t, err) + + t.Run("ack", func(t *testing.T) { + assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType()) + assert.Equal(t, mac, resp.ClientHWAddr) + assert.True(t, l.IP.Equal(resp.YourIPAddr)) + assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0])) + assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier())) + assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) + assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) + }) dnsAddrs := resp.DNS() - assert.Len(t, dnsAddrs, 1) - assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0])) + require.Len(t, dnsAddrs, 1) + assert.True(t, s.conf.GatewayIP.Equal(dnsAddrs[0])) - // check lease - ls := s.GetLeases(LeasesStatic) - assert.Len(t, ls, 1) - assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP)) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) + t.Run("check_lease", func(t *testing.T) { + ls := s.GetLeases(LeasesStatic) + require.Len(t, ls, 1) + assert.True(t, l.IP.Equal(ls[0].IP)) + assert.Equal(t, mac, ls[0].HWAddr) + }) } -func TestV4DynamicLeaseGet(t *testing.T) { - conf := V4ServerConf{ +func TestV4DynamicLease_Get(t *testing.T) { + var err error + sIface, err := v4Create(V4ServerConf{ Enabled: true, RangeStart: net.IP{192, 168, 10, 100}, RangeEnd: net.IP{192, 168, 10, 200}, @@ -184,58 +194,97 @@ func TestV4DynamicLeaseGet(t *testing.T) { "81 hex 303132", "82 ip 1.2.3.4", }, - } - sIface, err := v4Create(conf) + }) + require.Nil(t, err) + s, ok := sIface.(*v4Server) - assert.True(t, ok) - assert.Nil(t, err) + require.True(t, ok) s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}} - // "Discover" - mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") - req, _ := dhcpv4.NewDiscovery(mac) - resp, _ := dhcpv4.NewReplyFromRequest(req) - assert.Equal(t, 1, s.process(req, resp)) + var req, resp *dhcpv4.DHCPv4 + mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA} - // check "Offer" - assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) - assert.True(t, net.IP{192, 168, 10, 100}.Equal(resp.YourIPAddr)) - assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0])) - assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier())) - assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask()))) - assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) - assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)]) - assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]))) + t.Run("discover", func(t *testing.T) { + req, err = dhcpv4.NewDiscovery(mac) + require.Nil(t, err) - // "Request" - req, _ = dhcpv4.NewRequestFromOffer(resp) - resp, _ = dhcpv4.NewReplyFromRequest(req) - assert.Equal(t, 1, s.process(req, resp)) + resp, err = dhcpv4.NewReplyFromRequest(req) + require.Nil(t, err) + assert.Equal(t, 1, s.process(req, resp)) + }) + require.Nil(t, err) - // check "Ack" - assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType()) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) - assert.True(t, net.IP{192, 168, 10, 100}.Equal(resp.YourIPAddr)) - assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0])) - assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier())) - assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask()))) - assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) + t.Run("offer", func(t *testing.T) { + assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) + assert.Equal(t, mac, resp.ClientHWAddr) + assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr)) + assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0])) + assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier())) + assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) + assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) + assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)]) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]))) + }) + + t.Run("request", func(t *testing.T) { + var err error + + req, err = dhcpv4.NewRequestFromOffer(resp) + require.Nil(t, err) + + resp, err = dhcpv4.NewReplyFromRequest(req) + require.Nil(t, err) + assert.Equal(t, 1, s.process(req, resp)) + }) + require.Nil(t, err) + + t.Run("ack", func(t *testing.T) { + assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType()) + assert.Equal(t, mac, resp.ClientHWAddr) + assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr)) + assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0])) + assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier())) + assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) + assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) + }) dnsAddrs := resp.DNS() - assert.Len(t, dnsAddrs, 1) + require.Len(t, dnsAddrs, 1) assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0])) // check lease - ls := s.GetLeases(LeasesDynamic) - assert.Len(t, ls, 1) - assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP)) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) + t.Run("check_lease", func(t *testing.T) { + ls := s.GetLeases(LeasesDynamic) + assert.Len(t, ls, 1) + assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP)) + assert.Equal(t, mac, ls[0].HWAddr) + }) +} +func TestIP4InRange(t *testing.T) { start := net.IP{192, 168, 10, 100} stop := net.IP{192, 168, 10, 200} - assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 10, 99})) - assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 11, 100})) - assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 11, 201})) - assert.True(t, ip4InRange(start, stop, net.IP{192, 168, 10, 100})) + + testCases := []struct { + ip net.IP + want bool + }{{ + ip: net.IP{192, 168, 10, 99}, + want: false, + }, { + ip: net.IP{192, 168, 11, 100}, + want: false, + }, { + ip: net.IP{192, 168, 11, 201}, + want: false, + }, { + ip: start, + want: true, + }} + + for _, tc := range testCases { + t.Run(tc.ip.String(), func(t *testing.T) { + assert.Equal(t, tc.want, ip4InRange(start, stop, tc.ip)) + }) + } } diff --git a/internal/dhcpd/v6_test.go b/internal/dhcpd/v6_test.go index 9cdf3ee4..3eb06a89 100644 --- a/internal/dhcpd/v6_test.go +++ b/internal/dhcpd/v6_test.go @@ -9,220 +9,283 @@ import ( "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/iana" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func notify6(flags uint32) { } -func TestV6StaticLeaseAddRemove(t *testing.T) { - conf := V6ServerConf{ +func TestV6_AddRemove_static(t *testing.T) { + s, err := v6Create(V6ServerConf{ Enabled: true, RangeStart: net.ParseIP("2001::1"), notify: notify6, + }) + require.Nil(t, err) + + require.Empty(t, s.GetLeases(LeasesStatic)) + + // Add static lease. + l := Lease{ + IP: net.ParseIP("2001::1"), + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, } - s, err := v6Create(conf) - assert.Nil(t, err) + require.Nil(t, s.AddStaticLease(l)) + + // Try to add the same static lease. + require.NotNil(t, s.AddStaticLease(l)) ls := s.GetLeases(LeasesStatic) - assert.Empty(t, ls) - - // add static lease - l := Lease{} - l.IP = net.ParseIP("2001::1") - l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.Nil(t, s.AddStaticLease(l)) - - // try to add static lease - fail - assert.NotNil(t, s.AddStaticLease(l)) - - // check - ls = s.GetLeases(LeasesStatic) - assert.Len(t, ls, 1) - assert.Equal(t, "2001::1", ls[0].IP.String()) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) + require.Len(t, ls, 1) + assert.Equal(t, l.IP, ls[0].IP) + assert.Equal(t, l.HWAddr, ls[0].HWAddr) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) - // try to remove static lease - fail - l.IP = net.ParseIP("2001::2") - l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.NotNil(t, s.RemoveStaticLease(l)) + // Try to remove non-existent static lease. + require.NotNil(t, s.RemoveStaticLease(Lease{ + IP: net.ParseIP("2001::2"), + HWAddr: l.HWAddr, + })) - // remove static lease - l.IP = net.ParseIP("2001::1") - l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.Nil(t, s.RemoveStaticLease(l)) + // Remove static lease. + require.Nil(t, s.RemoveStaticLease(l)) - // check - ls = s.GetLeases(LeasesStatic) - assert.Empty(t, ls) + assert.Empty(t, s.GetLeases(LeasesStatic)) } -func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) { - conf := V6ServerConf{ +func TestV6_AddReplace(t *testing.T) { + sIface, err := v6Create(V6ServerConf{ Enabled: true, RangeStart: net.ParseIP("2001::1"), notify: notify6, - } - sIface, err := v6Create(conf) + }) + require.Nil(t, err) s, ok := sIface.(*v6Server) - assert.True(t, ok) - assert.Nil(t, err) + require.True(t, ok) - // add dynamic lease - ld := Lease{} - ld.IP = net.ParseIP("2001::1") - ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa") - s.addLease(&ld) + // Add dynamic leases. + dynLeases := []*Lease{{ + IP: net.ParseIP("2001::1"), + HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + }, { + IP: net.ParseIP("2001::2"), + HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + }} - // add dynamic lease - { - ld := Lease{} - ld.IP = net.ParseIP("2001::2") - ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa") - s.addLease(&ld) + for _, l := range dynLeases { + s.addLease(l) } - // add static lease with the same IP - l := Lease{} - l.IP = net.ParseIP("2001::1") - l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa") - assert.Nil(t, s.AddStaticLease(l)) + stLeases := []Lease{{ + IP: net.ParseIP("2001::1"), + HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + }, { + IP: net.ParseIP("2001::3"), + HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + }} - // add static lease with the same MAC - l = Lease{} - l.IP = net.ParseIP("2001::3") - l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa") - assert.Nil(t, s.AddStaticLease(l)) + for _, l := range stLeases { + require.Nil(t, s.AddStaticLease(l)) + } - // check ls := s.GetLeases(LeasesStatic) - assert.Len(t, ls, 2) + require.Len(t, ls, 2) - assert.Equal(t, "2001::1", ls[0].IP.String()) - assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) - assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) - - assert.Equal(t, "2001::3", ls[1].IP.String()) - assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String()) - assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix()) + for i, l := range ls { + assert.True(t, stLeases[i].IP.Equal(l.IP)) + assert.Equal(t, stLeases[i].HWAddr, l.HWAddr) + assert.EqualValues(t, leaseExpireStatic, l.Expiry.Unix()) + } } func TestV6GetLease(t *testing.T) { - conf := V6ServerConf{ + var err error + sIface, err := v6Create(V6ServerConf{ Enabled: true, RangeStart: net.ParseIP("2001::1"), notify: notify6, - } - sIface, err := v6Create(conf) + }) + require.Nil(t, err) s, ok := sIface.(*v6Server) - assert.True(t, ok) - assert.Nil(t, err) - s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")} + require.True(t, ok) + + dnsAddr := net.ParseIP("2000::1") + s.conf.dnsIPAddrs = []net.IP{dnsAddr} s.sid = dhcpv6.Duid{ - Type: dhcpv6.DUID_LLT, - HwType: iana.HWTypeEthernet, + Type: dhcpv6.DUID_LLT, + HwType: iana.HWTypeEthernet, + LinkLayerAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, } - s.sid.LinkLayerAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - l := Lease{} - l.IP = net.ParseIP("2001::1") - l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") - assert.Nil(t, s.AddStaticLease(l)) + l := Lease{ + IP: net.ParseIP("2001::1"), + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + } + require.Nil(t, s.AddStaticLease(l)) - // "Solicit" - mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") - req, _ := dhcpv6.NewSolicit(mac) - msg, _ := req.GetInnerMessage() - resp, _ := dhcpv6.NewAdvertiseFromSolicit(msg) - assert.True(t, s.process(msg, req, resp)) + var req, resp, msg *dhcpv6.Message + mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA} + t.Run("solicit", func(t *testing.T) { + req, err = dhcpv6.NewSolicit(mac) + require.Nil(t, err) + + msg, err = req.GetInnerMessage() + require.Nil(t, err) + + resp, err = dhcpv6.NewAdvertiseFromSolicit(msg) + require.Nil(t, err) + + assert.True(t, s.process(msg, req, resp)) + }) + require.Nil(t, err) resp.AddOption(dhcpv6.OptServerID(s.sid)) - // check "Advertise" - assert.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type()) - oia := resp.Options.OneIANA() - oiaAddr := oia.Options.OneAddress() - assert.Equal(t, "2001::1", oiaAddr.IPv6Addr.String()) - assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds()) + var oia *dhcpv6.OptIANA + var oiaAddr *dhcpv6.OptIAAddress + t.Run("advertise", func(t *testing.T) { + require.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type()) + oia = resp.Options.OneIANA() + oiaAddr = oia.Options.OneAddress() - // "Request" - req, _ = dhcpv6.NewRequestFromAdvertise(resp) - msg, _ = req.GetInnerMessage() - resp, _ = dhcpv6.NewReplyFromMessage(msg) - assert.True(t, s.process(msg, req, resp)) + assert.Equal(t, l.IP, oiaAddr.IPv6Addr) + assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds()) + }) - // check "Reply" - assert.Equal(t, dhcpv6.MessageTypeReply, resp.Type()) - oia = resp.Options.OneIANA() - oiaAddr = oia.Options.OneAddress() - assert.Equal(t, "2001::1", oiaAddr.IPv6Addr.String()) - assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds()) + t.Run("request", func(t *testing.T) { + req, err = dhcpv6.NewRequestFromAdvertise(resp) + require.Nil(t, err) + + msg, err = req.GetInnerMessage() + require.Nil(t, err) + + resp, err = dhcpv6.NewReplyFromMessage(msg) + require.Nil(t, err) + + assert.True(t, s.process(msg, req, resp)) + }) + require.Nil(t, err) + + t.Run("reply", func(t *testing.T) { + require.Equal(t, dhcpv6.MessageTypeReply, resp.Type()) + oia = resp.Options.OneIANA() + oiaAddr = oia.Options.OneAddress() + + assert.Equal(t, l.IP, oiaAddr.IPv6Addr) + assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds()) + }) dnsAddrs := resp.Options.DNS() - assert.Len(t, dnsAddrs, 1) - assert.Equal(t, "2000::1", dnsAddrs[0].String()) + require.Len(t, dnsAddrs, 1) + assert.Equal(t, dnsAddr, dnsAddrs[0]) - // check lease - ls := s.GetLeases(LeasesStatic) - assert.Len(t, ls, 1) - assert.Equal(t, "2001::1", ls[0].IP.String()) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) + t.Run("lease", func(t *testing.T) { + ls := s.GetLeases(LeasesStatic) + require.Len(t, ls, 1) + assert.Equal(t, l.IP, ls[0].IP) + assert.Equal(t, l.HWAddr, ls[0].HWAddr) + }) } func TestV6GetDynamicLease(t *testing.T) { - conf := V6ServerConf{ + sIface, err := v6Create(V6ServerConf{ Enabled: true, RangeStart: net.ParseIP("2001::2"), notify: notify6, - } - sIface, err := v6Create(conf) + }) + require.Nil(t, err) s, ok := sIface.(*v6Server) - assert.True(t, ok) - assert.Nil(t, err) - s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")} - s.sid = dhcpv6.Duid{ - Type: dhcpv6.DUID_LLT, - HwType: iana.HWTypeEthernet, - } - s.sid.LinkLayerAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") + require.True(t, ok) - // "Solicit" - mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") - req, _ := dhcpv6.NewSolicit(mac) - msg, _ := req.GetInnerMessage() - resp, _ := dhcpv6.NewAdvertiseFromSolicit(msg) - assert.True(t, s.process(msg, req, resp)) + dnsAddr := net.ParseIP("2000::1") + s.conf.dnsIPAddrs = []net.IP{dnsAddr} + s.sid = dhcpv6.Duid{ + Type: dhcpv6.DUID_LLT, + HwType: iana.HWTypeEthernet, + LinkLayerAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + } + + var req, resp, msg *dhcpv6.Message + mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA} + t.Run("solicit", func(t *testing.T) { + req, err = dhcpv6.NewSolicit(mac) + require.Nil(t, err) + + msg, err = req.GetInnerMessage() + require.Nil(t, err) + + resp, err = dhcpv6.NewAdvertiseFromSolicit(msg) + require.Nil(t, err) + + assert.True(t, s.process(msg, req, resp)) + }) + require.Nil(t, err) resp.AddOption(dhcpv6.OptServerID(s.sid)) - // check "Advertise" - assert.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type()) - oia := resp.Options.OneIANA() - oiaAddr := oia.Options.OneAddress() - assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String()) + var oia *dhcpv6.OptIANA + var oiaAddr *dhcpv6.OptIAAddress + t.Run("advertise", func(t *testing.T) { + require.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type()) + oia = resp.Options.OneIANA() + oiaAddr = oia.Options.OneAddress() + assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String()) + }) - // "Request" - req, _ = dhcpv6.NewRequestFromAdvertise(resp) - msg, _ = req.GetInnerMessage() - resp, _ = dhcpv6.NewReplyFromMessage(msg) - assert.True(t, s.process(msg, req, resp)) + t.Run("request", func(t *testing.T) { + req, err = dhcpv6.NewRequestFromAdvertise(resp) + require.Nil(t, err) - // check "Reply" - assert.Equal(t, dhcpv6.MessageTypeReply, resp.Type()) - oia = resp.Options.OneIANA() - oiaAddr = oia.Options.OneAddress() - assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String()) + msg, err = req.GetInnerMessage() + require.Nil(t, err) + + resp, err = dhcpv6.NewReplyFromMessage(msg) + require.Nil(t, err) + + assert.True(t, s.process(msg, req, resp)) + }) + require.Nil(t, err) + + t.Run("reply", func(t *testing.T) { + require.Equal(t, dhcpv6.MessageTypeReply, resp.Type()) + oia = resp.Options.OneIANA() + oiaAddr = oia.Options.OneAddress() + assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String()) + }) dnsAddrs := resp.Options.DNS() - assert.Len(t, dnsAddrs, 1) - assert.Equal(t, "2000::1", dnsAddrs[0].String()) + require.Len(t, dnsAddrs, 1) + assert.Equal(t, dnsAddr, dnsAddrs[0]) - // check lease - ls := s.GetLeases(LeasesDynamic) - assert.Len(t, ls, 1) - assert.Equal(t, "2001::2", ls[0].IP.String()) - assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) - - assert.False(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1"))) - assert.False(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2"))) - assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::2"))) - assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::3"))) + t.Run("lease", func(t *testing.T) { + ls := s.GetLeases(LeasesDynamic) + require.Len(t, ls, 1) + assert.Equal(t, "2001::2", ls[0].IP.String()) + assert.Equal(t, mac, ls[0].HWAddr) + }) +} + +func TestIP6InRange(t *testing.T) { + start := net.ParseIP("2001::2") + + testCases := []struct { + ip net.IP + want bool + }{{ + ip: net.ParseIP("2001::1"), + want: false, + }, { + ip: net.ParseIP("2002::2"), + want: false, + }, { + ip: start, + want: true, + }, { + ip: net.ParseIP("2001::3"), + want: true, + }} + + for _, tc := range testCases { + t.Run(tc.ip.String(), func(t *testing.T) { + assert.Equal(t, tc.want, ip6InRange(start, tc.ip)) + }) + } } diff --git a/internal/home/ipdetector_test.go b/internal/home/ipdetector_test.go index ee20612f..6609ba08 100644 --- a/internal/home/ipdetector_test.go +++ b/internal/home/ipdetector_test.go @@ -5,16 +5,15 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIPDetector_detectSpecialNetwork(t *testing.T) { var ipd *ipDetector + var err error - t.Run("newIPDetector", func(t *testing.T) { - var err error - ipd, err = newIPDetector() - assert.Nil(t, err) - }) + ipd, err = newIPDetector() + require.Nil(t, err) testCases := []struct { name string diff --git a/internal/home/middlewares_test.go b/internal/home/middlewares_test.go index 53b7a933..fbd7a214 100644 --- a/internal/home/middlewares_test.go +++ b/internal/home/middlewares_test.go @@ -9,6 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLimitRequestBody(t *testing.T) { @@ -60,8 +61,8 @@ func TestLimitRequestBody(t *testing.T) { lim.ServeHTTP(res, req) + require.Equal(t, tc.wantErr, err) assert.Equal(t, tc.want, res.Body.Bytes()) - assert.Equal(t, tc.wantErr, err) }) } }