diff --git a/dhcp.go b/dhcp.go index 48e3029f..8c966ad6 100644 --- a/dhcp.go +++ b/dhcp.go @@ -451,10 +451,6 @@ func stopDHCPServer() error { return nil } - if !dhcpServer.Enabled { - return nil - } - err := dhcpServer.Stop() if err != nil { return errorx.Decorate(err, "Couldn't stop DHCP server") diff --git a/dhcpd/dhcpd.go b/dhcpd/dhcpd.go index 073ec0fa..13089f7d 100644 --- a/dhcpd/dhcpd.go +++ b/dhcpd/dhcpd.go @@ -57,6 +57,7 @@ type Server struct { // leases leases []*Lease + leasesLock sync.RWMutex leaseStart net.IP // parsed from config RangeStart leaseStop net.IP // parsed from config RangeEnd leaseTime time.Duration // parsed from config LeaseDuration @@ -65,8 +66,7 @@ type Server struct { // IP address pool -- if entry is in the pool, then it's attached to a lease IPpool map[[4]byte]net.HardwareAddr - ServerConfig - sync.RWMutex + conf ServerConfig } // Print information about the available network interfaces @@ -96,49 +96,48 @@ func (s *Server) Init(config ServerConfig) error { } func (s *Server) setConfig(config ServerConfig) error { - s.ServerConfig = config + s.conf = config - iface, err := net.InterfaceByName(s.InterfaceName) + iface, err := net.InterfaceByName(config.InterfaceName) if err != nil { printInterfaces() - return wrapErrPrint(err, "Couldn't find interface by name %s", s.InterfaceName) + return wrapErrPrint(err, "Couldn't find interface by name %s", config.InterfaceName) } // get ipv4 address of an interface s.ipnet = getIfaceIPv4(iface) if s.ipnet == nil { - return wrapErrPrint(err, "Couldn't find IPv4 address of interface %s %+v", s.InterfaceName, iface) + return wrapErrPrint(err, "Couldn't find IPv4 address of interface %s %+v", config.InterfaceName, iface) } - if s.LeaseDuration == 0 { + if config.LeaseDuration == 0 { s.leaseTime = time.Hour * 2 - s.LeaseDuration = uint(s.leaseTime.Seconds()) } else { - s.leaseTime = time.Second * time.Duration(s.LeaseDuration) + s.leaseTime = time.Second * time.Duration(config.LeaseDuration) } - s.leaseStart, err = parseIPv4(s.RangeStart) + s.leaseStart, err = parseIPv4(config.RangeStart) if err != nil { - return wrapErrPrint(err, "Failed to parse range start address %s", s.RangeStart) + return wrapErrPrint(err, "Failed to parse range start address %s", config.RangeStart) } - s.leaseStop, err = parseIPv4(s.RangeEnd) + s.leaseStop, err = parseIPv4(config.RangeEnd) if err != nil { - return wrapErrPrint(err, "Failed to parse range end address %s", s.RangeEnd) + return wrapErrPrint(err, "Failed to parse range end address %s", config.RangeEnd) } - subnet, err := parseIPv4(s.SubnetMask) + subnet, err := parseIPv4(config.SubnetMask) if err != nil { - return wrapErrPrint(err, "Failed to parse subnet mask %s", s.SubnetMask) + return wrapErrPrint(err, "Failed to parse subnet mask %s", config.SubnetMask) } // if !bytes.Equal(subnet, s.ipnet.Mask) { // return wrapErrPrint(err, "specified subnet mask %s does not meatch interface %s subnet mask %s", s.SubnetMask, s.InterfaceName, s.ipnet.Mask) // } - router, err := parseIPv4(s.GatewayIP) + router, err := parseIPv4(config.GatewayIP) if err != nil { - return wrapErrPrint(err, "Failed to parse gateway IP %s", s.GatewayIP) + return wrapErrPrint(err, "Failed to parse gateway IP %s", config.GatewayIP) } s.leaseOptions = dhcp4.Options{ @@ -158,9 +157,9 @@ func (s *Server) Start() error { s.closeConn() } - iface, err := net.InterfaceByName(s.InterfaceName) + iface, err := net.InterfaceByName(s.conf.InterfaceName) if err != nil { - return wrapErrPrint(err, "Couldn't find interface by name %s", s.InterfaceName) + return wrapErrPrint(err, "Couldn't find interface by name %s", s.conf.InterfaceName) } c, err := newFilterConn(*iface, ":67") // it has to be bound to 0.0.0.0:67, otherwise it won't see DHCP discover/request packets @@ -246,9 +245,9 @@ func (s *Server) reserveLease(p dhcp4.Packet) (*Lease, error) { log.Tracef("Assigning IP address %s to %s (lease for %s expired at %s)", s.leases[i].IP, hwaddr, s.leases[i].HWAddr, s.leases[i].Expiry) lease.IP = s.leases[i].IP - s.Lock() + s.leasesLock.Lock() s.leases[i] = lease - s.Unlock() + s.leasesLock.Unlock() s.reserveIP(lease.IP, hwaddr) return lease, nil @@ -256,9 +255,9 @@ func (s *Server) reserveLease(p dhcp4.Packet) (*Lease, error) { log.Tracef("Assigning to %s IP address %s", hwaddr, ip.String()) lease.IP = ip - s.Lock() + s.leasesLock.Lock() s.leases = append(s.leases, lease) - s.Unlock() + s.leasesLock.Unlock() return lease, nil } @@ -373,7 +372,7 @@ func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dh // Return TRUE if it doesn't reply, which probably means that the IP is available func (s *Server) addrAvailable(target net.IP) bool { - if s.ICMPTimeout == 0 { + if s.conf.ICMPTimeout == 0 { return true } @@ -384,7 +383,7 @@ func (s *Server) addrAvailable(target net.IP) bool { } pinger.SetPrivileged(true) - pinger.Timeout = time.Duration(s.ICMPTimeout) * time.Millisecond + pinger.Timeout = time.Duration(s.conf.ICMPTimeout) * time.Millisecond pinger.Count = 1 reply := false pinger.OnRecv = func(pkt *ping.Packet) { @@ -407,11 +406,11 @@ func (s *Server) addrAvailable(target net.IP) bool { func (s *Server) blacklistLease(lease *Lease) { hw := make(net.HardwareAddr, 6) s.reserveIP(lease.IP, hw) - s.Lock() + s.leasesLock.Lock() lease.HWAddr = hw lease.Hostname = "" lease.Expiry = time.Now().Add(s.leaseTime) - s.Unlock() + s.leasesLock.Unlock() } // Return TRUE if DHCP packet is correct @@ -595,33 +594,33 @@ func (s *Server) RemoveStaticLease(l Lease) error { func (s *Server) Leases() []Lease { var result []Lease now := time.Now().Unix() - s.RLock() + s.leasesLock.RLock() for _, lease := range s.leases { if lease.Expiry.Unix() > now { result = append(result, *lease) } } - s.RUnlock() + s.leasesLock.RUnlock() return result } // StaticLeases returns the list of statically-configured DHCP leases (thread-safe) func (s *Server) StaticLeases() []Lease { - s.Lock() + s.leasesLock.Lock() if s.IPpool == nil { s.dbLoad() } - s.Unlock() + s.leasesLock.Unlock() var result []Lease - s.RLock() + s.leasesLock.RLock() for _, lease := range s.leases { if lease.Expiry.Unix() == 1 { result = append(result, *lease) } } - s.RUnlock() + s.leasesLock.RUnlock() return result } @@ -637,8 +636,8 @@ func (s *Server) printLeases() { // FindIPbyMAC finds an IP address by MAC address in the currently active DHCP leases func (s *Server) FindIPbyMAC(mac net.HardwareAddr) net.IP { now := time.Now().Unix() - s.RLock() - defer s.RUnlock() + s.leasesLock.RLock() + defer s.leasesLock.RUnlock() for _, l := range s.leases { if l.Expiry.Unix() > now && bytes.Equal(mac, l.HWAddr) { return l.IP @@ -649,8 +648,8 @@ func (s *Server) FindIPbyMAC(mac net.HardwareAddr) net.IP { // Reset internal state func (s *Server) reset() { - s.Lock() + s.leasesLock.Lock() s.leases = nil - s.Unlock() + s.leasesLock.Unlock() s.IPpool = make(map[[4]byte]net.HardwareAddr) }