Merge: DHCP: refactoring, unit-test, readme, on-disk cache, safe restart, expiration handling; #584 #539

* commit 'ef637e1313ce548eab28c9b9f36c6db20a1c62c9':
  + DHCP: step-by-step guide for test setup with Virtual Box
  * control: add logs
  + dhcp: handle lease expiration
  + dhcp: use ICMP for IP conflict detection
  * dhcp: don't allocate a new lease when processing Request message
  * dhcp: don't process Discover/Request packets with empty client HW address
  * dhcp: refactor
  * DHCP: Stop(): wait until the worker is stopped
  * control: safely restart DHCP server
  + DHCP: On-disk database for lease table
  * use golibs v0.1.1: file.SafeWrite()
  + dhcp: test
  * dhcp: remove code which forces an update of current lease's IP in  Request message handler
  * dhcp: refactor;  log client's HW addr
This commit is contained in:
Simon Zolin 2019-03-15 14:52:18 +03:00
commit 6e16654344
12 changed files with 583 additions and 127 deletions

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
@ -217,7 +218,7 @@ func (c *configuration) write() error {
log.Error("Couldn't generate YAML file: %s", err) log.Error("Couldn't generate YAML file: %s", err)
return err return err
} }
err = safeWriteFile(configFile, yamlText) err = file.SafeWrite(configFile, yamlText)
if err != nil { if err != nil {
log.Error("Couldn't save YAML config: %s", err) log.Error("Couldn't save YAML config: %s", err)
return err return err

View File

@ -12,6 +12,7 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
@ -36,6 +37,8 @@ var client = &http.Client{
Timeout: time.Second * 30, Timeout: time.Second * 30,
} }
var controlLock sync.Mutex
// ---------------- // ----------------
// helper functions // helper functions
// ---------------- // ----------------

16
dhcp.go
View File

@ -17,6 +17,7 @@ import (
var dhcpServer = dhcpd.Server{} var dhcpServer = dhcpd.Server{}
func handleDHCPStatus(w http.ResponseWriter, r *http.Request) { func handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
rawLeases := dhcpServer.Leases() rawLeases := dhcpServer.Leases()
leases := []map[string]string{} leases := []map[string]string{}
for i := range rawLeases { for i := range rawLeases {
@ -43,6 +44,7 @@ func handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
} }
func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
newconfig := dhcpd.ServerConfig{} newconfig := dhcpd.ServerConfig{}
err := json.NewDecoder(r.Body).Decode(&newconfig) err := json.NewDecoder(r.Body).Decode(&newconfig)
if err != nil { if err != nil {
@ -50,6 +52,11 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
err = dhcpServer.Stop()
if err != nil {
log.Error("failed to stop the DHCP server: %s", err)
}
if newconfig.Enabled { if newconfig.Enabled {
err := dhcpServer.Start(&newconfig) err := dhcpServer.Start(&newconfig)
if err != nil { if err != nil {
@ -57,17 +64,13 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
if !newconfig.Enabled {
err := dhcpServer.Stop()
if err != nil {
log.Error("failed to stop the DHCP server: %s", err)
}
}
config.DHCP = newconfig config.DHCP = newconfig
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
} }
func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
response := map[string]interface{}{} response := map[string]interface{}{}
ifaces, err := getValidNetInterfaces() ifaces, err := getValidNetInterfaces()
@ -128,6 +131,7 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
} }
func handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) { func handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
errorText := fmt.Sprintf("failed to read request body: %s", err) errorText := fmt.Sprintf("failed to read request body: %s", err)

53
dhcpd/README.md Normal file
View File

@ -0,0 +1,53 @@
# DHCP server
Contents:
* [Test setup with Virtual Box](#vbox)
<a id="vbox"></a>
## Test setup with Virtual Box
To set up a test environment for DHCP server you need:
* Linux host machine
* Virtual Box
* Virtual machine (guest OS doesn't matter)
### Configure client
1. Install Virtual Box and run the following command to create a Host-Only network:
$ VBoxManage hostonlyif create
You can check its status by `ip a` command.
You can also set up Host-Only network using Virtual Box menu:
File -> Host Network Manager...
2. Create your virtual machine and set up its network:
VM Settings -> Network -> Host-only Adapter
3. Start your VM, install an OS. Configure your network interface to use DHCP and the OS should ask for a IP address from our DHCP server.
### Configure server
1. Edit server configuration file 'AdGuardHome.yaml', for example:
dhcp:
enabled: true
interface_name: vboxnet0
gateway_ip: 192.168.56.1
subnet_mask: 255.255.255.0
range_start: 192.168.56.2
range_end: 192.168.56.2
lease_duration: 86400
icmp_timeout_msec: 1000
2. Start the server
./AdGuardHome
There should be a message in log which shows that DHCP server is ready:
[info] DHCP: listening on 0.0.0.0:67

98
dhcpd/db.go Normal file
View File

@ -0,0 +1,98 @@
// On-disk database for lease table
package dhcpd
import (
"encoding/json"
"io/ioutil"
"net"
"os"
"time"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
"github.com/krolaw/dhcp4"
)
const dbFilename = "leases.db"
type leaseJSON struct {
HWAddr []byte `json:"mac"`
IP []byte `json:"ip"`
Hostname string `json:"host"`
Expiry int64 `json:"exp"`
}
// Load lease table from DB
func (s *Server) dbLoad() {
data, err := ioutil.ReadFile(dbFilename)
if err != nil {
if !os.IsNotExist(err) {
log.Error("DHCP: can't read file %s: %v", dbFilename, err)
}
return
}
obj := []leaseJSON{}
err = json.Unmarshal(data, &obj)
if err != nil {
log.Error("DHCP: invalid DB: %v", err)
return
}
s.leases = nil
s.IPpool = make(map[[4]byte]net.HardwareAddr)
numLeases := len(obj)
for i := range obj {
if !dhcp4.IPInRange(s.leaseStart, s.leaseStop, obj[i].IP) {
log.Tracef("Skipping a lease with IP %s: not within current IP range", obj[i].IP)
continue
}
lease := Lease{
HWAddr: obj[i].HWAddr,
IP: obj[i].IP,
Hostname: obj[i].Hostname,
Expiry: time.Unix(obj[i].Expiry, 0),
}
s.leases = append(s.leases, &lease)
s.reserveIP(lease.IP, lease.HWAddr)
}
log.Info("DHCP: loaded %d leases from DB", numLeases)
}
// Store lease table in DB
func (s *Server) dbStore() {
var leases []leaseJSON
for i := range s.leases {
if s.leases[i].Expiry.Unix() == 0 {
continue
}
lease := leaseJSON{
HWAddr: s.leases[i].HWAddr,
IP: s.leases[i].IP,
Hostname: s.leases[i].Hostname,
Expiry: s.leases[i].Expiry.Unix(),
}
leases = append(leases, lease)
}
data, err := json.Marshal(leases)
if err != nil {
log.Error("json.Marshal: %v", err)
return
}
err = file.SafeWrite(dbFilename, data)
if err != nil {
log.Error("DHCP: can't store lease table on disk: %v filename: %s",
err, dbFilename)
return
}
log.Info("DHCP: stored %d leases in DB", len(leases))
}

View File

@ -4,11 +4,13 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/krolaw/dhcp4" "github.com/krolaw/dhcp4"
"github.com/sparrc/go-ping"
) )
const defaultDiscoverTime = time.Second * 3 const defaultDiscoverTime = time.Second * 3
@ -32,6 +34,10 @@ type ServerConfig struct {
RangeStart string `json:"range_start" yaml:"range_start"` RangeStart string `json:"range_start" yaml:"range_start"`
RangeEnd string `json:"range_end" yaml:"range_end"` RangeEnd string `json:"range_end" yaml:"range_end"`
LeaseDuration uint `json:"lease_duration" yaml:"lease_duration"` // in seconds LeaseDuration uint `json:"lease_duration" yaml:"lease_duration"` // in seconds
// IP conflict detector: time (ms) to wait for ICMP reply.
// 0: disable
ICMPTimeout uint `json:"icmp_timeout_msec" yaml:"icmp_timeout_msec"`
} }
// Server - the current state of the DHCP server // Server - the current state of the DHCP server
@ -40,6 +46,11 @@ type Server struct {
ipnet *net.IPNet // if interface name changes, this needs to be reset ipnet *net.IPNet // if interface name changes, this needs to be reset
cond *sync.Cond // Synchronize worker thread with main thread
mutex sync.Mutex // Mutex for 'cond'
running bool // Set if the worker thread is running
stopping bool // Set if the worker thread should be stopped
// leases // leases
leases []*Lease leases []*Lease
leaseStart net.IP // parsed from config RangeStart leaseStart net.IP // parsed from config RangeStart
@ -54,6 +65,16 @@ type Server struct {
sync.RWMutex sync.RWMutex
} }
// Print information about the available network interfaces
func printInterfaces() {
ifaces, _ := net.Interfaces()
var buf strings.Builder
for i := range ifaces {
buf.WriteString(fmt.Sprintf("\"%s\", ", ifaces[i].Name))
}
log.Info("Available network interfaces: %s", buf.String())
}
// Start will listen on port 67 and serve DHCP requests. // Start will listen on port 67 and serve DHCP requests.
// Even though config can be nil, it is not optional (at least for now), since there are no default values (yet). // Even though config can be nil, it is not optional (at least for now), since there are no default values (yet).
func (s *Server) Start(config *ServerConfig) error { func (s *Server) Start(config *ServerConfig) error {
@ -64,6 +85,7 @@ func (s *Server) Start(config *ServerConfig) error {
iface, err := net.InterfaceByName(s.InterfaceName) iface, err := net.InterfaceByName(s.InterfaceName)
if err != nil { if err != nil {
s.closeConn() // in case it was already started s.closeConn() // in case it was already started
printInterfaces()
return wrapErrPrint(err, "Couldn't find interface by name %s", s.InterfaceName) return wrapErrPrint(err, "Couldn't find interface by name %s", s.InterfaceName)
} }
@ -122,20 +144,27 @@ func (s *Server) Start(config *ServerConfig) error {
s.closeConn() s.closeConn()
} }
s.dbLoad()
c, err := newFilterConn(*iface, ":67") // it has to be bound to 0.0.0.0:67, otherwise it won't see DHCP discover/request packets c, err := newFilterConn(*iface, ":67") // it has to be bound to 0.0.0.0:67, otherwise it won't see DHCP discover/request packets
if err != nil { if err != nil {
return wrapErrPrint(err, "Couldn't start listening socket on 0.0.0.0:67") return wrapErrPrint(err, "Couldn't start listening socket on 0.0.0.0:67")
} }
log.Info("DHCP: listening on 0.0.0.0:67")
s.conn = c s.conn = c
s.cond = sync.NewCond(&s.mutex)
s.running = true
go func() { go func() {
// operate on c instead of c.conn because c.conn can change over time // operate on c instead of c.conn because c.conn can change over time
err := dhcp4.Serve(c, s) err := dhcp4.Serve(c, s)
if err != nil { if err != nil && !s.stopping {
log.Printf("dhcp4.Serve() returned with error: %s", err) log.Printf("dhcp4.Serve() returned with error: %s", err)
} }
c.Close() // in case Serve() exits for other reason than listening socket closure c.Close() // in case Serve() exits for other reason than listening socket closure
s.running = false
s.cond.Signal()
}() }()
return nil return nil
@ -147,11 +176,23 @@ func (s *Server) Stop() error {
// nothing to do, return silently // nothing to do, return silently
return nil return nil
} }
s.stopping = true
err := s.closeConn() err := s.closeConn()
if err != nil { if err != nil {
return wrapErrPrint(err, "Couldn't close UDP listening socket") return wrapErrPrint(err, "Couldn't close UDP listening socket")
} }
// We've just closed the listening socket.
// Worker thread should exit right after it tries to read from the socket.
s.mutex.Lock()
for s.running {
s.cond.Wait()
}
s.mutex.Unlock()
s.dbStore()
return nil return nil
} }
@ -165,6 +206,7 @@ func (s *Server) closeConn() error {
return err return err
} }
// Reserve a lease for the client
func (s *Server) reserveLease(p dhcp4.Packet) (*Lease, error) { func (s *Server) reserveLease(p dhcp4.Packet) (*Lease, error) {
// WARNING: do not remove copy() // WARNING: do not remove copy()
// the given hwaddr by p.CHAddr() in the packet survives only during ServeDHCP() call // the given hwaddr by p.CHAddr() in the packet survives only during ServeDHCP() call
@ -172,27 +214,39 @@ func (s *Server) reserveLease(p dhcp4.Packet) (*Lease, error) {
hwaddrCOW := p.CHAddr() hwaddrCOW := p.CHAddr()
hwaddr := make(net.HardwareAddr, len(hwaddrCOW)) hwaddr := make(net.HardwareAddr, len(hwaddrCOW))
copy(hwaddr, hwaddrCOW) copy(hwaddr, hwaddrCOW)
foundLease := s.locateLease(p)
if foundLease != nil {
// log.Tracef("found lease for %s: %+v", hwaddr, foundLease)
return foundLease, nil
}
// not assigned a lease, create new one, find IP from LRU // not assigned a lease, create new one, find IP from LRU
hostname := p.ParseOptions()[dhcp4.OptionHostName]
lease := &Lease{HWAddr: hwaddr, Hostname: string(hostname)}
log.Tracef("Lease not found for %s: creating new one", hwaddr) log.Tracef("Lease not found for %s: creating new one", hwaddr)
ip, err := s.findFreeIP(hwaddr) ip, err := s.findFreeIP(hwaddr)
if err != nil { if err != nil {
return nil, wrapErrPrint(err, "Couldn't find free IP for the lease %s", hwaddr.String()) i := s.findExpiredLease()
if i < 0 {
return nil, wrapErrPrint(err, "Couldn't find free IP for the lease %s", hwaddr.String())
}
log.Tracef("Assigning IP address %s to %s (lease for %s expired at %s)",
s.leases[i].IP, hwaddr, s.leases[i].HWAddr, s.leases[i].Expiry)
lease.IP = s.leases[i].IP
s.Lock()
s.leases[i] = lease
s.Unlock()
s.reserveIP(lease.IP, hwaddr)
return lease, nil
} }
log.Tracef("Assigning to %s IP address %s", hwaddr, ip.String()) log.Tracef("Assigning to %s IP address %s", hwaddr, ip.String())
hostname := p.ParseOptions()[dhcp4.OptionHostName] lease.IP = ip
lease := &Lease{HWAddr: hwaddr, IP: ip, Hostname: string(hostname)}
s.Lock() s.Lock()
s.leases = append(s.leases, lease) s.leases = append(s.leases, lease)
s.Unlock() s.Unlock()
return lease, nil return lease, nil
} }
func (s *Server) locateLease(p dhcp4.Packet) *Lease { // Find a lease for the client
func (s *Server) findLease(p dhcp4.Packet) *Lease {
hwaddr := p.CHAddr() hwaddr := p.CHAddr()
for i := range s.leases { for i := range s.leases {
if bytes.Equal([]byte(hwaddr), []byte(s.leases[i].HWAddr)) { if bytes.Equal([]byte(hwaddr), []byte(s.leases[i].HWAddr)) {
@ -203,6 +257,17 @@ func (s *Server) locateLease(p dhcp4.Packet) *Lease {
return nil return nil
} }
// Find an expired lease and return its index or -1
func (s *Server) findExpiredLease() int {
now := time.Now().Unix()
for i, lease := range s.leases {
if lease.Expiry.Unix() <= now {
return i
}
}
return -1
}
func (s *Server) findFreeIP(hwaddr net.HardwareAddr) (net.IP, error) { func (s *Server) findFreeIP(hwaddr net.HardwareAddr) (net.IP, error) {
// if IP pool is nil, lazy initialize it // if IP pool is nil, lazy initialize it
if s.IPpool == nil { if s.IPpool == nil {
@ -213,13 +278,12 @@ func (s *Server) findFreeIP(hwaddr net.HardwareAddr) (net.IP, error) {
var foundIP net.IP var foundIP net.IP
for i := 0; i < dhcp4.IPRange(s.leaseStart, s.leaseStop); i++ { for i := 0; i < dhcp4.IPRange(s.leaseStart, s.leaseStop); i++ {
newIP := dhcp4.IPAdd(s.leaseStart, i) newIP := dhcp4.IPAdd(s.leaseStart, i)
foundHWaddr := s.getIPpool(newIP) foundHWaddr := s.findReservedHWaddr(newIP)
log.Tracef("tried IP %v, got hwaddr %v", newIP, foundHWaddr) log.Tracef("tried IP %v, got hwaddr %v", newIP, foundHWaddr)
if foundHWaddr != nil && len(foundHWaddr) != 0 { if foundHWaddr != nil && len(foundHWaddr) != 0 {
// if !bytes.Equal(foundHWaddr, hwaddr) { // if !bytes.Equal(foundHWaddr, hwaddr) {
// log.Tracef("SHOULD NOT HAPPEN: hwaddr in IP pool %s is not equal to hwaddr in lease %s", foundHWaddr, hwaddr) // log.Tracef("SHOULD NOT HAPPEN: hwaddr in IP pool %s is not equal to hwaddr in lease %s", foundHWaddr, hwaddr)
// } // }
log.Tracef("will try again")
continue continue
} }
foundIP = newIP foundIP = newIP
@ -236,7 +300,7 @@ func (s *Server) findFreeIP(hwaddr net.HardwareAddr) (net.IP, error) {
return foundIP, nil return foundIP, nil
} }
func (s *Server) getIPpool(ip net.IP) net.HardwareAddr { func (s *Server) findReservedHWaddr(ip net.IP) net.HardwareAddr {
rawIP := []byte(ip) rawIP := []byte(ip)
IP4 := [4]byte{rawIP[0], rawIP[1], rawIP[2], rawIP[3]} IP4 := [4]byte{rawIP[0], rawIP[1], rawIP[2], rawIP[3]}
return s.IPpool[IP4] return s.IPpool[IP4]
@ -256,133 +320,223 @@ func (s *Server) unreserveIP(ip net.IP) {
// ServeDHCP handles an incoming DHCP request // ServeDHCP handles an incoming DHCP request
func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dhcp4.Options) dhcp4.Packet { func (s *Server) ServeDHCP(p dhcp4.Packet, msgType dhcp4.MessageType, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Got %v message", msgType) s.printLeases()
log.Tracef("Leases:")
for i, lease := range s.leases {
log.Tracef("Lease #%d: hwaddr %s, ip %s, expiry %s", i, lease.HWAddr, lease.IP, lease.Expiry)
}
log.Tracef("IP pool:")
for ip, hwaddr := range s.IPpool {
log.Tracef("IP pool entry %s -> %s", net.IPv4(ip[0], ip[1], ip[2], ip[3]), hwaddr)
}
switch msgType { switch msgType {
case dhcp4.Discover: // Broadcast Packet From Client - Can I have an IP? case dhcp4.Discover: // Broadcast Packet From Client - Can I have an IP?
// find a lease, but don't update lease time return s.handleDiscover(p, options)
log.Tracef("Got from client: Discover")
lease, err := s.reserveLease(p)
if err != nil {
log.Tracef("Couldn't find free lease: %s", err)
// couldn't find lease, don't respond
return nil
}
reply := dhcp4.ReplyPacket(p, dhcp4.Offer, s.ipnet.IP, lease.IP, s.leaseTime, s.leaseOptions.SelectOrderOrAll(options[dhcp4.OptionParameterRequestList]))
log.Tracef("Replying with offer: offered IP %v for %v with options %+v", lease.IP, s.leaseTime, reply.ParseOptions())
return reply
case dhcp4.Request: // Broadcast From Client - I'll take that IP (Also start for renewals) case dhcp4.Request: // Broadcast From Client - I'll take that IP (Also start for renewals)
// start/renew a lease -- update lease time // start/renew a lease -- update lease time
// some clients (OSX) just go right ahead and do Request first from previously known IP, if they get NAK, they restart full cycle with Discover then Request // some clients (OSX) just go right ahead and do Request first from previously known IP, if they get NAK, they restart full cycle with Discover then Request
return s.handleDHCP4Request(p, options) return s.handleDHCP4Request(p, options)
case dhcp4.Decline: // Broadcast From Client - Sorry I can't use that IP case dhcp4.Decline: // Broadcast From Client - Sorry I can't use that IP
log.Tracef("Got from client: Decline") return s.handleDecline(p, options)
case dhcp4.Release: // From Client, I don't need that IP anymore case dhcp4.Release: // From Client, I don't need that IP anymore
log.Tracef("Got from client: Release") return s.handleRelease(p, options)
case dhcp4.Inform: // From Client, I have this IP and there's nothing you can do about it case dhcp4.Inform: // From Client, I have this IP and there's nothing you can do about it
log.Tracef("Got from client: Inform") return s.handleInform(p, options)
// do nothing
// from server -- ignore those but enumerate just in case // from server -- ignore those but enumerate just in case
case dhcp4.Offer: // Broadcast From Server - Here's an IP case dhcp4.Offer: // Broadcast From Server - Here's an IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: Offer") log.Printf("DHCP: received message from %s: Offer", p.CHAddr())
case dhcp4.ACK: // From Server, Yes you can have that IP case dhcp4.ACK: // From Server, Yes you can have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: ACK") log.Printf("DHCP: received message from %s: ACK", p.CHAddr())
case dhcp4.NAK: // From Server, No you cannot have that IP case dhcp4.NAK: // From Server, No you cannot have that IP
log.Printf("SHOULD NOT HAPPEN -- FROM ANOTHER DHCP SERVER: NAK") log.Printf("DHCP: received message from %s: NAK", p.CHAddr())
default: default:
log.Printf("Unknown DHCP packet detected, ignoring: %v", msgType) log.Printf("DHCP: unknown packet %v from %s", msgType, p.CHAddr())
return nil return nil
} }
return nil return nil
} }
// Send ICMP to the specified machine
// Return TRUE if it doesn't reply, which probably means that the IP is available
func (s *Server) addrAvailable(target net.IP) bool {
if s.ICMPTimeout == 0 {
return true
}
pinger, err := ping.NewPinger(target.String())
if err != nil {
log.Error("ping.NewPinger(): %v", err)
return true
}
pinger.SetPrivileged(true)
pinger.Timeout = time.Duration(s.ICMPTimeout) * time.Millisecond
pinger.Count = 1
reply := false
pinger.OnRecv = func(pkt *ping.Packet) {
// log.Tracef("Received ICMP Reply from %v", target)
reply = true
}
log.Tracef("Sending ICMP Echo to %v", target)
pinger.Run()
if reply {
log.Info("DHCP: IP conflict: %v is already used by another device", target)
return false
}
log.Tracef("ICMP procedure is complete: %v", target)
return true
}
// Add the specified IP to the black list for a time period
func (s *Server) blacklistLease(lease *Lease) {
hw := make(net.HardwareAddr, 6)
s.reserveIP(lease.IP, hw)
s.Lock()
lease.HWAddr = hw
lease.Hostname = ""
lease.Expiry = time.Now().Add(s.leaseTime)
s.Unlock()
}
// Return TRUE if DHCP packet is correct
func isValidPacket(p dhcp4.Packet) bool {
hw := p.CHAddr()
zeroes := make([]byte, len(hw))
if bytes.Equal(hw, zeroes) {
log.Tracef("Packet has empty CHAddr")
return false
}
return true
}
func (s *Server) handleDiscover(p dhcp4.Packet, options dhcp4.Options) dhcp4.Packet {
// find a lease, but don't update lease time
var lease *Lease
var err error
reqIP := net.IP(options[dhcp4.OptionRequestedIPAddress])
hostname := p.ParseOptions()[dhcp4.OptionHostName]
log.Tracef("Message from client: Discover. ReqIP: %s HW: %s Hostname: %s",
reqIP, p.CHAddr(), hostname)
if !isValidPacket(p) {
return nil
}
lease = s.findLease(p)
for lease == nil {
lease, err = s.reserveLease(p)
if err != nil {
log.Error("Couldn't find free lease: %s", err)
return nil
}
if !s.addrAvailable(lease.IP) {
s.blacklistLease(lease)
lease = nil
continue
}
break
}
opt := s.leaseOptions.SelectOrderOrAll(options[dhcp4.OptionParameterRequestList])
reply := dhcp4.ReplyPacket(p, dhcp4.Offer, s.ipnet.IP, lease.IP, s.leaseTime, opt)
log.Tracef("Replying with offer: offered IP %v for %v with options %+v", lease.IP, s.leaseTime, reply.ParseOptions())
return reply
}
func (s *Server) handleDHCP4Request(p dhcp4.Packet, options dhcp4.Options) dhcp4.Packet { func (s *Server) handleDHCP4Request(p dhcp4.Packet, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Got from client: Request") var lease *Lease
if server, ok := options[dhcp4.OptionServerIdentifier]; ok && !net.IP(server).Equal(s.ipnet.IP) {
reqIP := net.IP(options[dhcp4.OptionRequestedIPAddress])
log.Tracef("Message from client: Request. IP: %s ReqIP: %s HW: %s",
p.CIAddr(), reqIP, p.CHAddr())
if !isValidPacket(p) {
return nil
}
server := options[dhcp4.OptionServerIdentifier]
if server != nil && !net.IP(server).Equal(s.ipnet.IP) {
log.Tracef("Request message not for this DHCP server (%v vs %v)", server, s.ipnet.IP) log.Tracef("Request message not for this DHCP server (%v vs %v)", server, s.ipnet.IP)
return nil // Message not for this dhcp server return nil // Message not for this dhcp server
} }
reqIP := net.IP(options[dhcp4.OptionRequestedIPAddress])
if reqIP == nil { if reqIP == nil {
reqIP = p.CIAddr() reqIP = p.CIAddr()
}
if reqIP.To4() == nil { } else if reqIP == nil || reqIP.To4() == nil {
log.Tracef("Replying with NAK: request IP isn't valid IPv4: %s", reqIP) log.Tracef("Requested IP isn't a valid IPv4: %s", reqIP)
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil) return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
} }
if reqIP.Equal(net.IPv4zero) { lease = s.findLease(p)
log.Tracef("Replying with NAK: request IP is 0.0.0.0") if lease == nil {
log.Tracef("Lease for %s isn't found", p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil) return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
} }
log.Tracef("requested IP is %s", reqIP) if !lease.IP.Equal(reqIP) {
lease, err := s.reserveLease(p) log.Tracef("Lease for %s doesn't match requested/client IP: %s vs %s",
if err != nil { lease.HWAddr, lease.IP, reqIP)
log.Tracef("Couldn't find free lease: %s", err)
// couldn't find lease, don't respond
return nil
}
if lease.IP.Equal(reqIP) {
// IP matches lease IP, nothing else to do
lease.Expiry = time.Now().Add(s.leaseTime)
log.Tracef("Replying with ACK: request IP matches lease IP, nothing else to do. IP %v for %v", lease.IP, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.ACK, s.ipnet.IP, lease.IP, s.leaseTime, s.leaseOptions.SelectOrderOrAll(options[dhcp4.OptionParameterRequestList]))
}
//
// requested IP different from lease
//
log.Tracef("lease IP is different from requested IP: %s vs %s", lease.IP, reqIP)
hwaddr := s.getIPpool(reqIP)
if hwaddr == nil {
// not in pool, check if it's in DHCP range
if dhcp4.IPInRange(s.leaseStart, s.leaseStop, reqIP) {
// okay, we can give it to our client -- it's in our DHCP range and not taken, so let them use their IP
log.Tracef("Replying with ACK: request IP %v is not taken, so assigning lease IP %v to it, for %v", reqIP, lease.IP, p.CHAddr())
s.unreserveIP(lease.IP)
lease.IP = reqIP
s.reserveIP(reqIP, p.CHAddr())
lease.Expiry = time.Now().Add(s.leaseTime)
return dhcp4.ReplyPacket(p, dhcp4.ACK, s.ipnet.IP, lease.IP, s.leaseTime, s.leaseOptions.SelectOrderOrAll(options[dhcp4.OptionParameterRequestList]))
}
}
if hwaddr != nil && !bytes.Equal(hwaddr, lease.HWAddr) {
log.Printf("SHOULD NOT HAPPEN: IP pool hwaddr does not match lease hwaddr: %s vs %s", hwaddr, lease.HWAddr)
}
// requsted IP is not sufficient, reply with NAK
if hwaddr != nil {
log.Tracef("Replying with NAK: request IP %s is taken, asked by %v", reqIP, p.CHAddr())
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil) return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil)
} }
// requested IP is outside of DHCP range lease.Expiry = time.Now().Add(s.leaseTime)
log.Tracef("Replying with NAK: request IP %s is outside of DHCP range [%s, %s], asked by %v", reqIP, s.leaseStart, s.leaseStop, p.CHAddr()) log.Tracef("Replying with ACK. IP: %s HW: %s Expire: %s",
return dhcp4.ReplyPacket(p, dhcp4.NAK, s.ipnet.IP, nil, 0, nil) lease.IP, lease.HWAddr, lease.Expiry)
opt := s.leaseOptions.SelectOrderOrAll(options[dhcp4.OptionParameterRequestList])
return dhcp4.ReplyPacket(p, dhcp4.ACK, s.ipnet.IP, lease.IP, s.leaseTime, opt)
} }
// Leases returns the list of current DHCP leases func (s *Server) handleInform(p dhcp4.Packet, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Message from client: Inform. IP: %s HW: %s",
p.CIAddr(), p.CHAddr())
return nil
}
func (s *Server) handleRelease(p dhcp4.Packet, options dhcp4.Options) dhcp4.Packet {
log.Tracef("Message from client: Release. IP: %s HW: %s",
p.CIAddr(), p.CHAddr())
return nil
}
func (s *Server) handleDecline(p dhcp4.Packet, options dhcp4.Options) dhcp4.Packet {
reqIP := net.IP(options[dhcp4.OptionRequestedIPAddress])
log.Tracef("Message from client: Decline. IP: %s HW: %s",
reqIP, p.CHAddr())
return nil
}
// Leases returns the list of current DHCP leases (thread-safe)
func (s *Server) Leases() []*Lease { func (s *Server) Leases() []*Lease {
s.RLock() s.RLock()
result := s.leases result := s.leases
s.RUnlock() s.RUnlock()
return result return result
} }
// Print information about the current leases
func (s *Server) printLeases() {
log.Tracef("Leases:")
for i, lease := range s.leases {
log.Tracef("Lease #%d: hwaddr %s, ip %s, expiry %s",
i, lease.HWAddr, lease.IP, lease.Expiry)
}
}
// Reset internal state
func (s *Server) reset() {
s.Lock()
s.leases = nil
s.Unlock()
s.IPpool = make(map[[4]byte]net.HardwareAddr)
}

152
dhcpd/dhcpd_test.go Normal file
View File

@ -0,0 +1,152 @@
package dhcpd
import (
"bytes"
"net"
"os"
"testing"
"time"
"github.com/krolaw/dhcp4"
)
func check(t *testing.T, result bool, msg string) {
if !result {
t.Fatal(msg)
}
}
// Tests performed:
// . Handle Discover message (lease reserve)
// . Handle Request message (lease commit)
func TestDHCP(t *testing.T) {
var s = Server{}
var p, p2 dhcp4.Packet
var hw net.HardwareAddr
var lease *Lease
var opt dhcp4.Options
s.leaseStart = []byte{1, 1, 1, 1}
s.leaseStop = []byte{1, 1, 1, 2}
s.leaseTime = 5 * time.Second
s.leaseOptions = dhcp4.Options{}
s.ipnet = &net.IPNet{
IP: []byte{1, 2, 3, 4},
Mask: []byte{0xff, 0xff, 0xff, 0xff},
}
p = make(dhcp4.Packet, 241)
// Reserve an IP
hw = []byte{1, 2, 3, 4, 5, 6}
p.SetCHAddr(hw)
lease, _ = s.reserveLease(p)
check(t, bytes.Equal(lease.HWAddr, hw), "lease.HWAddr")
check(t, bytes.Equal(lease.IP, []byte{1, 1, 1, 1}), "lease.IP")
lease = s.findLease(p)
check(t, bytes.Equal(lease.HWAddr, hw), "lease.HWAddr")
check(t, bytes.Equal(lease.IP, []byte{1, 1, 1, 1}), "lease.IP")
// Reserve an IP - the next IP from the range
hw = []byte{2, 2, 3, 4, 5, 6}
p.SetCHAddr(hw)
lease, _ = s.reserveLease(p)
check(t, bytes.Equal(lease.HWAddr, hw), "lease.HWAddr")
check(t, bytes.Equal(lease.IP, []byte{1, 1, 1, 2}), "lease.IP")
// Reserve an IP - we have no more available IPs
p.SetCHAddr([]byte{3, 2, 3, 4, 5, 6})
lease, _ = s.reserveLease(p)
check(t, lease == nil, "lease == nil")
// Decline request for a lease which doesn't match our internal state
hw = []byte{1, 2, 3, 4, 5, 6}
p.SetCHAddr(hw)
p.SetCIAddr([]byte{0, 0, 0, 0})
opt = make(dhcp4.Options, 10)
// ask a different IP
opt[dhcp4.OptionRequestedIPAddress] = []byte{1, 1, 1, 2}
p2 = s.handleDHCP4Request(p, opt)
opt = p2.ParseOptions()
check(t, bytes.Equal(opt[dhcp4.OptionDHCPMessageType], []byte{byte(dhcp4.NAK)}), "dhcp4.NAK")
// Commit the previously reserved lease
hw = []byte{1, 2, 3, 4, 5, 6}
p.SetCHAddr(hw)
p.SetCIAddr([]byte{0, 0, 0, 0})
opt = make(dhcp4.Options, 10)
opt[dhcp4.OptionRequestedIPAddress] = []byte{1, 1, 1, 1}
p2 = s.handleDHCP4Request(p, opt)
opt = p2.ParseOptions()
check(t, bytes.Equal(opt[dhcp4.OptionDHCPMessageType], []byte{byte(dhcp4.ACK)}), "dhcp4.ACK")
check(t, bytes.Equal(p2.YIAddr(), []byte{1, 1, 1, 1}), "p2.YIAddr")
check(t, bytes.Equal(p2.CHAddr(), hw), "p2.CHAddr")
check(t, bytes.Equal(opt[dhcp4.OptionIPAddressLeaseTime], dhcp4.OptionsLeaseTime(5*time.Second)), "OptionIPAddressLeaseTime")
check(t, bytes.Equal(opt[dhcp4.OptionServerIdentifier], s.ipnet.IP), "OptionServerIdentifier")
s.reset()
misc(t, &s)
}
// Small tests that don't require a static server's state
func misc(t *testing.T, s *Server) {
var p, p2 dhcp4.Packet
var hw net.HardwareAddr
var opt dhcp4.Options
p = make(dhcp4.Packet, 241)
// Try to commit a lease for an IP without prior Discover-Offer packets
hw = []byte{2, 2, 3, 4, 5, 6}
p.SetCHAddr(hw)
p.SetCIAddr([]byte{0, 0, 0, 0})
opt = make(dhcp4.Options, 10)
opt[dhcp4.OptionRequestedIPAddress] = []byte{1, 1, 1, 1}
p2 = s.handleDHCP4Request(p, opt)
opt = p2.ParseOptions()
check(t, bytes.Equal(opt[dhcp4.OptionDHCPMessageType], []byte{byte(dhcp4.NAK)}), "dhcp4.NAK")
}
// Leases database store/load
func TestDB(t *testing.T) {
var s = Server{}
var p dhcp4.Packet
var hw1, hw2 net.HardwareAddr
var lease *Lease
s.leaseStart = []byte{1, 1, 1, 1}
s.leaseStop = []byte{1, 1, 1, 2}
s.leaseTime = 5 * time.Second
s.leaseOptions = dhcp4.Options{}
s.ipnet = &net.IPNet{
IP: []byte{1, 2, 3, 4},
Mask: []byte{0xff, 0xff, 0xff, 0xff},
}
p = make(dhcp4.Packet, 241)
hw1 = []byte{1, 2, 3, 4, 5, 6}
p.SetCHAddr(hw1)
lease, _ = s.reserveLease(p)
lease.Expiry = time.Unix(4000000001, 0)
hw2 = []byte{2, 2, 3, 4, 5, 6}
p.SetCHAddr(hw2)
lease, _ = s.reserveLease(p)
lease.Expiry = time.Unix(4000000002, 0)
os.Remove("leases.db")
s.dbStore()
s.reset()
s.dbLoad()
check(t, bytes.Equal(s.leases[0].HWAddr, hw1), "leases[0].HWAddr")
check(t, bytes.Equal(s.leases[0].IP, []byte{1, 1, 1, 1}), "leases[0].IP")
check(t, s.leases[0].Expiry.Unix() == 4000000001, "leases[0].Expiry")
check(t, bytes.Equal(s.leases[1].HWAddr, hw2), "leases[1].HWAddr")
check(t, bytes.Equal(s.leases[1].IP, []byte{1, 1, 1, 2}), "leases[1].IP")
check(t, s.leases[1].Expiry.Unix() == 4000000002, "leases[1].Expiry")
os.Remove("leases.db")
}

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -220,7 +221,7 @@ func (filter *filter) save() error {
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
body := []byte(strings.Join(filter.Rules, "\n")) body := []byte(strings.Join(filter.Rules, "\n"))
err := safeWriteFile(filterFilePath, body) err := file.SafeWrite(filterFilePath, body)
// update LastUpdated field after saving the file // update LastUpdated field after saving the file
filter.LastUpdated = filter.LastTimeUpdated() filter.LastUpdated = filter.LastTimeUpdated()

3
go.mod
View File

@ -4,7 +4,7 @@ go 1.12
require ( require (
github.com/AdguardTeam/dnsproxy v0.11.2 github.com/AdguardTeam/dnsproxy v0.11.2
github.com/AdguardTeam/golibs v0.1.0 github.com/AdguardTeam/golibs v0.1.1
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f // indirect github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f // indirect
github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7 github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7
github.com/go-ole/go-ole v1.2.1 // indirect github.com/go-ole/go-ole v1.2.1 // indirect
@ -17,6 +17,7 @@ require (
github.com/miekg/dns v1.1.1 github.com/miekg/dns v1.1.1
github.com/shirou/gopsutil v2.18.10+incompatible github.com/shirou/gopsutil v2.18.10+incompatible
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect
github.com/sparrc/go-ping v0.0.0-20181106165434-ef3ab45e41b0
github.com/stretchr/testify v1.2.2 github.com/stretchr/testify v1.2.2
go.uber.org/goleak v0.10.0 go.uber.org/goleak v0.10.0
golang.org/x/net v0.0.0-20190119204137-ed066c81e75e golang.org/x/net v0.0.0-20190119204137-ed066c81e75e

5
go.sum
View File

@ -1,7 +1,8 @@
github.com/AdguardTeam/dnsproxy v0.11.2 h1:S/Ag2q9qoZsmW1fvMohPZP7/5amEtz8NmFCp8kxUalQ= github.com/AdguardTeam/dnsproxy v0.11.2 h1:S/Ag2q9qoZsmW1fvMohPZP7/5amEtz8NmFCp8kxUalQ=
github.com/AdguardTeam/dnsproxy v0.11.2/go.mod h1:EPp92b5cYR7HZpO+OQu6xC7AyhUoBaXW3sfa3exq/0I= github.com/AdguardTeam/dnsproxy v0.11.2/go.mod h1:EPp92b5cYR7HZpO+OQu6xC7AyhUoBaXW3sfa3exq/0I=
github.com/AdguardTeam/golibs v0.1.0 h1:Mo1QNKC8eSbqczhxfdBXYCrUMwvgCyCwZFyWv+2Gdng=
github.com/AdguardTeam/golibs v0.1.0/go.mod h1:zhi6xGwK4cMpjDocybhhLgvcGkstiSIjlpKbvyxC5Yc= github.com/AdguardTeam/golibs v0.1.0/go.mod h1:zhi6xGwK4cMpjDocybhhLgvcGkstiSIjlpKbvyxC5Yc=
github.com/AdguardTeam/golibs v0.1.1 h1:aepIN7yulf8I4Ub2c0cAaIizfSHPVXB2wrh8j4BJxl4=
github.com/AdguardTeam/golibs v0.1.1/go.mod h1:b0XkhgIcn2TxwX6C5AQMtpIFAgjPehNgxJErWkwA3ko=
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f h1:5ZfJxyXo8KyX8DgGXC5B7ILL8y51fci/qYz2B4j8iLY= github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f h1:5ZfJxyXo8KyX8DgGXC5B7ILL8y51fci/qYz2B4j8iLY=
github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
@ -55,6 +56,8 @@ github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBh
github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U= github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U=
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc= github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc=
github.com/sparrc/go-ping v0.0.0-20181106165434-ef3ab45e41b0 h1:mu7brOsdaH5Dqf93vdch+mr/0To8Sgc+yInt/jE/RJM=
github.com/sparrc/go-ping v0.0.0-20181106165434-ef3ab45e41b0/go.mod h1:eMyUVp6f/5jnzM+3zahzl7q6UXLbgSc3MKg/+ow9QW0=
github.com/spf13/cobra v0.0.3 h1:ZlrZ4XsMRm04Fr5pSFxBgfND2EBVa1nLpiy1stUsX/8= github.com/spf13/cobra v0.0.3 h1:ZlrZ4XsMRm04Fr5pSFxBgfND2EBVa1nLpiy1stUsX/8=
github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg=

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -19,26 +18,6 @@ import (
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
) )
// ----------------------------------
// helper functions for working with files
// ----------------------------------
// Writes data first to a temporary file and then renames it to what's specified in path
func safeWriteFile(path string, data []byte) error {
dir := filepath.Dir(path)
err := os.MkdirAll(dir, 0755)
if err != nil {
return err
}
tmpPath := path + ".tmp"
err = ioutil.WriteFile(tmpPath, data, 0644)
if err != nil {
return err
}
return os.Rename(tmpPath, path)
}
// ---------------------------------- // ----------------------------------
// helper functions for HTTP handlers // helper functions for HTTP handlers
// ---------------------------------- // ----------------------------------
@ -48,6 +27,12 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed) http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
return return
} }
if method == "POST" || method == "PUT" || method == "DELETE" {
controlLock.Lock()
defer controlLock.Unlock()
}
handler(w, r) handler(w, r)
} }
} }

View File

@ -6,6 +6,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
@ -86,7 +87,7 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err
return err return err
} }
err = safeWriteFile(configFile, body) err = file.SafeWrite(configFile, body)
if err != nil { if err != nil {
log.Printf("Couldn't save YAML config: %s", err) log.Printf("Couldn't save YAML config: %s", err)
return err return err