package dhcpd

import (
	"fmt"
	"math"
	"math/big"
	"net"

	"github.com/AdguardTeam/golibs/errors"
)

// ipRange is an inclusive range of IP addresses.  A nil range is a range that
// doesn't contain any IP addresses.
//
// It is safe for concurrent use.
//
// TODO(a.garipov): Perhaps create an optimised version with uint32 for
// IPv4 ranges?  Or use one of uint128 packages?
type ipRange struct {
	start *big.Int
	end   *big.Int
}

// maxRangeLen is the maximum IP range length.  The bitsets used in servers only
// accept uints, which can have the size of 32 bit.
const maxRangeLen = math.MaxUint32

// newIPRange creates a new IP address range.  start must be less than end.  The
// resulting range must not be greater than maxRangeLen.
func newIPRange(start, end net.IP) (r *ipRange, err error) {
	defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()

	// Make sure that both are 16 bytes long to simplify handling in
	// methods.
	start, end = start.To16(), end.To16()

	startInt := (&big.Int{}).SetBytes(start)
	endInt := (&big.Int{}).SetBytes(end)
	diff := (&big.Int{}).Sub(endInt, startInt)

	if diff.Sign() <= 0 {
		return nil, fmt.Errorf("start is greater than or equal to end")
	} else if !diff.IsUint64() || diff.Uint64() > maxRangeLen {
		return nil, fmt.Errorf("range is too large")
	}

	r = &ipRange{
		start: startInt,
		end:   endInt,
	}

	return r, nil
}

// contains returns true if r contains ip.
func (r *ipRange) contains(ip net.IP) (ok bool) {
	if r == nil {
		return false
	}

	ipInt := (&big.Int{}).SetBytes(ip.To16())

	return r.containsInt(ipInt)
}

// containsInt returns true if r contains ipInt.  For internal use only.
func (r *ipRange) containsInt(ipInt *big.Int) (ok bool) {
	return ipInt.Cmp(r.start) >= 0 && ipInt.Cmp(r.end) <= 0
}

// ipPredicate is a function that is called on every IP address in
// (*ipRange).find.  ip is given in the 16-byte form.
type ipPredicate func(ip net.IP) (ok bool)

// find finds the first IP address in r for which p returns true.  ip is in the
// 16-byte form.
func (r *ipRange) find(p ipPredicate) (ip net.IP) {
	if r == nil {
		return nil
	}

	ip = make(net.IP, net.IPv6len)
	_1 := big.NewInt(1)
	for i := (&big.Int{}).Set(r.start); i.Cmp(r.end) <= 0; i.Add(i, _1) {
		i.FillBytes(ip)
		if p(ip) {
			return ip
		}
	}

	return nil
}

// offset returns the offset of ip from the beginning of r.  It returns 0 and
// false if ip is not in r.
func (r *ipRange) offset(ip net.IP) (offset uint64, ok bool) {
	if r == nil {
		return 0, false
	}

	ip = ip.To16()
	ipInt := (&big.Int{}).SetBytes(ip)
	if !r.containsInt(ipInt) {
		return 0, false
	}

	offsetInt := (&big.Int{}).Sub(ipInt, r.start)

	// Assume that the range was checked against maxRangeLen during
	// construction.
	return offsetInt.Uint64(), true
}

// String implements the fmt.Stringer interface for *ipRange.
func (r *ipRange) String() (s string) {
	return fmt.Sprintf("%s-%s", r.start, r.end)
}