diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b0f572a..c4b7ccad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,7 @@ and this project adheres to ### Changed +- Port bindings are now checked for uniqueness ([#3835]). - The DNSSEC check now simply checks against the AD flag in the response ([#3904]). - Client objects in the configuration file are now sorted ([#3933]). @@ -241,6 +242,7 @@ In this release, the schema version has changed from 10 to 12. [#3772]: https://github.com/AdguardTeam/AdGuardHome/issues/3772 [#3778]: https://github.com/AdguardTeam/AdGuardHome/issues/3778 [#3815]: https://github.com/AdguardTeam/AdGuardHome/issues/3815 +[#3835]: https://github.com/AdguardTeam/AdGuardHome/issues/3835 [#3887]: https://github.com/AdguardTeam/AdGuardHome/issues/3887 [#3890]: https://github.com/AdguardTeam/AdGuardHome/issues/3890 [#3904]: https://github.com/AdguardTeam/AdGuardHome/issues/3904 diff --git a/internal/aghhttp/aghhttp.go b/internal/aghhttp/aghhttp.go new file mode 100644 index 00000000..666342ea --- /dev/null +++ b/internal/aghhttp/aghhttp.go @@ -0,0 +1,24 @@ +// Package aghhttp provides some common methods to work with HTTP. +package aghhttp + +import ( + "fmt" + "io" + "net/http" + + "github.com/AdguardTeam/golibs/log" +) + +// OK responds with word OK. +func OK(w http.ResponseWriter) { + if _, err := io.WriteString(w, "OK\n"); err != nil { + log.Error("couldn't write body: %s", err) + } +} + +// Error writes formatted message to w and also logs it. +func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { + text := fmt.Sprintf(format, args...) + log.Error("%s %s: %s", r.Method, r.URL, text) + http.Error(w, text, code) +} diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index e4c1678b..85d6103d 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -431,7 +431,7 @@ func (hp *hostsParser) writeMainHostRule(host string, ip net.IP) (added, addedPt rwSuccess = "^$dnsrewrite=NOERROR;" rwSuccessPTR = "^$dnsrewrite=NOERROR;PTR;" - modLen = len("||") + len(rwSuccess) + modLen = len("||") + len(rwSuccess) + len(";") modLenPTR = len("||") + len(rwSuccessPTR) ) diff --git a/internal/aghnet/interfaces.go b/internal/aghnet/interfaces.go index c8efc57b..b0511b0d 100644 --- a/internal/aghnet/interfaces.go +++ b/internal/aghnet/interfaces.go @@ -118,3 +118,11 @@ func IfaceDNSIPAddrs( return addrs, nil } + +// interfaceName is a string containing network interface's name. The name is +// used in file walking methods. +type interfaceName string + +// Use interfaceName in the OS-independent code since it's actually only used in +// several OS-dependent implementations which causes linting issues. +var _ = interfaceName("") diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index 393778ad..f8e36406 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -4,13 +4,11 @@ package aghnet import ( "encoding/json" "fmt" + "io" "net" - "os" "os/exec" - "runtime" "strings" "syscall" - "time" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -189,57 +187,30 @@ func GetSubnet(ifaceName string) *net.IPNet { return nil } -// CheckPortAvailable - check if TCP port is available -func CheckPortAvailable(host net.IP, port int) error { - ln, err := net.Listen("tcp", netutil.JoinHostPort(host.String(), port)) - if err != nil { - return err +// CheckPort checks if the port is available for binding. +func CheckPort(network string, ip net.IP, port int) (err error) { + var c io.Closer + addr := netutil.IPPort{IP: ip, Port: port}.String() + switch network { + case "tcp": + c, err = net.Listen(network, addr) + case "udp": + c, err = net.ListenPacket(network, addr) + default: + return nil } - _ = ln.Close() - // It seems that net.Listener.Close() doesn't close file descriptors right away. - // We wait for some time and hope that this fd will be closed. - time.Sleep(100 * time.Millisecond) - return nil + return errors.WithDeferred(err, closePortChecker(c)) } -// CheckPacketPortAvailable - check if UDP port is available -func CheckPacketPortAvailable(host net.IP, port int) error { - ln, err := net.ListenPacket("udp", netutil.JoinHostPort(host.String(), port)) - if err != nil { - return err - } - _ = ln.Close() - - // It seems that net.Listener.Close() doesn't close file descriptors right away. - // We wait for some time and hope that this fd will be closed. - time.Sleep(100 * time.Millisecond) - return err -} - -// ErrorIsAddrInUse - check if error is "address already in use" -func ErrorIsAddrInUse(err error) bool { - errOpError, ok := err.(*net.OpError) - if !ok { +// IsAddrInUse checks if err is about unsuccessful address binding. +func IsAddrInUse(err error) (ok bool) { + var sysErr syscall.Errno + if !errors.As(err, &sysErr) { return false } - errSyscallError, ok := errOpError.Err.(*os.SyscallError) - if !ok { - return false - } - - errErrno, ok := errSyscallError.Err.(syscall.Errno) - if !ok { - return false - } - - if runtime.GOOS == "windows" { - const WSAEADDRINUSE = 10048 - return errErrno == WSAEADDRINUSE - } - - return errErrno == syscall.EADDRINUSE + return isAddrInUse(sysErr) } // SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport diff --git a/internal/aghnet/net_others.go b/internal/aghnet/net_others.go deleted file mode 100644 index de88ed80..00000000 --- a/internal/aghnet/net_others.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build !(linux || darwin || freebsd || openbsd) -// +build !linux,!darwin,!freebsd,!openbsd - -package aghnet - -import ( - "github.com/AdguardTeam/AdGuardHome/internal/aghos" -) - -func canBindPrivilegedPorts() (can bool, err error) { - return aghos.HaveAdminRights() -} - -func ifaceHasStaticIP(string) (ok bool, err error) { - return false, aghos.Unsupported("checking static ip") -} - -func ifaceSetStaticIP(string) (err error) { - return aghos.Unsupported("setting static ip") -} diff --git a/internal/aghnet/net_unix.go b/internal/aghnet/net_unix.go index efca131b..9f0f5011 100644 --- a/internal/aghnet/net_unix.go +++ b/internal/aghnet/net_unix.go @@ -1,8 +1,19 @@ -//go:build openbsd || freebsd || linux -// +build openbsd freebsd linux +//go:build openbsd || freebsd || linux || darwin +// +build openbsd freebsd linux darwin package aghnet -// interfaceName is a string containing network interface's name. The name is -// used in file walking methods. -type interfaceName string +import ( + "io" + "syscall" + + "github.com/AdguardTeam/golibs/errors" +) + +func closePortChecker(c io.Closer) (err error) { + return c.Close() +} + +func isAddrInUse(err syscall.Errno) (ok bool) { + return errors.Is(err, syscall.EADDRINUSE) +} diff --git a/internal/aghnet/net_windows.go b/internal/aghnet/net_windows.go new file mode 100644 index 00000000..bbbb81d7 --- /dev/null +++ b/internal/aghnet/net_windows.go @@ -0,0 +1,45 @@ +//go:build !(linux || darwin || freebsd || openbsd) +// +build !linux,!darwin,!freebsd,!openbsd + +package aghnet + +import ( + "io" + "syscall" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/errors" + "golang.org/x/sys/windows" +) + +func canBindPrivilegedPorts() (can bool, err error) { + return aghos.HaveAdminRights() +} + +func ifaceHasStaticIP(string) (ok bool, err error) { + return false, aghos.Unsupported("checking static ip") +} + +func ifaceSetStaticIP(string) (err error) { + return aghos.Unsupported("setting static ip") +} + +func closePortChecker(c io.Closer) (err error) { + if err = c.Close(); err != nil { + return err + } + + // It seems that net.Listener.Close() doesn't close file descriptors right + // away. We wait for some time and hope that this fd will be closed. + // + // TODO(e.burkov): Investigate the purpose of the line and perhaps use more + // reliable approach. + time.Sleep(100 * time.Millisecond) + + return nil +} + +func isAddrInUse(err syscall.Errno) (ok bool) { + return errors.Is(err, windows.WSAEADDRINUSE) +} diff --git a/internal/dhcpd/http.go b/internal/dhcpd/http.go index e016dee8..78016010 100644 --- a/internal/dhcpd/http.go +++ b/internal/dhcpd/http.go @@ -10,18 +10,13 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/timeutil" ) -func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { - text := fmt.Sprintf(format, args...) - log.Info("DHCP: %s %s: %s", r.Method, r.URL, text) - http.Error(w, text, code) -} - type v4ServerConfJSON struct { GatewayIP net.IP `json:"gateway_ip"` SubnetMask net.IP `json:"subnet_mask"` @@ -87,8 +82,13 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(status) if err != nil { - httpError(r, w, http.StatusInternalServerError, "Unable to marshal DHCP status json: %s", err) - return + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Unable to marshal DHCP status json: %s", + err, + ) } } @@ -211,36 +211,34 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { err := json.NewDecoder(r.Body).Decode(conf) if err != nil { - httpError(r, w, http.StatusBadRequest, - "failed to parse new dhcp config json: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse new dhcp config json: %s", err) return } srv4, v4Enabled, err := s.handleDHCPSetConfigV4(conf) if err != nil { - httpError(r, w, http.StatusBadRequest, "bad dhcpv4 configuration: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "bad dhcpv4 configuration: %s", err) return } srv6, v6Enabled, err := s.handleDHCPSetConfigV6(conf) if err != nil { - httpError(r, w, http.StatusBadRequest, "bad dhcpv6 configuration: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "bad dhcpv6 configuration: %s", err) return } if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled { - httpError(r, w, http.StatusBadRequest, - "dhcpv4 or dhcpv6 configuration must be complete") + aghhttp.Error(r, w, http.StatusBadRequest, "dhcpv4 or dhcpv6 configuration must be complete") return } err = s.Stop() if err != nil { - httpError(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err) return } @@ -265,7 +263,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { err = s.dbLoad() if err != nil { - httpError(r, w, http.StatusInternalServerError, "loading leases db: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "loading leases db: %s", err) return } @@ -274,9 +272,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { var code int code, err = s.enableDHCP(conf.InterfaceName) if err != nil { - httpError(r, w, code, "enabling dhcp: %s", err) - - return + aghhttp.Error(r, w, code, "enabling dhcp: %s", err) } } } @@ -295,7 +291,8 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { ifaces, err := net.Interfaces() if err != nil { - httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) + return } @@ -312,7 +309,15 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { var addrs []net.Addr addrs, err = iface.Addrs() if err != nil { - httpError(r, w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err) + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Failed to get addresses for interface %s: %s", + iface.Name, + err, + ) + return } @@ -329,7 +334,13 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { ipnet, ok := addr.(*net.IPNet) if !ok { // not an IPNet, should not happen - httpError(r, w, http.StatusInternalServerError, "SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "got iface.Addrs() element %[1]s that is not net.IPNet, it is %[1]T", + addr) + return } // ignore link-local @@ -350,8 +361,13 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { err = json.NewEncoder(w).Encode(response) if err != nil { - httpError(r, w, http.StatusInternalServerError, "Failed to marshal json with available interfaces: %s", err) - return + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Failed to marshal json with available interfaces: %s", + err, + ) } } @@ -455,9 +471,13 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(result) if err != nil { - httpError(r, w, http.StatusInternalServerError, "Failed to marshal DHCP found json: %s", err) - - return + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Failed to marshal DHCP found json: %s", + err, + ) } } @@ -465,13 +485,13 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request l := &Lease{} err := json.NewDecoder(r.Body).Decode(l) if err != nil { - httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err) return } if l.IP == nil { - httpError(r, w, http.StatusBadRequest, "invalid IP") + aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP") return } @@ -483,7 +503,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request err = s.srv6.AddStaticLease(l) if err != nil { - httpError(r, w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) } return @@ -492,7 +512,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request l.IP = ip4 err = s.srv4.AddStaticLease(l) if err != nil { - httpError(r, w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } @@ -502,13 +522,13 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ l := &Lease{} err := json.NewDecoder(r.Body).Decode(l) if err != nil { - httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err) return } if l.IP == nil { - httpError(r, w, http.StatusBadRequest, "invalid IP") + aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP") return } @@ -520,7 +540,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ err = s.srv6.RemoveStaticLease(l) if err != nil { - httpError(r, w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) } return @@ -529,7 +549,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ l.IP = ip4 err = s.srv4.RemoveStaticLease(l) if err != nil { - httpError(r, w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } @@ -545,7 +565,7 @@ const ( func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) { err := s.Stop() if err != nil { - httpError(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err) return } @@ -583,7 +603,7 @@ func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) { err := s.resetLeases() if err != nil { msg := "resetting leases: %s" - httpError(r, w, http.StatusInternalServerError, msg, err) + aghhttp.Error(r, w, http.StatusInternalServerError, msg, err) return } diff --git a/internal/dnsforward/access.go b/internal/dnsforward/access.go index efa3bf8c..79f6c2b7 100644 --- a/internal/dnsforward/access.go +++ b/internal/dnsforward/access.go @@ -7,6 +7,7 @@ import ( "net/http" "strings" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" @@ -187,7 +188,7 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(j) if err != nil { - httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err) return } @@ -251,14 +252,14 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) { list := accessListJSON{} err := json.NewDecoder(r.Body).Decode(&list) if err != nil { - httpError(r, w, http.StatusBadRequest, "decoding request: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err) return } err = validateAccessSet(list) if err != nil { - httpError(r, w, http.StatusBadRequest, err.Error()) + aghhttp.Error(r, w, http.StatusBadRequest, err.Error()) return } @@ -266,7 +267,7 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) { var a *accessCtx a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts) if err != nil { - httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "creating access ctx: %s", err) return } diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 3623de71..42a81946 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" @@ -18,12 +19,6 @@ import ( "github.com/miekg/dns" ) -func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { - text := fmt.Sprintf(format, args...) - log.Info("dns: %s %s: %s", r.Method, r.URL, text) - http.Error(w, text, code) -} - type dnsConfig struct { Upstreams *[]string `json:"upstream_dns"` UpstreamsFile *string `json:"upstream_dns_file"` @@ -119,7 +114,8 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err = json.NewEncoder(w).Encode(resp); err != nil { - httpError(r, w, http.StatusInternalServerError, "json.Encoder: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encoder: %s", err) + return } } @@ -198,34 +194,52 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { req := dnsConfig{} dec := json.NewDecoder(r.Body) if err := dec.Decode(&req); err != nil { - httpError(r, w, http.StatusBadRequest, "json Encode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json Encode: %s", err) + return } if req.Upstreams != nil { if err := ValidateUpstreams(*req.Upstreams); err != nil { - httpError(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err) + return } } if errBoot, err := req.checkBootstrap(); err != nil { - httpError(r, w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", errBoot, err) + aghhttp.Error( + r, + w, + http.StatusBadRequest, + "%s can not be used as bootstrap dns cause: %s", + errBoot, + err, + ) + return } if !req.checkBlockingMode() { - httpError(r, w, http.StatusBadRequest, "blocking_mode: incorrect value") + aghhttp.Error(r, w, http.StatusBadRequest, "blocking_mode: incorrect value") + return } if !req.checkUpstreamsMode() { - httpError(r, w, http.StatusBadRequest, "upstream_mode: incorrect value") + aghhttp.Error(r, w, http.StatusBadRequest, "upstream_mode: incorrect value") + return } if !req.checkCacheTTL() { - httpError(r, w, http.StatusBadRequest, "cache_ttl_min must be less or equal than cache_ttl_max") + aghhttp.Error( + r, + w, + http.StatusBadRequest, + "cache_ttl_min must be less or equal than cache_ttl_max", + ) + return } @@ -234,8 +248,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { if restart { if err := s.Reconfigure(nil); err != nil { - httpError(r, w, http.StatusInternalServerError, "%s", err) - return + aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) } } } @@ -582,7 +595,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { req := &upstreamJSON{} err := json.NewDecoder(r.Body).Decode(req) if err != nil { - httpError(r, w, http.StatusBadRequest, "Failed to read request body: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err) return } @@ -620,7 +633,13 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { jsonVal, err := json.Marshal(result) if err != nil { - httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Unable to marshal status json: %s", + err, + ) return } @@ -628,9 +647,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _, err = w.Write(jsonVal) if err != nil { - httpError(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err) - - return + aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err) } } @@ -641,12 +658,12 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { // -> dnsforward.handleDNSRequest func (s *Server) handleDoH(w http.ResponseWriter, r *http.Request) { if !s.conf.TLSAllowUnencryptedDoH && r.TLS == nil { - httpError(r, w, http.StatusNotFound, "Not Found") + aghhttp.Error(r, w, http.StatusNotFound, "Not Found") return } if !s.IsRunning() { - httpError(r, w, http.StatusInternalServerError, "dns server is not running") + aghhttp.Error(r, w, http.StatusInternalServerError, "dns server is not running") return } diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index a11b72c7..aa0ba979 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" ) @@ -287,7 +288,8 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(list) if err != nil { - httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err) + return } } @@ -296,7 +298,8 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ list := []string{} err := json.NewDecoder(r.Body).Decode(&list) if err != nil { - httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err) + return } diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 0e0bb0ab..2f5717f1 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -864,8 +864,7 @@ func makeResult(matchedRules []rules.Rule, reason Reason) (res Result) { } } -// InitModule manually initializes blocked services map using blockedSvcListID -// as list ID for the rules. +// InitModule manually initializes blocked services map. func InitModule() { initBlockedServices() } diff --git a/internal/filtering/rewrites.go b/internal/filtering/rewrites.go index 90eb262e..19885d55 100644 --- a/internal/filtering/rewrites.go +++ b/internal/filtering/rewrites.go @@ -9,6 +9,7 @@ import ( "sort" "strings" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) @@ -206,7 +207,8 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(arr) if err != nil { - httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err) + return } } @@ -215,7 +217,8 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { jsent := rewriteEntryJSON{} err := json.NewDecoder(r.Body).Decode(&jsent) if err != nil { - httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err) + return } @@ -237,7 +240,8 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) jsent := rewriteEntryJSON{} err := json.NewDecoder(r.Body).Decode(&jsent) if err != nil { - httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err) + return } diff --git a/internal/filtering/safebrowsing.go b/internal/filtering/safebrowsing.go index 495a9067..9de4bcd2 100644 --- a/internal/filtering/safebrowsing.go +++ b/internal/filtering/safebrowsing.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" @@ -368,12 +369,6 @@ func (d *DNSFilter) checkParental( return check(sctx, res, d.parentalUpstream) } -func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { - text := fmt.Sprintf(format, args...) - log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text) - http.Error(w, text, code) -} - func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { d.Config.SafeBrowsingEnabled = true d.Config.ConfigModified() @@ -392,7 +387,8 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ Enabled: d.Config.SafeBrowsingEnabled, }) if err != nil { - httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) + return } } @@ -415,8 +411,7 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) Enabled: d.Config.ParentalEnabled, }) if err != nil { - httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) - return + aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) } } diff --git a/internal/filtering/safesearch.go b/internal/filtering/safesearch.go index c67ac735..23751cb0 100644 --- a/internal/filtering/safesearch.go +++ b/internal/filtering/safesearch.go @@ -11,6 +11,7 @@ import ( "net/http" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" ) @@ -152,8 +153,13 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques Enabled: d.Config.SafeSearchEnabled, }) if err != nil { - httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) - return + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Unable to write response json: %s", + err, + ) } } diff --git a/internal/home/auth.go b/internal/home/auth.go index bfcc7267..26a807a9 100644 --- a/internal/home/auth.go +++ b/internal/home/auth.go @@ -13,6 +13,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/timeutil" @@ -417,7 +418,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { req := loginJSON{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - httpError(w, http.StatusBadRequest, "json decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err) return } @@ -429,7 +430,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { // // TODO(e.burkov): Use realIP when the issue will be fixed. if remoteAddr, err = netutil.SplitHost(r.RemoteAddr); err != nil { - httpError(w, http.StatusBadRequest, "auth: getting remote address: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "auth: getting remote address: %s", err) return } @@ -437,7 +438,8 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { if blocker := Context.auth.blocker; blocker != nil { if left := blocker.check(remoteAddr); left > 0 { w.Header().Set("Retry-After", strconv.Itoa(int(left.Seconds()))) - httpError( + aghhttp.Error( + r, w, http.StatusTooManyRequests, "auth: blocked for %s", @@ -451,7 +453,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { var cookie string cookie, err = Context.auth.httpCookie(req, remoteAddr) if err != nil { - httpError(w, http.StatusBadRequest, "crypto rand reader: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "crypto rand reader: %s", err) return } @@ -480,7 +482,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { w.Header().Set("Pragma", "no-cache") w.Header().Set("Expires", "0") - returnOK(w) + aghhttp.OK(w) } func handleLogout(w http.ResponseWriter, r *http.Request) { diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 412ff002..1f053343 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -6,6 +6,7 @@ import ( "net" "net/http" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/log" ) @@ -58,7 +59,7 @@ type clientListJSON struct { } // respond with information about configured clients -func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http.Request) { +func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) { data := clientListJSON{} clients.lock.Lock() @@ -106,7 +107,14 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http w.Header().Set("Content-Type", "application/json") e := json.NewEncoder(w).Encode(data) if e != nil { - httpError(w, http.StatusInternalServerError, "Failed to encode to json: %v", e) + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Failed to encode to json: %v", + e, + ) + return } } @@ -154,7 +162,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. cj := clientJSON{} err := json.NewDecoder(r.Body).Decode(&cj) if err != nil { - httpError(w, http.StatusBadRequest, "failed to process request body: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err) return } @@ -162,11 +170,14 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. c := jsonToClient(cj) ok, err := clients.Add(c) if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + return } + if !ok { - httpError(w, http.StatusBadRequest, "Client already exists") + aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists") + return } @@ -178,19 +189,19 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http. cj := clientJSON{} err := json.NewDecoder(r.Body).Decode(&cj) if err != nil { - httpError(w, http.StatusBadRequest, "failed to process request body: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err) return } if len(cj.Name) == 0 { - httpError(w, http.StatusBadRequest, "client's name must be non-empty") + aghhttp.Error(r, w, http.StatusBadRequest, "client's name must be non-empty") return } if !clients.Del(cj.Name) { - httpError(w, http.StatusBadRequest, "Client not found") + aghhttp.Error(r, w, http.StatusBadRequest, "Client not found") return } @@ -207,20 +218,22 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht dj := updateJSON{} err := json.NewDecoder(r.Body).Decode(&dj) if err != nil { - httpError(w, http.StatusBadRequest, "failed to process request body: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err) return } if len(dj.Name) == 0 { - httpError(w, http.StatusBadRequest, "Invalid request") + aghhttp.Error(r, w, http.StatusBadRequest, "Invalid request") + return } c := jsonToClient(dj.Data) err = clients.Update(dj.Name, c) if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + return } @@ -256,7 +269,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(data) if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write response: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write response: %s", err) } } diff --git a/internal/home/config.go b/internal/home/config.go index 99463a6a..e5e1b14b 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -1,7 +1,6 @@ package home import ( - "fmt" "net" "os" "path/filepath" @@ -274,17 +273,34 @@ func getLogSettings() logSettings { } // parseConfig loads configuration from the YAML file -func parseConfig() error { - configFile := config.getConfigFilename() - log.Debug("Reading config file: %s", configFile) - yamlFile, err := readConfigFile() +func parseConfig() (err error) { + var fileData []byte + fileData, err = readConfigFile() if err != nil { return err } + config.fileData = nil - err = yaml.Unmarshal(yamlFile, &config) + err = yaml.Unmarshal(fileData, &config) if err != nil { - log.Error("Couldn't parse config file: %s", err) + return err + } + + pm := portsMap{} + pm.add( + config.BindPort, + config.BetaBindPort, + config.DNS.Port, + ) + if config.TLS.Enabled { + pm.add( + config.TLS.PortHTTPS, + config.TLS.PortDNSOverTLS, + config.TLS.PortDNSOverQUIC, + config.TLS.PortDNSCrypt, + ) + } + if err = pm.validate(); err != nil { return err } @@ -299,18 +315,17 @@ func parseConfig() error { return nil } -// readConfigFile reads config file contents if it exists -func readConfigFile() ([]byte, error) { - if len(config.fileData) != 0 { +// readConfigFile reads configuration file contents. +func readConfigFile() (fileData []byte, err error) { + if len(config.fileData) > 0 { return config.fileData, nil } - configFile := config.getConfigFilename() - d, err := os.ReadFile(configFile) - if err != nil { - return nil, fmt.Errorf("couldn't read config file %s: %w", configFile, err) - } - return d, nil + name := config.getConfigFilename() + log.Debug("reading config file: %s", name) + + // Do not wrap the error because it's informative enough as is. + return os.ReadFile(name) } // Saves configuration to the YAML file and also saves the user filter contents to a file diff --git a/internal/home/control.go b/internal/home/control.go index d9d809ba..df21e310 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -9,6 +9,7 @@ import ( "runtime" "strings" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/version" @@ -17,23 +18,6 @@ import ( "github.com/NYTimes/gziphandler" ) -// ---------------- -// helper functions -// ---------------- - -func returnOK(w http.ResponseWriter) { - _, err := fmt.Fprintf(w, "OK\n") - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) - } -} - -func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) { - text := fmt.Sprintf(format, args...) - log.Info(text) - http.Error(w, text, code) -} - // appendDNSAddrs is a convenient helper for appending a formatted form of DNS // addresses to a slice of strings. func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) { @@ -125,12 +109,12 @@ type statusResponse struct { Language string `json:"language"` } -func handleStatus(w http.ResponseWriter, _ *http.Request) { +func handleStatus(w http.ResponseWriter, r *http.Request) { dnsAddrs, err := collectDNSAddresses() if err != nil { // Don't add a lot of formatting, since the error is already // wrapped by collectDNSAddresses. - httpError(w, http.StatusInternalServerError, "%s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) return } @@ -165,7 +149,7 @@ func handleStatus(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(resp) if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) return } @@ -182,7 +166,7 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) { data, err := json.Marshal(pj) if err != nil { - httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "json.Marshal: %s", err) return } _, _ = w.Write(data) @@ -295,7 +279,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) { host, err := netutil.SplitHost(r.Host) if err != nil { - httpError(w, http.StatusBadRequest, "bad host: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "bad host: %s", err) return false } diff --git a/internal/home/controlfiltering.go b/internal/home/controlfiltering.go index ee19fd16..639403d8 100644 --- a/internal/home/controlfiltering.go +++ b/internal/home/controlfiltering.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) @@ -49,7 +50,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request fj := filterAddJSON{} err := json.NewDecoder(r.Body).Decode(&fj) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse request body json: %s", err) + return } @@ -63,7 +65,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request // Check for duplicates if filterExists(fj.URL) { - httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL) + aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL) + return } @@ -79,17 +82,35 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request // Download the filter contents ok, err := f.update(&filt) if err != nil { - httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err) - return - } - if !ok { - httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL) + aghhttp.Error( + r, + w, + http.StatusBadRequest, + "Couldn't fetch filter from url %s: %s", + filt.URL, + err, + ) + return } - // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it + if !ok { + aghhttp.Error( + r, + w, + http.StatusBadRequest, + "Filter at the url %s is invalid (maybe it points to blank page?)", + filt.URL, + ) + + return + } + + // URL is assumed valid so append it to filters, update config, write new + // file and reload it to engines. if !filterAdd(filt) { - httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL) + aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL) + return } @@ -98,7 +119,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request _, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount) if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err) } } @@ -111,7 +132,8 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ req := request{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - httpError(w, http.StatusBadRequest, "failed to parse request body json: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse request body json: %s", err) + return } @@ -152,7 +174,7 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ _, err = fmt.Fprintf(w, "OK %d rules\n", deleted.RulesCount) if err != nil { - httpError(w, http.StatusInternalServerError, "couldn't write body: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "couldn't write body: %s", err) } } @@ -172,7 +194,8 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request fj := filterURLReq{} err := json.NewDecoder(r.Body).Decode(&fj) if err != nil { - httpError(w, http.StatusBadRequest, "json decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err) + return } @@ -228,7 +251,8 @@ func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque // This use of ReadAll is safe, because request's body is now limited. body, err := io.ReadAll(r.Body) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err) + return } @@ -250,7 +274,8 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques req := Req{} err = json.NewDecoder(r.Body).Decode(&req) if err != nil { - httpError(w, http.StatusBadRequest, "json decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err) + return } @@ -270,13 +295,15 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false) }() if err != nil { - httpError(w, http.StatusInternalServerError, "%s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) + return } js, err := json.Marshal(resp) if err != nil { - httpError(w, http.StatusInternalServerError, "json encode: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) + return } w.Header().Set("Content-Type", "application/json") @@ -335,13 +362,14 @@ func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request jsonVal, err := json.Marshal(resp) if err != nil { - httpError(w, http.StatusInternalServerError, "json encode: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) + return } w.Header().Set("Content-Type", "application/json") _, err = w.Write(jsonVal) if err != nil { - httpError(w, http.StatusInternalServerError, "http write: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err) } } @@ -350,12 +378,14 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request req := filteringConfig{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - httpError(w, http.StatusBadRequest, "json decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err) + return } if !checkFiltersUpdateIntervalHours(req.Interval) { - httpError(w, http.StatusBadRequest, "Unsupported interval") + aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval") + return } @@ -408,7 +438,15 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) if err != nil { - httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err) + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "couldn't apply filtering: %s: %s", + host, + err, + ) + return } @@ -433,7 +471,8 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { js, err := json.Marshal(resp) if err != nil { - httpError(w, http.StatusInternalServerError, "json encode: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) + return } w.Header().Set("Content-Type", "application/json") diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 84996017..baae5d00 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -34,7 +35,8 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request ifaces, err := aghnet.GetValidNetInterfacesForWeb() if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) + return } @@ -46,7 +48,14 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(data) if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err) + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Unable to marshal default addresses to json: %s", + err, + ) + return } } @@ -84,23 +93,32 @@ type checkConfigResp struct { func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { reqData := checkConfigReq{} respData := checkConfigResp{} + err := json.NewDecoder(r.Body).Decode(&reqData) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err) + return } - if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort && reqData.Web.Port != config.BetaBindPort { - err = aghnet.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port) + pm := portsMap{} + pm.add(config.BindPort, config.BetaBindPort, reqData.Web.Port) + if err = pm.validate(); err != nil { + respData.Web.Status = err.Error() + } else if reqData.Web.Port != 0 { + err = aghnet.CheckPort("tcp", reqData.Web.IP, reqData.Web.Port) if err != nil { respData.Web.Status = err.Error() } } - if reqData.DNS.Port != 0 { - err = aghnet.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + pm.add(reqData.DNS.Port) + if err = pm.validate(); err != nil { + respData.DNS.Status = err.Error() + } else if reqData.DNS.Port != 0 { + err = aghnet.CheckPort("udp", reqData.DNS.IP, reqData.DNS.Port) - if aghnet.ErrorIsAddrInUse(err) { + if aghnet.IsAddrInUse(err) { canAutofix := checkDNSStubListener() if canAutofix && reqData.DNS.Autofix { @@ -109,7 +127,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) log.Error("Couldn't disable DNSStubListener: %s", err) } - err = aghnet.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = aghnet.CheckPort("udp", reqData.DNS.IP, reqData.DNS.Port) canAutofix = false } @@ -117,7 +135,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) } if err == nil { - err = aghnet.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port) + err = aghnet.CheckPort("tcp", reqData.DNS.IP, reqData.DNS.Port) } if err != nil { @@ -130,7 +148,8 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(respData) if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal JSON: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to marshal JSON: %s", err) + return } } @@ -287,21 +306,21 @@ func shutdownSrv(ctx context.Context, srv *http.Server) { func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { req, restartHTTP, err := decodeApplyConfigReq(r.Body) if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - err = aghnet.CheckPacketPortAvailable(req.DNS.IP, req.DNS.Port) + err = aghnet.CheckPort("udp", req.DNS.IP, req.DNS.Port) if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - err = aghnet.CheckPortAvailable(req.DNS.IP, req.DNS.Port) + err = aghnet.CheckPort("tcp", req.DNS.IP, req.DNS.Port) if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } @@ -315,28 +334,29 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { config.DNS.BindHosts = []net.IP{req.DNS.IP} config.DNS.Port = req.DNS.Port - // TODO(e.burkov): StartMods() should be put in a separate goroutine at - // the moment we'll allow setting up TLS in the initial configuration or - // the configuration itself will use HTTPS protocol, because the - // underlying functions potentially restart the HTTPS server. + // TODO(e.burkov): StartMods() should be put in a separate goroutine at the + // moment we'll allow setting up TLS in the initial configuration or the + // configuration itself will use HTTPS protocol, because the underlying + // functions potentially restart the HTTPS server. err = StartMods() if err != nil { Context.firstRun = true copyInstallSettings(config, curConfig) - httpError(w, http.StatusInternalServerError, "%s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) return } - u := User{} - u.Name = req.Username - Context.auth.UserAdd(&u, req.Password) + u := &User{ + Name: req.Username, + } + Context.auth.UserAdd(u, req.Password) err = config.write() if err != nil { Context.firstRun = true copyInstallSettings(config, curConfig) - httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err) return } @@ -347,7 +367,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { registerControlHandlers() - returnOK(w) + aghhttp.OK(w) if f, ok := w.(http.Flusher); ok { f.Flush() } @@ -386,7 +406,7 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e restartHTTP = !config.BindHost.Equal(req.Web.IP) || config.BindPort != req.Web.Port if restartHTTP { - err = aghnet.CheckPortAvailable(req.Web.IP, req.Web.Port) + err = aghnet.CheckPort("tcp", req.Web.IP, req.Web.Port) if err != nil { return nil, false, fmt.Errorf( "checking address %s:%d: %w", @@ -437,12 +457,14 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ reqData := checkConfigReqBeta{} err := json.NewDecoder(r.Body).Decode(&reqData) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err) + return } if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 { - httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest)) + aghhttp.Error(r, w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest)) + return } @@ -464,7 +486,14 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err) + aghhttp.Error( + r, + w, + http.StatusBadRequest, + "Failed to encode 'check_config' JSON data: %s", + err, + ) + return } body := nonBetaReqBody.String() @@ -505,12 +534,14 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques reqData := applyConfigReqBeta{} err := json.NewDecoder(r.Body).Decode(&reqData) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err) + return } if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 { - httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest)) + aghhttp.Error(r, w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest)) + return } @@ -531,7 +562,14 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err) + aghhttp.Error( + r, + w, + http.StatusBadRequest, + "Failed to encode 'check_config' JSON data: %s", + err, + ) + return } body := nonBetaReqBody.String() @@ -564,7 +602,8 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req ifaces, err := aghnet.GetValidNetInterfacesForWeb() if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) + return } @@ -573,7 +612,14 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(data) if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err) + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Unable to marshal default addresses to json: %s", + err, + ) + return } } diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index 5217190e..79b9f37e 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -11,6 +11,7 @@ import ( "syscall" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/golibs/errors" @@ -43,7 +44,8 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { if r.ContentLength != 0 { err = json.NewDecoder(r.Body).Decode(req) if err != nil { - httpError(w, http.StatusBadRequest, "JSON parse: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "JSON parse: %s", err) + return } } @@ -77,7 +79,15 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { if err != nil { vcu := Context.updater.VersionCheckURL() // TODO(a.garipov): Figure out the purpose of %T verb. - httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", vcu, err, err) + aghhttp.Error( + r, + w, + http.StatusBadGateway, + "Couldn't get version check json from %s: %T %s\n", + vcu, + err, + err, + ) return } @@ -87,24 +97,26 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(resp) if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err) } } // handleUpdate performs an update to the latest available version procedure. -func handleUpdate(w http.ResponseWriter, _ *http.Request) { +func handleUpdate(w http.ResponseWriter, r *http.Request) { if Context.updater.NewVersion() == "" { - httpError(w, http.StatusBadRequest, "/update request isn't allowed now") + aghhttp.Error(r, w, http.StatusBadRequest, "/update request isn't allowed now") + return } err := Context.updater.Update() if err != nil { - httpError(w, http.StatusInternalServerError, "%s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) + return } - returnOK(w) + aghhttp.OK(w) if f, ok := w.(http.Flusher); ok { f.Flush() } diff --git a/internal/home/home.go b/internal/home/home.go index 1238aaec..fe11a8d3 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -159,14 +159,11 @@ func setupContext(args options) { } if !Context.firstRun { - // Do the upgrade if necessary + // Do the upgrade if necessary. err := upgradeConfig() - if err != nil { - log.Fatal(err) - } + fatalOnError(err) - err = parseConfig() - if err != nil { + if err = parseConfig(); err != nil { log.Error("parsing configuration file: %s", err) os.Exit(1) @@ -186,13 +183,13 @@ func setupContext(args options) { // unsupported errors and returns nil. If err is nil, logIfUnsupported returns // nil. Otherise, it returns err. func logIfUnsupported(msg string, err error) (outErr error) { - if unsupErr := (&aghos.UnsupportedError{}); errors.As(err, &unsupErr) { + if errors.As(err, new(*aghos.UnsupportedError)) { log.Debug(msg, err) - } else if err != nil { - return err + + return nil } - return nil + return err } // configureOS sets the OS-related configuration. @@ -297,13 +294,32 @@ func setupConfig(args options) (err error) { Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts) + if args.bindPort != 0 { + pm := portsMap{} + pm.add( + args.bindPort, + config.BetaBindPort, + config.DNS.Port, + ) + if config.TLS.Enabled { + pm.add( + config.TLS.PortHTTPS, + config.TLS.PortDNSOverTLS, + config.TLS.PortDNSOverQUIC, + config.TLS.PortDNSCrypt, + ) + } + if err = pm.validate(); err != nil { + return err + } + + config.BindPort = args.bindPort + } + // override bind host/port from the console if args.bindHost != nil { config.BindHost = args.bindHost } - if args.bindPort != 0 { - config.BindPort = args.bindPort - } if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { Context.pidFileName = args.pidFile } @@ -766,8 +782,7 @@ func printHTTPAddresses(proto string) { port = tlsConf.PortHTTPS } - // TODO(e.burkov): Inspect and perhaps merge with the previous - // condition. + // TODO(e.burkov): Inspect and perhaps merge with the previous condition. if proto == schemeHTTPS && tlsConf.ServerName != "" { printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0) diff --git a/internal/home/i18n.go b/internal/home/i18n.go index d46bd6c6..51f6a2ac 100644 --- a/internal/home/i18n.go +++ b/internal/home/i18n.go @@ -6,6 +6,7 @@ import ( "net/http" "strings" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" ) @@ -96,5 +97,5 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) { }() onConfigModified() - returnOK(w) + aghhttp.OK(w) } diff --git a/internal/home/portsmap.go b/internal/home/portsmap.go new file mode 100644 index 00000000..cfbfcde9 --- /dev/null +++ b/internal/home/portsmap.go @@ -0,0 +1,63 @@ +package home + +import ( + "fmt" + "strconv" + "strings" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/stringutil" +) + +// portsMap is a helper type for mapping a network port to the number of its +// users. +type portsMap map[int]int + +// add binds each of ps. Zeroes are skipped. +func (pm portsMap) add(ps ...int) { + for _, p := range ps { + if p == 0 { + continue + } + + pm[p]++ + } +} + +// validate returns an error about all the ports bound several times. +func (pm portsMap) validate() (err error) { + overbound := []int{} + for p, num := range pm { + if num > 1 { + overbound = append(overbound, p) + pm[p] = 1 + } + } + + switch len(overbound) { + case 0: + return nil + case 1: + return fmt.Errorf("port %d is already used", overbound[0]) + default: + b := &strings.Builder{} + + // TODO(e.burkov, a.garipov): Add JoinToBuilder helper to stringutil. + stringutil.WriteToBuilder(b, "ports ", strconv.Itoa(overbound[0])) + for _, p := range overbound[1:] { + stringutil.WriteToBuilder(b, ", ", strconv.Itoa(p)) + } + stringutil.WriteToBuilder(b, " are already used") + + return errors.Error(b.String()) + } +} + +// validatePorts is a helper function for a single-step ports binding +// validation. +func validatePorts(ps ...int) (err error) { + pm := portsMap{} + pm.add(ps...) + + return pm.validate() +} diff --git a/internal/home/tls.go b/internal/home/tls.go index c334e2e1..35138c34 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -20,6 +20,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -224,7 +225,7 @@ type tlsConfigSettingsExt struct { PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"` } -func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, _ *http.Request) { +func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, r *http.Request) { t.confLock.Lock() data := tlsConfig{ tlsConfigSettingsExt: tlsConfigSettingsExt{ @@ -233,13 +234,14 @@ func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, _ *http.Request) { tlsConfigStatus: t.status, } t.confLock.Unlock() - marshalTLS(w, data) + marshalTLS(w, r, data) } func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) { setts, err := unmarshalTLS(r) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) + return } @@ -247,8 +249,31 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) { setts.PrivateKey = t.conf.PrivateKey } + if setts.Enabled { + if err = validatePorts( + config.BindPort, + config.BetaBindPort, + config.DNS.Port, + setts.PortHTTPS, + setts.PortDNSOverTLS, + setts.PortDNSOverQUIC, + setts.PortDNSCrypt, + ); err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + + return + } + } + if !WebCheckPortAvailable(setts.PortHTTPS) { - httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS) + aghhttp.Error( + r, + w, + http.StatusBadRequest, + "port %d is not available, cannot enable HTTPS on it", + setts.PortHTTPS, + ) + return } @@ -261,7 +286,8 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) { tlsConfigSettingsExt: setts, tlsConfigStatus: status, } - marshalTLS(w, data) + + marshalTLS(w, r, data) } func (t *TLSMod) setConfig(newConf tlsConfigSettings, status tlsConfigStatus) (restartHTTPS bool) { @@ -302,7 +328,8 @@ func (t *TLSMod) setConfig(newConf tlsConfigSettings, status tlsConfigStatus) (r func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { data, err := unmarshalTLS(r) if err != nil { - httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) + return } @@ -310,8 +337,32 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { data.PrivateKey = t.conf.PrivateKey } + if data.Enabled { + if err = validatePorts( + config.BindPort, + config.BetaBindPort, + config.DNS.Port, + data.PortHTTPS, + data.PortDNSOverTLS, + data.PortDNSOverQUIC, + data.PortDNSCrypt, + ); err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + + return + } + } + + // TODO(e.burkov): Investigate and perhaps check other ports. if !WebCheckPortAvailable(data.PortHTTPS) { - httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) + aghhttp.Error( + r, + w, + http.StatusBadRequest, + "port %d is not available, cannot enable HTTPS on it", + data.PortHTTPS, + ) + return } @@ -321,7 +372,7 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { tlsConfigSettingsExt: data, tlsConfigStatus: t.status, } - marshalTLS(w, data2) + marshalTLS(w, r, data2) return } @@ -334,7 +385,7 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { err = reconfigureDNSServer() if err != nil { - httpError(w, http.StatusInternalServerError, "%s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) return } @@ -344,15 +395,15 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { tlsConfigStatus: t.status, } - marshalTLS(w, data2) + marshalTLS(w, r, data2) if f, ok := w.(http.Flusher); ok { f.Flush() } - // The background context is used because the TLSConfigChanged wraps - // context with timeout on its own and shuts down the server, which - // handles current request. It is also should be done in a separate - // goroutine due to the same reason. + // The background context is used because the TLSConfigChanged wraps context + // with timeout on its own and shuts down the server, which handles current + // request. It is also should be done in a separate goroutine due to the + // same reason. if restartHTTPS { go func() { Context.web.TLSConfigChanged(context.Background(), data.tlsConfigSettings) @@ -595,7 +646,7 @@ func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) { return data, nil } -func marshalTLS(w http.ResponseWriter, data tlsConfig) { +func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) { w.Header().Set("Content-Type", "application/json") if data.CertificateChain != "" { @@ -610,8 +661,13 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) { err := json.NewEncoder(w).Encode(data) if err != nil { - httpError(w, http.StatusInternalServerError, "Failed to marshal json with TLS status: %s", err) - return + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Failed to marshal json with TLS status: %s", + err, + ) } } diff --git a/internal/home/web.go b/internal/home/web.go index 23e629b9..07f9d73a 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -114,17 +114,8 @@ func CreateWeb(conf *webConfig) *Web { // WebCheckPortAvailable - check if port is available // BUT: if we are already using this port, no need func WebCheckPortAvailable(port int) bool { - alreadyRunning := false - if Context.web.httpsServer.server != nil { - alreadyRunning = true - } - if !alreadyRunning { - err := aghnet.CheckPortAvailable(config.BindHost, port) - if err != nil { - return false - } - } - return true + return Context.web.httpsServer.server != nil || + aghnet.CheckPort("tcp", config.BindHost, port) == nil } // TLSConfigChanged updates the TLS configuration and restarts the HTTPS server diff --git a/internal/querylog/http.go b/internal/querylog/http.go index f5ddf7c0..6a2bdcee 100644 --- a/internal/querylog/http.go +++ b/internal/querylog/http.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/jsonutil" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" @@ -33,18 +34,11 @@ func (l *queryLog) initWeb() { l.conf.HTTPRegister(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig) } -func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { - text := fmt.Sprintf(format, args...) - - log.Info("QueryLog: %s %s: %s", r.Method, r.URL, text) - - http.Error(w, text, code) -} - func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { params, err := l.parseSearchParams(r) if err != nil { - httpError(r, w, http.StatusBadRequest, "failed to parse params: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse params: %s", err) + return } @@ -56,14 +50,21 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { jsonVal, err := json.Marshal(data) if err != nil { - httpError(r, w, http.StatusInternalServerError, "Couldn't marshal data into json: %s", err) + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "Couldn't marshal data into json: %s", + err, + ) + return } w.Header().Set("Content-Type", "application/json") _, err = w.Write(jsonVal) if err != nil { - httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) } } @@ -80,13 +81,15 @@ func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) { jsonVal, err := json.Marshal(resp) if err != nil { - httpError(r, w, http.StatusInternalServerError, "json encode: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) + return } + w.Header().Set("Content-Type", "application/json") _, err = w.Write(jsonVal) if err != nil { - httpError(r, w, http.StatusInternalServerError, "http write: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err) } } @@ -109,13 +112,15 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request) d := &qlogConfig{} req, err := jsonutil.DecodeObject(d, r.Body) if err != nil { - httpError(r, w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + return } ivl := time.Duration(float64(timeutil.Day) * d.Interval) if req.Exists("interval") && !checkInterval(ivl) { - httpError(r, w, http.StatusBadRequest, "Unsupported interval") + aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval") + return } diff --git a/internal/stats/http.go b/internal/stats/http.go index e828a2b5..e2f00039 100644 --- a/internal/stats/http.go +++ b/internal/stats/http.go @@ -4,21 +4,13 @@ package stats import ( "encoding/json" - "fmt" "net/http" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/log" ) -func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { - text := fmt.Sprintf(format, args...) - - log.Info("Stats: %s %s: %s", r.Method, r.URL, text) - - http.Error(w, text, code) -} - // topAddrs is an alias for the types of the TopFoo fields of statsResponse. // The key is either a client's address or a requested address. type topAddrs = map[string]uint64 @@ -71,7 +63,7 @@ func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) { log.Debug("stats: prepared data in %v", time.Since(start)) if !ok { - httpError(r, w, http.StatusInternalServerError, "Couldn't get statistics data") + aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data") return } @@ -81,7 +73,7 @@ func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) { err := json.NewEncoder(w).Encode(resp) if err != nil { - httpError(r, w, http.StatusInternalServerError, "json encode: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) return } @@ -98,13 +90,14 @@ func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) { data, err := json.Marshal(resp) if err != nil { - httpError(r, w, http.StatusInternalServerError, "json encode: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) + return } w.Header().Set("Content-Type", "application/json") _, err = w.Write(data) if err != nil { - httpError(r, w, http.StatusInternalServerError, "http write: %s", err) + aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err) } } @@ -113,12 +106,14 @@ func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) { reqData := config{} err := json.NewDecoder(r.Body).Decode(&reqData) if err != nil { - httpError(r, w, http.StatusBadRequest, "json decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err) + return } if !checkInterval(reqData.IntervalDays) { - httpError(r, w, http.StatusBadRequest, "Unsupported interval") + aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval") + return } diff --git a/main.go b/main.go index 4d773d2f..505eb3e5 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,10 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/home" ) +// Embed the prebuilt client here since we strive to keep .go files inside the +// internal directory and the embed package is unable to embed files located +// outside of the same or underlying directory. + //go:embed build build2 var clientBuildFS embed.FS