diff --git a/control.go b/control.go index e9890575..dd368380 100644 --- a/control.go +++ b/control.go @@ -57,10 +57,10 @@ func returnOK(w http.ResponseWriter) { } } -func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) { +func httpError(w *http.ResponseWriter, code int, format string, args ...interface{}) { text := fmt.Sprintf(format, args...) log.Println(text) - http.Error(w, text, code) + http.Error(*w, text, code) } // --------------- @@ -78,7 +78,7 @@ func writeAllConfigsAndReloadDNS() error { func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) { err := writeAllConfigsAndReloadDNS() if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) + httpError(&w, http.StatusInternalServerError, "Couldn't write config file: %s", err) return } returnOK(w) @@ -334,9 +334,7 @@ func handleSetBootstrapDNS(w http.ResponseWriter, r *http.Request) { func setDNSServers(w *http.ResponseWriter, r *http.Request, upstreams bool) { body, err := ioutil.ReadAll(r.Body) if err != nil { - errorText := fmt.Sprintf("Failed to read request body: %s", err) - log.Println(errorText) - http.Error(*w, errorText, http.StatusBadRequest) + httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) return } // if empty body -- user is asking for default servers @@ -379,23 +377,17 @@ func setDNSServers(w *http.ResponseWriter, r *http.Request, upstreams bool) { err = writeAllConfigs() if err != nil { - errorText := fmt.Sprintf("Couldn't write config file: %s", err) - log.Println(errorText) - http.Error(*w, errorText, http.StatusInternalServerError) + httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) return } err = reconfigureDNSServer() if err != nil { - errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err) - log.Println(errorText) - http.Error(*w, errorText, http.StatusInternalServerError) + httpError(w, http.StatusInternalServerError, "Couldn't reconfigure the DNS server: %s", err) return } _, err = fmt.Fprintf(*w, "OK %d servers\n", count) if err != nil { - errorText := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errorText) - http.Error(*w, errorText, http.StatusInternalServerError) + httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) } } @@ -574,7 +566,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { f := filter{} err := json.NewDecoder(r.Body).Decode(&f) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) + httpError(&w, http.StatusBadRequest, "Failed to parse request body json: %s", err) return } @@ -975,7 +967,7 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { ifaces, err := getValidNetInterfacesForWeb() if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) + httpError(&w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) return } @@ -987,7 +979,7 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(data) if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err) + httpError(&w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err) return } } @@ -996,7 +988,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { newSettings := firstRunData{} err := json.NewDecoder(r.Body).Decode(&newSettings) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse new config json: %s", err) + httpError(&w, http.StatusBadRequest, "Failed to parse new config json: %s", err) return } @@ -1010,14 +1002,14 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { if restartHTTP { err = checkPortAvailable(newSettings.Web.IP, newSettings.Web.Port) if err != nil { - httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s", net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err) + httpError(&w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s", net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err) return } } err = checkPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port) if err != nil { - httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s", net.JoinHostPort(newSettings.DNS.IP, strconv.Itoa(newSettings.DNS.Port)), err) + httpError(&w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s", net.JoinHostPort(newSettings.DNS.IP, strconv.Itoa(newSettings.DNS.Port)), err) return } @@ -1032,7 +1024,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { if config.DNS.Port != 0 { err = startDNSServer() if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err) + httpError(&w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err) return } } @@ -1057,7 +1049,7 @@ func handleTLSStatus(w http.ResponseWriter, r *http.Request) { func handleTLSValidate(w http.ResponseWriter, r *http.Request) { data, err := unmarshalTLS(r) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) + httpError(&w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) return } @@ -1070,7 +1062,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) { if !alreadyRunning { err = checkPortAvailable(config.BindHost, data.PortHTTPS) if err != nil { - httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) + httpError(&w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) return } } @@ -1082,7 +1074,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) { func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { data, err := unmarshalTLS(r) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) + httpError(&w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) return } @@ -1095,7 +1087,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { if !alreadyRunning { err = checkPortAvailable(config.BindHost, data.PortHTTPS) if err != nil { - httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) + httpError(&w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) return } } @@ -1109,7 +1101,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { config.TLS = data err = writeAllConfigsAndReloadDNS() if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) + httpError(&w, http.StatusInternalServerError, "Couldn't write config file: %s", err) return } marshalTLS(w, data) @@ -1328,7 +1320,7 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) { } err := json.NewEncoder(w).Encode(data) if err != nil { - httpError(w, http.StatusInternalServerError, "Failed to marshal json with TLS status: %s", err) + httpError(&w, http.StatusInternalServerError, "Failed to marshal json with TLS status: %s", err) return } } @@ -1338,12 +1330,12 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) { // -------------- func handleDOH(w http.ResponseWriter, r *http.Request) { if r.TLS == nil { - httpError(w, http.StatusNotFound, "Not Found") + httpError(&w, http.StatusNotFound, "Not Found") return } if !isRunning() { - httpError(w, http.StatusInternalServerError, "DNS server is not running") + httpError(&w, http.StatusInternalServerError, "DNS server is not running") return } diff --git a/dhcp.go b/dhcp.go index a67b0ef6..6098a1a2 100644 --- a/dhcp.go +++ b/dhcp.go @@ -37,7 +37,7 @@ func handleDHCPStatus(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(status) if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal DHCP status json: %s", err) + httpError(&w, http.StatusInternalServerError, "Unable to marshal DHCP status json: %s", err) return } } @@ -46,14 +46,14 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { newconfig := dhcpd.ServerConfig{} err := json.NewDecoder(r.Body).Decode(&newconfig) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse new DHCP config json: %s", err) + httpError(&w, http.StatusBadRequest, "Failed to parse new DHCP config json: %s", err) return } if newconfig.Enabled { err := dhcpServer.Start(&newconfig) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err) + httpError(&w, http.StatusBadRequest, "Failed to start DHCP server: %s", err) return } } @@ -72,7 +72,7 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { ifaces, err := getValidNetInterfaces() if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) + httpError(&w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) return } @@ -87,7 +87,7 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { } addrs, err := iface.Addrs() if err != nil { - httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err) + httpError(&w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err) return } @@ -105,7 +105,7 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { ipnet, ok := addr.(*net.IPNet) if !ok { // not an IPNet, should not happen - httpError(w, http.StatusInternalServerError, "SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) + httpError(&w, http.StatusInternalServerError, "SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) return } // ignore link-local @@ -122,7 +122,7 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { err = json.NewEncoder(w).Encode(response) if err != nil { - httpError(w, http.StatusInternalServerError, "Failed to marshal json with available interfaces: %s", err) + httpError(&w, http.StatusInternalServerError, "Failed to marshal json with available interfaces: %s", err) return } } @@ -153,7 +153,7 @@ func handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(result) if err != nil { - httpError(w, http.StatusInternalServerError, "Failed to marshal DHCP found json: %s", err) + httpError(&w, http.StatusInternalServerError, "Failed to marshal DHCP found json: %s", err) return } } diff --git a/go.mod b/go.mod index cba63387..ebc1b9f3 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/kardianos/service v0.0.0-20181115005516-4c239ee84e7b github.com/krolaw/dhcp4 v0.0.0-20180925202202-7cead472c414 github.com/miekg/dns v1.1.1 + github.com/pkg/errors v0.8.0 github.com/shirou/gopsutil v2.18.10+incompatible github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect github.com/stretchr/testify v1.2.2 diff --git a/upgrade.go b/upgrade.go index 0b3ddc5c..b431dc71 100644 --- a/upgrade.go +++ b/upgrade.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "io/ioutil" "os" @@ -10,7 +11,7 @@ import ( yaml "gopkg.in/yaml.v2" ) -const currentSchemaVersion = 2 // used for upgrading from old configs to new config +const currentSchemaVersion = 3 // used for upgrading from old configs to new config // Performs necessary upgrade operations if needed func upgradeConfig() error { @@ -59,12 +60,17 @@ func upgradeConfig() error { func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error { switch oldVersion { case 0: - err := upgradeSchema0to2(diskConfig) + err := upgradeSchema0to3(diskConfig) if err != nil { return err } case 1: - err := upgradeSchema1to2(diskConfig) + err := upgradeSchema1to3(diskConfig) + if err != nil { + return err + } + case 2: + err := upgradeSchema2to3(diskConfig) if err != nil { return err } @@ -135,12 +141,60 @@ func upgradeSchema1to2(diskConfig *map[string]interface{}) error { return nil } -// jump two schemas at once -- this time we just do it sequentially -func upgradeSchema0to2(diskConfig *map[string]interface{}) error { +// Third schema upgrade: +// Bootstrap DNS becomes an array +func upgradeSchema2to3(diskConfig *map[string]interface{}) error { + log.Printf("%s(): called", _Func()) + + // Let's read dns configuration from diskConfig + dnsConfig, ok := (*diskConfig)["dns"] + if !ok { + return errors.New("no DNS configuration in config file") + } + + // Convert interface{} to map[string]interface{} + newDNSConfig := make(map[string]interface{}) + + switch v := dnsConfig.(type) { + case map[interface{}]interface{}: + for k, v := range v { + newDNSConfig[fmt.Sprint(k)] = v + } + default: + return errors.New("DNS configuration is not a map") + } + + // Replace bootstrap_dns value filed with new array contains old bootstrap_dns inside + if bootstrapDNS, ok := (newDNSConfig)["bootstrap_dns"]; ok { + newBootstrapConfig := []string{fmt.Sprint(bootstrapDNS)} + (newDNSConfig)["bootstrap_dns"] = newBootstrapConfig + (*diskConfig)["dns"] = newDNSConfig + } else { + return errors.New("no bootstrap DNS in DNS config") + } + + // Bump schema version + (*diskConfig)["schema_version"] = 3 + + return nil +} + +// jump three schemas at once -- this time we just do it sequentially +func upgradeSchema0to3(diskConfig *map[string]interface{}) error { err := upgradeSchema0to1(diskConfig) if err != nil { return err } - return upgradeSchema1to2(diskConfig) + return upgradeSchema1to3(diskConfig) +} + +// jump two schemas at once -- this time we just do it sequentially +func upgradeSchema1to3(diskConfig *map[string]interface{}) error { + err := upgradeSchema1to2(diskConfig) + if err != nil { + return err + } + + return upgradeSchema2to3(diskConfig) }