diff --git a/dhcp.go b/dhcp.go index 3b77a855..2d28ae9c 100644 --- a/dhcp.go +++ b/dhcp.go @@ -13,6 +13,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/dhcpd" + "github.com/AdguardTeam/golibs/file" "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" ) @@ -61,7 +62,17 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { } if newconfig.Enabled { - err := dhcpServer.Start(&newconfig) + + staticIP, err := hasStaticIP(newconfig.InterfaceName) + if !staticIP && err == nil { + err = setStaticIP(newconfig.InterfaceName) + if err != nil { + httpError(w, http.StatusInternalServerError, "Failed to configure static IP: %s", err) + return + } + } + + err = dhcpServer.Start(&newconfig) if err != nil { httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err) return @@ -258,6 +269,68 @@ func getFullIP(ifaceName string) string { return fields[3] } +// Get gateway IP address +func getGatewayIP(ifaceName string) string { + cmd := exec.Command("ip", "route", "show", "dev", ifaceName) + log.Tracef("executing %s %v", cmd.Path, cmd.Args) + d, err := cmd.Output() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return "" + } + + fields := strings.Fields(string(d)) + if len(fields) < 3 || fields[0] != "default" { + return "" + } + + ip := net.ParseIP(fields[2]) + if ip == nil { + return "" + } + + return fields[2] +} + +// Set a static IP for network interface +func setStaticIP(ifaceName string) error { + ip := getFullIP(ifaceName) + if len(ip) == 0 { + return errors.New("Can't get IP address") + } + + body, err := ioutil.ReadFile("/etc/dhcpcd.conf") + if err != nil { + return err + } + + ip4, _, err := net.ParseCIDR(ip) + if err != nil { + return err + } + + add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n", + ifaceName, ip) + body = append(body, []byte(add)...) + + gatewayIP := getGatewayIP(ifaceName) + if len(gatewayIP) != 0 { + add = fmt.Sprintf("static routers=%s\n", + gatewayIP) + body = append(body, []byte(add)...) + } + + add = fmt.Sprintf("static domain_name_servers=%s\n\n", + ip4) + body = append(body, []byte(add)...) + + err = file.SafeWrite("/etc/dhcpcd.conf", body) + if err != nil { + return err + } + + return nil +} + func startDHCPServer() error { if !config.DHCP.Enabled { // not enabled, don't do anything