Pull request: 3835 check ports properly

Merge in DNS/adguard-home from 3835-imp-error-msg to master

Updates #3835.

Squashed commit of the following:

commit ba31cb67833df9f293fe13be96a35c2a823f115b
Merge: 19c7dfc9 4be69d35
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Dec 16 20:07:25 2021 +0300

    Merge branch 'master' into 3835-imp-error-msg

commit 19c7dfc96284a271d30d7111c86c439be3461389
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Dec 16 19:42:10 2021 +0300

    all: imp more

commit 5b9c6a3e357238bf44ef800a6033a7671f27d469
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Dec 16 18:57:02 2021 +0300

    all: introduce aghhttp

commit 29caa17200957aad2b98461573bb33d80931adcf
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Dec 16 14:23:53 2021 +0300

    all: imp more

commit 754c020191d7b9518cb0e789f3f5741ba38c3cf4
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Dec 15 20:53:41 2021 +0300

    all: imp code, log changes

commit ec712dd562f31fcc2fbc27e7035f926c79827444
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Dec 15 18:40:54 2021 +0300

    home: check ports properly
This commit is contained in:
Eugene Burkov 2021-12-16 20:54:59 +03:00
parent 4be69d35eb
commit b3210cfa7e
31 changed files with 675 additions and 348 deletions

View File

@ -49,6 +49,7 @@ and this project adheres to
### Changed ### Changed
- Port bindings are now checked for uniqueness ([#3835]).
- The DNSSEC check now simply checks against the AD flag in the response - The DNSSEC check now simply checks against the AD flag in the response
([#3904]). ([#3904]).
- Client objects in the configuration file are now sorted ([#3933]). - Client objects in the configuration file are now sorted ([#3933]).
@ -241,6 +242,7 @@ In this release, the schema version has changed from 10 to 12.
[#3772]: https://github.com/AdguardTeam/AdGuardHome/issues/3772 [#3772]: https://github.com/AdguardTeam/AdGuardHome/issues/3772
[#3778]: https://github.com/AdguardTeam/AdGuardHome/issues/3778 [#3778]: https://github.com/AdguardTeam/AdGuardHome/issues/3778
[#3815]: https://github.com/AdguardTeam/AdGuardHome/issues/3815 [#3815]: https://github.com/AdguardTeam/AdGuardHome/issues/3815
[#3835]: https://github.com/AdguardTeam/AdGuardHome/issues/3835
[#3887]: https://github.com/AdguardTeam/AdGuardHome/issues/3887 [#3887]: https://github.com/AdguardTeam/AdGuardHome/issues/3887
[#3890]: https://github.com/AdguardTeam/AdGuardHome/issues/3890 [#3890]: https://github.com/AdguardTeam/AdGuardHome/issues/3890
[#3904]: https://github.com/AdguardTeam/AdGuardHome/issues/3904 [#3904]: https://github.com/AdguardTeam/AdGuardHome/issues/3904

View File

@ -0,0 +1,24 @@
// Package aghhttp provides some common methods to work with HTTP.
package aghhttp
import (
"fmt"
"io"
"net/http"
"github.com/AdguardTeam/golibs/log"
)
// OK responds with word OK.
func OK(w http.ResponseWriter) {
if _, err := io.WriteString(w, "OK\n"); err != nil {
log.Error("couldn't write body: %s", err)
}
}
// Error writes formatted message to w and also logs it.
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Error("%s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}

View File

@ -431,7 +431,7 @@ func (hp *hostsParser) writeMainHostRule(host string, ip net.IP) (added, addedPt
rwSuccess = "^$dnsrewrite=NOERROR;" rwSuccess = "^$dnsrewrite=NOERROR;"
rwSuccessPTR = "^$dnsrewrite=NOERROR;PTR;" rwSuccessPTR = "^$dnsrewrite=NOERROR;PTR;"
modLen = len("||") + len(rwSuccess) modLen = len("||") + len(rwSuccess) + len(";")
modLenPTR = len("||") + len(rwSuccessPTR) modLenPTR = len("||") + len(rwSuccessPTR)
) )

View File

@ -118,3 +118,11 @@ func IfaceDNSIPAddrs(
return addrs, nil return addrs, nil
} }
// interfaceName is a string containing network interface's name. The name is
// used in file walking methods.
type interfaceName string
// Use interfaceName in the OS-independent code since it's actually only used in
// several OS-dependent implementations which causes linting issues.
var _ = interfaceName("")

View File

@ -4,13 +4,11 @@ package aghnet
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net" "net"
"os"
"os/exec" "os/exec"
"runtime"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -189,57 +187,30 @@ func GetSubnet(ifaceName string) *net.IPNet {
return nil return nil
} }
// CheckPortAvailable - check if TCP port is available // CheckPort checks if the port is available for binding.
func CheckPortAvailable(host net.IP, port int) error { func CheckPort(network string, ip net.IP, port int) (err error) {
ln, err := net.Listen("tcp", netutil.JoinHostPort(host.String(), port)) var c io.Closer
if err != nil { addr := netutil.IPPort{IP: ip, Port: port}.String()
return err switch network {
} case "tcp":
_ = ln.Close() c, err = net.Listen(network, addr)
case "udp":
// It seems that net.Listener.Close() doesn't close file descriptors right away. c, err = net.ListenPacket(network, addr)
// We wait for some time and hope that this fd will be closed. default:
time.Sleep(100 * time.Millisecond)
return nil return nil
} }
// CheckPacketPortAvailable - check if UDP port is available return errors.WithDeferred(err, closePortChecker(c))
func CheckPacketPortAvailable(host net.IP, port int) error {
ln, err := net.ListenPacket("udp", netutil.JoinHostPort(host.String(), port))
if err != nil {
return err
}
_ = ln.Close()
// It seems that net.Listener.Close() doesn't close file descriptors right away.
// We wait for some time and hope that this fd will be closed.
time.Sleep(100 * time.Millisecond)
return err
} }
// ErrorIsAddrInUse - check if error is "address already in use" // IsAddrInUse checks if err is about unsuccessful address binding.
func ErrorIsAddrInUse(err error) bool { func IsAddrInUse(err error) (ok bool) {
errOpError, ok := err.(*net.OpError) var sysErr syscall.Errno
if !ok { if !errors.As(err, &sysErr) {
return false return false
} }
errSyscallError, ok := errOpError.Err.(*os.SyscallError) return isAddrInUse(sysErr)
if !ok {
return false
}
errErrno, ok := errSyscallError.Err.(syscall.Errno)
if !ok {
return false
}
if runtime.GOOS == "windows" {
const WSAEADDRINUSE = 10048
return errErrno == WSAEADDRINUSE
}
return errErrno == syscall.EADDRINUSE
} }
// SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport // SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport

View File

@ -1,20 +0,0 @@
//go:build !(linux || darwin || freebsd || openbsd)
// +build !linux,!darwin,!freebsd,!openbsd
package aghnet
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(string) (ok bool, err error) {
return false, aghos.Unsupported("checking static ip")
}
func ifaceSetStaticIP(string) (err error) {
return aghos.Unsupported("setting static ip")
}

View File

@ -1,8 +1,19 @@
//go:build openbsd || freebsd || linux //go:build openbsd || freebsd || linux || darwin
// +build openbsd freebsd linux // +build openbsd freebsd linux darwin
package aghnet package aghnet
// interfaceName is a string containing network interface's name. The name is import (
// used in file walking methods. "io"
type interfaceName string "syscall"
"github.com/AdguardTeam/golibs/errors"
)
func closePortChecker(c io.Closer) (err error) {
return c.Close()
}
func isAddrInUse(err syscall.Errno) (ok bool) {
return errors.Is(err, syscall.EADDRINUSE)
}

View File

@ -0,0 +1,45 @@
//go:build !(linux || darwin || freebsd || openbsd)
// +build !linux,!darwin,!freebsd,!openbsd
package aghnet
import (
"io"
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/sys/windows"
)
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(string) (ok bool, err error) {
return false, aghos.Unsupported("checking static ip")
}
func ifaceSetStaticIP(string) (err error) {
return aghos.Unsupported("setting static ip")
}
func closePortChecker(c io.Closer) (err error) {
if err = c.Close(); err != nil {
return err
}
// It seems that net.Listener.Close() doesn't close file descriptors right
// away. We wait for some time and hope that this fd will be closed.
//
// TODO(e.burkov): Investigate the purpose of the line and perhaps use more
// reliable approach.
time.Sleep(100 * time.Millisecond)
return nil
}
func isAddrInUse(err syscall.Errno) (ok bool) {
return errors.Is(err, windows.WSAEADDRINUSE)
}

View File

@ -10,18 +10,13 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
) )
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("DHCP: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
type v4ServerConfJSON struct { type v4ServerConfJSON struct {
GatewayIP net.IP `json:"gateway_ip"` GatewayIP net.IP `json:"gateway_ip"`
SubnetMask net.IP `json:"subnet_mask"` SubnetMask net.IP `json:"subnet_mask"`
@ -87,8 +82,13 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(status) err := json.NewEncoder(w).Encode(status)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal DHCP status json: %s", err) aghhttp.Error(
return r,
w,
http.StatusInternalServerError,
"Unable to marshal DHCP status json: %s",
err,
)
} }
} }
@ -211,36 +211,34 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
err := json.NewDecoder(r.Body).Decode(conf) err := json.NewDecoder(r.Body).Decode(conf)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse new dhcp config json: %s", err)
"failed to parse new dhcp config json: %s", err)
return return
} }
srv4, v4Enabled, err := s.handleDHCPSetConfigV4(conf) srv4, v4Enabled, err := s.handleDHCPSetConfigV4(conf)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "bad dhcpv4 configuration: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "bad dhcpv4 configuration: %s", err)
return return
} }
srv6, v6Enabled, err := s.handleDHCPSetConfigV6(conf) srv6, v6Enabled, err := s.handleDHCPSetConfigV6(conf)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "bad dhcpv6 configuration: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "bad dhcpv6 configuration: %s", err)
return return
} }
if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled { if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled {
httpError(r, w, http.StatusBadRequest, aghhttp.Error(r, w, http.StatusBadRequest, "dhcpv4 or dhcpv6 configuration must be complete")
"dhcpv4 or dhcpv6 configuration must be complete")
return return
} }
err = s.Stop() err = s.Stop()
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
return return
} }
@ -265,7 +263,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
err = s.dbLoad() err = s.dbLoad()
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "loading leases db: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "loading leases db: %s", err)
return return
} }
@ -274,9 +272,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
var code int var code int
code, err = s.enableDHCP(conf.InterfaceName) code, err = s.enableDHCP(conf.InterfaceName)
if err != nil { if err != nil {
httpError(r, w, code, "enabling dhcp: %s", err) aghhttp.Error(r, w, code, "enabling dhcp: %s", err)
return
} }
} }
} }
@ -295,7 +291,8 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
ifaces, err := net.Interfaces() ifaces, err := net.Interfaces()
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return return
} }
@ -312,7 +309,15 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
var addrs []net.Addr var addrs []net.Addr
addrs, err = iface.Addrs() addrs, err = iface.Addrs()
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err) aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to get addresses for interface %s: %s",
iface.Name,
err,
)
return return
} }
@ -329,7 +334,13 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
ipnet, ok := addr.(*net.IPNet) ipnet, ok := addr.(*net.IPNet)
if !ok { if !ok {
// not an IPNet, should not happen // not an IPNet, should not happen
httpError(r, w, http.StatusInternalServerError, "SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"got iface.Addrs() element %[1]s that is not net.IPNet, it is %[1]T",
addr)
return return
} }
// ignore link-local // ignore link-local
@ -350,8 +361,13 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
err = json.NewEncoder(w).Encode(response) err = json.NewEncoder(w).Encode(response)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to marshal json with available interfaces: %s", err) aghhttp.Error(
return r,
w,
http.StatusInternalServerError,
"Failed to marshal json with available interfaces: %s",
err,
)
} }
} }
@ -455,9 +471,13 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(result) err = json.NewEncoder(w).Encode(result)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to marshal DHCP found json: %s", err) aghhttp.Error(
r,
return w,
http.StatusInternalServerError,
"Failed to marshal DHCP found json: %s",
err,
)
} }
} }
@ -465,13 +485,13 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
l := &Lease{} l := &Lease{}
err := json.NewDecoder(r.Body).Decode(l) err := json.NewDecoder(r.Body).Decode(l)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return return
} }
if l.IP == nil { if l.IP == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP") aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
return return
} }
@ -483,7 +503,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
err = s.srv6.AddStaticLease(l) err = s.srv6.AddStaticLease(l)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
} }
return return
@ -492,7 +512,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
l.IP = ip4 l.IP = ip4
err = s.srv4.AddStaticLease(l) err = s.srv4.AddStaticLease(l)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
@ -502,13 +522,13 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
l := &Lease{} l := &Lease{}
err := json.NewDecoder(r.Body).Decode(l) err := json.NewDecoder(r.Body).Decode(l)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return return
} }
if l.IP == nil { if l.IP == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP") aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
return return
} }
@ -520,7 +540,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
err = s.srv6.RemoveStaticLease(l) err = s.srv6.RemoveStaticLease(l)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
} }
return return
@ -529,7 +549,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
l.IP = ip4 l.IP = ip4
err = s.srv4.RemoveStaticLease(l) err = s.srv4.RemoveStaticLease(l)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
@ -545,7 +565,7 @@ const (
func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) { func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
err := s.Stop() err := s.Stop()
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
return return
} }
@ -583,7 +603,7 @@ func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
err := s.resetLeases() err := s.resetLeases()
if err != nil { if err != nil {
msg := "resetting leases: %s" msg := "resetting leases: %s"
httpError(r, w, http.StatusInternalServerError, msg, err) aghhttp.Error(r, w, http.StatusInternalServerError, msg, err)
return return
} }

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@ -187,7 +188,7 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(j) err := json.NewEncoder(w).Encode(j)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err)
return return
} }
@ -251,14 +252,14 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
list := accessListJSON{} list := accessListJSON{}
err := json.NewDecoder(r.Body).Decode(&list) err := json.NewDecoder(r.Body).Decode(&list)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "decoding request: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
return return
} }
err = validateAccessSet(list) err = validateAccessSet(list)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, err.Error()) aghhttp.Error(r, w, http.StatusBadRequest, err.Error())
return return
} }
@ -266,7 +267,7 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
var a *accessCtx var a *accessCtx
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts) a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
return return
} }

View File

@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -18,12 +19,6 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("dns: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
type dnsConfig struct { type dnsConfig struct {
Upstreams *[]string `json:"upstream_dns"` Upstreams *[]string `json:"upstream_dns"`
UpstreamsFile *string `json:"upstream_dns_file"` UpstreamsFile *string `json:"upstream_dns_file"`
@ -119,7 +114,8 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err = json.NewEncoder(w).Encode(resp); err != nil { if err = json.NewEncoder(w).Encode(resp); err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encoder: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encoder: %s", err)
return return
} }
} }
@ -198,34 +194,52 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
req := dnsConfig{} req := dnsConfig{}
dec := json.NewDecoder(r.Body) dec := json.NewDecoder(r.Body)
if err := dec.Decode(&req); err != nil { if err := dec.Decode(&req); err != nil {
httpError(r, w, http.StatusBadRequest, "json Encode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json Encode: %s", err)
return return
} }
if req.Upstreams != nil { if req.Upstreams != nil {
if err := ValidateUpstreams(*req.Upstreams); err != nil { if err := ValidateUpstreams(*req.Upstreams); err != nil {
httpError(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err)
return return
} }
} }
if errBoot, err := req.checkBootstrap(); err != nil { if errBoot, err := req.checkBootstrap(); err != nil {
httpError(r, w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", errBoot, err) aghhttp.Error(
r,
w,
http.StatusBadRequest,
"%s can not be used as bootstrap dns cause: %s",
errBoot,
err,
)
return return
} }
if !req.checkBlockingMode() { if !req.checkBlockingMode() {
httpError(r, w, http.StatusBadRequest, "blocking_mode: incorrect value") aghhttp.Error(r, w, http.StatusBadRequest, "blocking_mode: incorrect value")
return return
} }
if !req.checkUpstreamsMode() { if !req.checkUpstreamsMode() {
httpError(r, w, http.StatusBadRequest, "upstream_mode: incorrect value") aghhttp.Error(r, w, http.StatusBadRequest, "upstream_mode: incorrect value")
return return
} }
if !req.checkCacheTTL() { if !req.checkCacheTTL() {
httpError(r, w, http.StatusBadRequest, "cache_ttl_min must be less or equal than cache_ttl_max") aghhttp.Error(
r,
w,
http.StatusBadRequest,
"cache_ttl_min must be less or equal than cache_ttl_max",
)
return return
} }
@ -234,8 +248,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
if restart { if restart {
if err := s.Reconfigure(nil); err != nil { if err := s.Reconfigure(nil); err != nil {
httpError(r, w, http.StatusInternalServerError, "%s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
} }
} }
} }
@ -582,7 +595,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
req := &upstreamJSON{} req := &upstreamJSON{}
err := json.NewDecoder(r.Body).Decode(req) err := json.NewDecoder(r.Body).Decode(req)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "Failed to read request body: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
return return
} }
@ -620,7 +633,13 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(result) jsonVal, err := json.Marshal(result)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal status json: %s",
err,
)
return return
} }
@ -628,9 +647,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
return
} }
} }
@ -641,12 +658,12 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
// -> dnsforward.handleDNSRequest // -> dnsforward.handleDNSRequest
func (s *Server) handleDoH(w http.ResponseWriter, r *http.Request) { func (s *Server) handleDoH(w http.ResponseWriter, r *http.Request) {
if !s.conf.TLSAllowUnencryptedDoH && r.TLS == nil { if !s.conf.TLSAllowUnencryptedDoH && r.TLS == nil {
httpError(r, w, http.StatusNotFound, "Not Found") aghhttp.Error(r, w, http.StatusNotFound, "Not Found")
return return
} }
if !s.IsRunning() { if !s.IsRunning() {
httpError(r, w, http.StatusInternalServerError, "dns server is not running") aghhttp.Error(r, w, http.StatusInternalServerError, "dns server is not running")
return return
} }

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
) )
@ -287,7 +288,8 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(list) err := json.NewEncoder(w).Encode(list)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
return return
} }
} }
@ -296,7 +298,8 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
list := []string{} list := []string{}
err := json.NewDecoder(r.Body).Decode(&list) err := json.NewDecoder(r.Body).Decode(&list)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return return
} }

View File

@ -864,8 +864,7 @@ func makeResult(matchedRules []rules.Rule, reason Reason) (res Result) {
} }
} }
// InitModule manually initializes blocked services map using blockedSvcListID // InitModule manually initializes blocked services map.
// as list ID for the rules.
func InitModule() { func InitModule() {
initBlockedServices() initBlockedServices()
} }

View File

@ -9,6 +9,7 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -206,7 +207,8 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(arr) err := json.NewEncoder(w).Encode(arr)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
return return
} }
} }
@ -215,7 +217,8 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
jsent := rewriteEntryJSON{} jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent) err := json.NewDecoder(r.Body).Decode(&jsent)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return return
} }
@ -237,7 +240,8 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
jsent := rewriteEntryJSON{} jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent) err := json.NewDecoder(r.Body).Decode(&jsent)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return return
} }

View File

@ -13,6 +13,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -368,12 +369,6 @@ func (d *DNSFilter) checkParental(
return check(sctx, res, d.parentalUpstream) return check(sctx, res, d.parentalUpstream)
} }
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeBrowsingEnabled = true d.Config.SafeBrowsingEnabled = true
d.Config.ConfigModified() d.Config.ConfigModified()
@ -392,7 +387,8 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ
Enabled: d.Config.SafeBrowsingEnabled, Enabled: d.Config.SafeBrowsingEnabled,
}) })
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return return
} }
} }
@ -415,8 +411,7 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
Enabled: d.Config.ParentalEnabled, Enabled: d.Config.ParentalEnabled,
}) })
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
} }
} }

View File

@ -11,6 +11,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -152,8 +153,13 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques
Enabled: d.Config.SafeSearchEnabled, Enabled: d.Config.SafeSearchEnabled,
}) })
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) aghhttp.Error(
return r,
w,
http.StatusInternalServerError,
"Unable to write response json: %s",
err,
)
} }
} }

View File

@ -13,6 +13,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
@ -417,7 +418,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
req := loginJSON{} req := loginJSON{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return return
} }
@ -429,7 +430,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
// //
// TODO(e.burkov): Use realIP when the issue will be fixed. // TODO(e.burkov): Use realIP when the issue will be fixed.
if remoteAddr, err = netutil.SplitHost(r.RemoteAddr); err != nil { if remoteAddr, err = netutil.SplitHost(r.RemoteAddr); err != nil {
httpError(w, http.StatusBadRequest, "auth: getting remote address: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "auth: getting remote address: %s", err)
return return
} }
@ -437,7 +438,8 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
if blocker := Context.auth.blocker; blocker != nil { if blocker := Context.auth.blocker; blocker != nil {
if left := blocker.check(remoteAddr); left > 0 { if left := blocker.check(remoteAddr); left > 0 {
w.Header().Set("Retry-After", strconv.Itoa(int(left.Seconds()))) w.Header().Set("Retry-After", strconv.Itoa(int(left.Seconds())))
httpError( aghhttp.Error(
r,
w, w,
http.StatusTooManyRequests, http.StatusTooManyRequests,
"auth: blocked for %s", "auth: blocked for %s",
@ -451,7 +453,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
var cookie string var cookie string
cookie, err = Context.auth.httpCookie(req, remoteAddr) cookie, err = Context.auth.httpCookie(req, remoteAddr)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "crypto rand reader: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "crypto rand reader: %s", err)
return return
} }
@ -480,7 +482,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Pragma", "no-cache") w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0") w.Header().Set("Expires", "0")
returnOK(w) aghhttp.OK(w)
} }
func handleLogout(w http.ResponseWriter, r *http.Request) { func handleLogout(w http.ResponseWriter, r *http.Request) {

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"net/http" "net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -58,7 +59,7 @@ type clientListJSON struct {
} }
// respond with information about configured clients // respond with information about configured clients
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http.Request) { func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) {
data := clientListJSON{} data := clientListJSON{}
clients.lock.Lock() clients.lock.Lock()
@ -106,7 +107,14 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data) e := json.NewEncoder(w).Encode(data)
if e != nil { if e != nil {
httpError(w, http.StatusInternalServerError, "Failed to encode to json: %v", e) aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to encode to json: %v",
e,
)
return return
} }
} }
@ -154,7 +162,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
cj := clientJSON{} cj := clientJSON{}
err := json.NewDecoder(r.Body).Decode(&cj) err := json.NewDecoder(r.Body).Decode(&cj)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
return return
} }
@ -162,11 +170,14 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
c := jsonToClient(cj) c := jsonToClient(cj)
ok, err := clients.Add(c) ok, err := clients.Add(c)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
if !ok { if !ok {
httpError(w, http.StatusBadRequest, "Client already exists") aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists")
return return
} }
@ -178,19 +189,19 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
cj := clientJSON{} cj := clientJSON{}
err := json.NewDecoder(r.Body).Decode(&cj) err := json.NewDecoder(r.Body).Decode(&cj)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
return return
} }
if len(cj.Name) == 0 { if len(cj.Name) == 0 {
httpError(w, http.StatusBadRequest, "client's name must be non-empty") aghhttp.Error(r, w, http.StatusBadRequest, "client's name must be non-empty")
return return
} }
if !clients.Del(cj.Name) { if !clients.Del(cj.Name) {
httpError(w, http.StatusBadRequest, "Client not found") aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
return return
} }
@ -207,20 +218,22 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
dj := updateJSON{} dj := updateJSON{}
err := json.NewDecoder(r.Body).Decode(&dj) err := json.NewDecoder(r.Body).Decode(&dj)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
return return
} }
if len(dj.Name) == 0 { if len(dj.Name) == 0 {
httpError(w, http.StatusBadRequest, "Invalid request") aghhttp.Error(r, w, http.StatusBadRequest, "Invalid request")
return return
} }
c := jsonToClient(dj.Data) c := jsonToClient(dj.Data)
err = clients.Update(dj.Name, c) err = clients.Update(dj.Name, c)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
@ -256,7 +269,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(data) err := json.NewEncoder(w).Encode(data)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write response: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write response: %s", err)
} }
} }

View File

@ -1,7 +1,6 @@
package home package home
import ( import (
"fmt"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@ -274,17 +273,34 @@ func getLogSettings() logSettings {
} }
// parseConfig loads configuration from the YAML file // parseConfig loads configuration from the YAML file
func parseConfig() error { func parseConfig() (err error) {
configFile := config.getConfigFilename() var fileData []byte
log.Debug("Reading config file: %s", configFile) fileData, err = readConfigFile()
yamlFile, err := readConfigFile()
if err != nil { if err != nil {
return err return err
} }
config.fileData = nil config.fileData = nil
err = yaml.Unmarshal(yamlFile, &config) err = yaml.Unmarshal(fileData, &config)
if err != nil { if err != nil {
log.Error("Couldn't parse config file: %s", err) return err
}
pm := portsMap{}
pm.add(
config.BindPort,
config.BetaBindPort,
config.DNS.Port,
)
if config.TLS.Enabled {
pm.add(
config.TLS.PortHTTPS,
config.TLS.PortDNSOverTLS,
config.TLS.PortDNSOverQUIC,
config.TLS.PortDNSCrypt,
)
}
if err = pm.validate(); err != nil {
return err return err
} }
@ -299,18 +315,17 @@ func parseConfig() error {
return nil return nil
} }
// readConfigFile reads config file contents if it exists // readConfigFile reads configuration file contents.
func readConfigFile() ([]byte, error) { func readConfigFile() (fileData []byte, err error) {
if len(config.fileData) != 0 { if len(config.fileData) > 0 {
return config.fileData, nil return config.fileData, nil
} }
configFile := config.getConfigFilename() name := config.getConfigFilename()
d, err := os.ReadFile(configFile) log.Debug("reading config file: %s", name)
if err != nil {
return nil, fmt.Errorf("couldn't read config file %s: %w", configFile, err) // Do not wrap the error because it's informative enough as is.
} return os.ReadFile(name)
return d, nil
} }
// Saves configuration to the YAML file and also saves the user filter contents to a file // Saves configuration to the YAML file and also saves the user filter contents to a file

View File

@ -9,6 +9,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
@ -17,23 +18,6 @@ import (
"github.com/NYTimes/gziphandler" "github.com/NYTimes/gziphandler"
) )
// ----------------
// helper functions
// ----------------
func returnOK(w http.ResponseWriter) {
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info(text)
http.Error(w, text, code)
}
// appendDNSAddrs is a convenient helper for appending a formatted form of DNS // appendDNSAddrs is a convenient helper for appending a formatted form of DNS
// addresses to a slice of strings. // addresses to a slice of strings.
func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) { func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) {
@ -125,12 +109,12 @@ type statusResponse struct {
Language string `json:"language"` Language string `json:"language"`
} }
func handleStatus(w http.ResponseWriter, _ *http.Request) { func handleStatus(w http.ResponseWriter, r *http.Request) {
dnsAddrs, err := collectDNSAddresses() dnsAddrs, err := collectDNSAddresses()
if err != nil { if err != nil {
// Don't add a lot of formatting, since the error is already // Don't add a lot of formatting, since the error is already
// wrapped by collectDNSAddresses. // wrapped by collectDNSAddresses.
httpError(w, http.StatusInternalServerError, "%s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return return
} }
@ -165,7 +149,7 @@ func handleStatus(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp) err = json.NewEncoder(w).Encode(resp)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return return
} }
@ -182,7 +166,7 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
data, err := json.Marshal(pj) data, err := json.Marshal(pj)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json.Marshal: %s", err)
return return
} }
_, _ = w.Write(data) _, _ = w.Write(data)
@ -295,7 +279,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
host, err := netutil.SplitHost(r.Host) host, err := netutil.SplitHost(r.Host)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "bad host: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "bad host: %s", err)
return false return false
} }

View File

@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -49,7 +50,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
fj := filterAddJSON{} fj := filterAddJSON{}
err := json.NewDecoder(r.Body).Decode(&fj) err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return return
} }
@ -63,7 +65,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// Check for duplicates // Check for duplicates
if filterExists(fj.URL) { if filterExists(fj.URL) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL) aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
return return
} }
@ -79,17 +82,35 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// Download the filter contents // Download the filter contents
ok, err := f.update(&filt) ok, err := f.update(&filt)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err) aghhttp.Error(
return r,
} w,
if !ok { http.StatusBadRequest,
httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL) "Couldn't fetch filter from url %s: %s",
filt.URL,
err,
)
return return
} }
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it if !ok {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"Filter at the url %s is invalid (maybe it points to blank page?)",
filt.URL,
)
return
}
// URL is assumed valid so append it to filters, update config, write new
// file and reload it to engines.
if !filterAdd(filt) { if !filterAdd(filt) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL) aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
return return
} }
@ -98,7 +119,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount) _, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
} }
} }
@ -111,7 +132,8 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
req := request{} req := request{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "failed to parse request body json: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse request body json: %s", err)
return return
} }
@ -152,7 +174,7 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
_, err = fmt.Fprintf(w, "OK %d rules\n", deleted.RulesCount) _, err = fmt.Fprintf(w, "OK %d rules\n", deleted.RulesCount)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't write body: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "couldn't write body: %s", err)
} }
} }
@ -172,7 +194,8 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
fj := filterURLReq{} fj := filterURLReq{}
err := json.NewDecoder(r.Body).Decode(&fj) err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return return
} }
@ -228,7 +251,8 @@ func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque
// This use of ReadAll is safe, because request's body is now limited. // This use of ReadAll is safe, because request's body is now limited.
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
return return
} }
@ -250,7 +274,8 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
req := Req{} req := Req{}
err = json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return return
} }
@ -270,13 +295,15 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false) resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false)
}() }()
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return return
} }
js, err := json.Marshal(resp) js, err := json.Marshal(resp)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -335,13 +362,14 @@ func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request
jsonVal, err := json.Marshal(resp) jsonVal, err := json.Marshal(resp)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "http write: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err)
} }
} }
@ -350,12 +378,14 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
req := filteringConfig{} req := filteringConfig{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return return
} }
if !checkFiltersUpdateIntervalHours(req.Interval) { if !checkFiltersUpdateIntervalHours(req.Interval) {
httpError(w, http.StatusBadRequest, "Unsupported interval") aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
return return
} }
@ -408,7 +438,15 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err) aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"couldn't apply filtering: %s: %s",
host,
err,
)
return return
} }
@ -433,7 +471,8 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
js, err := json.Marshal(resp) js, err := json.Marshal(resp)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")

View File

@ -14,6 +14,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -34,7 +35,8 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
ifaces, err := aghnet.GetValidNetInterfacesForWeb() ifaces, err := aghnet.GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return return
} }
@ -46,7 +48,14 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data) err = json.NewEncoder(w).Encode(data)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err) aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal default addresses to json: %s",
err,
)
return return
} }
} }
@ -84,23 +93,32 @@ type checkConfigResp struct {
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
reqData := checkConfigReq{} reqData := checkConfigReq{}
respData := checkConfigResp{} respData := checkConfigResp{}
err := json.NewDecoder(r.Body).Decode(&reqData) err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
return return
} }
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort && reqData.Web.Port != config.BetaBindPort { pm := portsMap{}
err = aghnet.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port) pm.add(config.BindPort, config.BetaBindPort, reqData.Web.Port)
if err = pm.validate(); err != nil {
respData.Web.Status = err.Error()
} else if reqData.Web.Port != 0 {
err = aghnet.CheckPort("tcp", reqData.Web.IP, reqData.Web.Port)
if err != nil { if err != nil {
respData.Web.Status = err.Error() respData.Web.Status = err.Error()
} }
} }
if reqData.DNS.Port != 0 { pm.add(reqData.DNS.Port)
err = aghnet.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) if err = pm.validate(); err != nil {
respData.DNS.Status = err.Error()
} else if reqData.DNS.Port != 0 {
err = aghnet.CheckPort("udp", reqData.DNS.IP, reqData.DNS.Port)
if aghnet.ErrorIsAddrInUse(err) { if aghnet.IsAddrInUse(err) {
canAutofix := checkDNSStubListener() canAutofix := checkDNSStubListener()
if canAutofix && reqData.DNS.Autofix { if canAutofix && reqData.DNS.Autofix {
@ -109,7 +127,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
log.Error("Couldn't disable DNSStubListener: %s", err) log.Error("Couldn't disable DNSStubListener: %s", err)
} }
err = aghnet.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port) err = aghnet.CheckPort("udp", reqData.DNS.IP, reqData.DNS.Port)
canAutofix = false canAutofix = false
} }
@ -117,7 +135,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
} }
if err == nil { if err == nil {
err = aghnet.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port) err = aghnet.CheckPort("tcp", reqData.DNS.IP, reqData.DNS.Port)
} }
if err != nil { if err != nil {
@ -130,7 +148,8 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(respData) err = json.NewEncoder(w).Encode(respData)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal JSON: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to marshal JSON: %s", err)
return return
} }
} }
@ -287,21 +306,21 @@ func shutdownSrv(ctx context.Context, srv *http.Server) {
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
req, restartHTTP, err := decodeApplyConfigReq(r.Body) req, restartHTTP, err := decodeApplyConfigReq(r.Body)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
err = aghnet.CheckPacketPortAvailable(req.DNS.IP, req.DNS.Port) err = aghnet.CheckPort("udp", req.DNS.IP, req.DNS.Port)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
err = aghnet.CheckPortAvailable(req.DNS.IP, req.DNS.Port) err = aghnet.CheckPort("tcp", req.DNS.IP, req.DNS.Port)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
@ -315,28 +334,29 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
config.DNS.BindHosts = []net.IP{req.DNS.IP} config.DNS.BindHosts = []net.IP{req.DNS.IP}
config.DNS.Port = req.DNS.Port config.DNS.Port = req.DNS.Port
// TODO(e.burkov): StartMods() should be put in a separate goroutine at // TODO(e.burkov): StartMods() should be put in a separate goroutine at the
// the moment we'll allow setting up TLS in the initial configuration or // moment we'll allow setting up TLS in the initial configuration or the
// the configuration itself will use HTTPS protocol, because the // configuration itself will use HTTPS protocol, because the underlying
// underlying functions potentially restart the HTTPS server. // functions potentially restart the HTTPS server.
err = StartMods() err = StartMods()
if err != nil { if err != nil {
Context.firstRun = true Context.firstRun = true
copyInstallSettings(config, curConfig) copyInstallSettings(config, curConfig)
httpError(w, http.StatusInternalServerError, "%s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return return
} }
u := User{} u := &User{
u.Name = req.Username Name: req.Username,
Context.auth.UserAdd(&u, req.Password) }
Context.auth.UserAdd(u, req.Password)
err = config.write() err = config.write()
if err != nil { if err != nil {
Context.firstRun = true Context.firstRun = true
copyInstallSettings(config, curConfig) copyInstallSettings(config, curConfig)
httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err)
return return
} }
@ -347,7 +367,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
registerControlHandlers() registerControlHandlers()
returnOK(w) aghhttp.OK(w)
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
f.Flush() f.Flush()
} }
@ -386,7 +406,7 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
restartHTTP = !config.BindHost.Equal(req.Web.IP) || config.BindPort != req.Web.Port restartHTTP = !config.BindHost.Equal(req.Web.IP) || config.BindPort != req.Web.Port
if restartHTTP { if restartHTTP {
err = aghnet.CheckPortAvailable(req.Web.IP, req.Web.Port) err = aghnet.CheckPort("tcp", req.Web.IP, req.Web.Port)
if err != nil { if err != nil {
return nil, false, fmt.Errorf( return nil, false, fmt.Errorf(
"checking address %s:%d: %w", "checking address %s:%d: %w",
@ -437,12 +457,14 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
reqData := checkConfigReqBeta{} reqData := checkConfigReqBeta{}
err := json.NewDecoder(r.Body).Decode(&reqData) err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
return return
} }
if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 { if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 {
httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest)) aghhttp.Error(r, w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
return return
} }
@ -464,7 +486,14 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData) err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err) aghhttp.Error(
r,
w,
http.StatusBadRequest,
"Failed to encode 'check_config' JSON data: %s",
err,
)
return return
} }
body := nonBetaReqBody.String() body := nonBetaReqBody.String()
@ -505,12 +534,14 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
reqData := applyConfigReqBeta{} reqData := applyConfigReqBeta{}
err := json.NewDecoder(r.Body).Decode(&reqData) err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
return return
} }
if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 { if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 {
httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest)) aghhttp.Error(r, w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
return return
} }
@ -531,7 +562,14 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData) err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err) aghhttp.Error(
r,
w,
http.StatusBadRequest,
"Failed to encode 'check_config' JSON data: %s",
err,
)
return return
} }
body := nonBetaReqBody.String() body := nonBetaReqBody.String()
@ -564,7 +602,8 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req
ifaces, err := aghnet.GetValidNetInterfacesForWeb() ifaces, err := aghnet.GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return return
} }
@ -573,7 +612,14 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data) err = json.NewEncoder(w).Encode(data)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err) aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal default addresses to json: %s",
err,
)
return return
} }
} }

View File

@ -11,6 +11,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -43,7 +44,8 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
if r.ContentLength != 0 { if r.ContentLength != 0 {
err = json.NewDecoder(r.Body).Decode(req) err = json.NewDecoder(r.Body).Decode(req)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "JSON parse: %s", err)
return return
} }
} }
@ -77,7 +79,15 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
vcu := Context.updater.VersionCheckURL() vcu := Context.updater.VersionCheckURL()
// TODO(a.garipov): Figure out the purpose of %T verb. // TODO(a.garipov): Figure out the purpose of %T verb.
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", vcu, err, err) aghhttp.Error(
r,
w,
http.StatusBadGateway,
"Couldn't get version check json from %s: %T %s\n",
vcu,
err,
err,
)
return return
} }
@ -87,24 +97,26 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp) err = json.NewEncoder(w).Encode(resp)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
} }
} }
// handleUpdate performs an update to the latest available version procedure. // handleUpdate performs an update to the latest available version procedure.
func handleUpdate(w http.ResponseWriter, _ *http.Request) { func handleUpdate(w http.ResponseWriter, r *http.Request) {
if Context.updater.NewVersion() == "" { if Context.updater.NewVersion() == "" {
httpError(w, http.StatusBadRequest, "/update request isn't allowed now") aghhttp.Error(r, w, http.StatusBadRequest, "/update request isn't allowed now")
return return
} }
err := Context.updater.Update() err := Context.updater.Update()
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return return
} }
returnOK(w) aghhttp.OK(w)
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
f.Flush() f.Flush()
} }

View File

@ -159,14 +159,11 @@ func setupContext(args options) {
} }
if !Context.firstRun { if !Context.firstRun {
// Do the upgrade if necessary // Do the upgrade if necessary.
err := upgradeConfig() err := upgradeConfig()
if err != nil { fatalOnError(err)
log.Fatal(err)
}
err = parseConfig() if err = parseConfig(); err != nil {
if err != nil {
log.Error("parsing configuration file: %s", err) log.Error("parsing configuration file: %s", err)
os.Exit(1) os.Exit(1)
@ -186,15 +183,15 @@ func setupContext(args options) {
// unsupported errors and returns nil. If err is nil, logIfUnsupported returns // unsupported errors and returns nil. If err is nil, logIfUnsupported returns
// nil. Otherise, it returns err. // nil. Otherise, it returns err.
func logIfUnsupported(msg string, err error) (outErr error) { func logIfUnsupported(msg string, err error) (outErr error) {
if unsupErr := (&aghos.UnsupportedError{}); errors.As(err, &unsupErr) { if errors.As(err, new(*aghos.UnsupportedError)) {
log.Debug(msg, err) log.Debug(msg, err)
} else if err != nil {
return err
}
return nil return nil
} }
return err
}
// configureOS sets the OS-related configuration. // configureOS sets the OS-related configuration.
func configureOS(conf *configuration) (err error) { func configureOS(conf *configuration) (err error) {
osConf := conf.OSConfig osConf := conf.OSConfig
@ -297,13 +294,32 @@ func setupConfig(args options) (err error) {
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts) Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts)
if args.bindPort != 0 {
pm := portsMap{}
pm.add(
args.bindPort,
config.BetaBindPort,
config.DNS.Port,
)
if config.TLS.Enabled {
pm.add(
config.TLS.PortHTTPS,
config.TLS.PortDNSOverTLS,
config.TLS.PortDNSOverQUIC,
config.TLS.PortDNSCrypt,
)
}
if err = pm.validate(); err != nil {
return err
}
config.BindPort = args.bindPort
}
// override bind host/port from the console // override bind host/port from the console
if args.bindHost != nil { if args.bindHost != nil {
config.BindHost = args.bindHost config.BindHost = args.bindHost
} }
if args.bindPort != 0 {
config.BindPort = args.bindPort
}
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
Context.pidFileName = args.pidFile Context.pidFileName = args.pidFile
} }
@ -766,8 +782,7 @@ func printHTTPAddresses(proto string) {
port = tlsConf.PortHTTPS port = tlsConf.PortHTTPS
} }
// TODO(e.burkov): Inspect and perhaps merge with the previous // TODO(e.burkov): Inspect and perhaps merge with the previous condition.
// condition.
if proto == schemeHTTPS && tlsConf.ServerName != "" { if proto == schemeHTTPS && tlsConf.ServerName != "" {
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0) printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0)

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
) )
@ -96,5 +97,5 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
}() }()
onConfigModified() onConfigModified()
returnOK(w) aghhttp.OK(w)
} }

63
internal/home/portsmap.go Normal file
View File

@ -0,0 +1,63 @@
package home
import (
"fmt"
"strconv"
"strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
)
// portsMap is a helper type for mapping a network port to the number of its
// users.
type portsMap map[int]int
// add binds each of ps. Zeroes are skipped.
func (pm portsMap) add(ps ...int) {
for _, p := range ps {
if p == 0 {
continue
}
pm[p]++
}
}
// validate returns an error about all the ports bound several times.
func (pm portsMap) validate() (err error) {
overbound := []int{}
for p, num := range pm {
if num > 1 {
overbound = append(overbound, p)
pm[p] = 1
}
}
switch len(overbound) {
case 0:
return nil
case 1:
return fmt.Errorf("port %d is already used", overbound[0])
default:
b := &strings.Builder{}
// TODO(e.burkov, a.garipov): Add JoinToBuilder helper to stringutil.
stringutil.WriteToBuilder(b, "ports ", strconv.Itoa(overbound[0]))
for _, p := range overbound[1:] {
stringutil.WriteToBuilder(b, ", ", strconv.Itoa(p))
}
stringutil.WriteToBuilder(b, " are already used")
return errors.Error(b.String())
}
}
// validatePorts is a helper function for a single-step ports binding
// validation.
func validatePorts(ps ...int) (err error) {
pm := portsMap{}
pm.add(ps...)
return pm.validate()
}

View File

@ -20,6 +20,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -224,7 +225,7 @@ type tlsConfigSettingsExt struct {
PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"` PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"`
} }
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, _ *http.Request) { func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
t.confLock.Lock() t.confLock.Lock()
data := tlsConfig{ data := tlsConfig{
tlsConfigSettingsExt: tlsConfigSettingsExt{ tlsConfigSettingsExt: tlsConfigSettingsExt{
@ -233,13 +234,14 @@ func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, _ *http.Request) {
tlsConfigStatus: t.status, tlsConfigStatus: t.status,
} }
t.confLock.Unlock() t.confLock.Unlock()
marshalTLS(w, data) marshalTLS(w, r, data)
} }
func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) { func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
setts, err := unmarshalTLS(r) setts, err := unmarshalTLS(r)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return return
} }
@ -247,8 +249,31 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
setts.PrivateKey = t.conf.PrivateKey setts.PrivateKey = t.conf.PrivateKey
} }
if setts.Enabled {
if err = validatePorts(
config.BindPort,
config.BetaBindPort,
config.DNS.Port,
setts.PortHTTPS,
setts.PortDNSOverTLS,
setts.PortDNSOverQUIC,
setts.PortDNSCrypt,
); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
if !WebCheckPortAvailable(setts.PortHTTPS) { if !WebCheckPortAvailable(setts.PortHTTPS) {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS) aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable HTTPS on it",
setts.PortHTTPS,
)
return return
} }
@ -261,7 +286,8 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
tlsConfigSettingsExt: setts, tlsConfigSettingsExt: setts,
tlsConfigStatus: status, tlsConfigStatus: status,
} }
marshalTLS(w, data)
marshalTLS(w, r, data)
} }
func (t *TLSMod) setConfig(newConf tlsConfigSettings, status tlsConfigStatus) (restartHTTPS bool) { func (t *TLSMod) setConfig(newConf tlsConfigSettings, status tlsConfigStatus) (restartHTTPS bool) {
@ -302,7 +328,8 @@ func (t *TLSMod) setConfig(newConf tlsConfigSettings, status tlsConfigStatus) (r
func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
data, err := unmarshalTLS(r) data, err := unmarshalTLS(r)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return return
} }
@ -310,8 +337,32 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
data.PrivateKey = t.conf.PrivateKey data.PrivateKey = t.conf.PrivateKey
} }
if data.Enabled {
if err = validatePorts(
config.BindPort,
config.BetaBindPort,
config.DNS.Port,
data.PortHTTPS,
data.PortDNSOverTLS,
data.PortDNSOverQUIC,
data.PortDNSCrypt,
); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
// TODO(e.burkov): Investigate and perhaps check other ports.
if !WebCheckPortAvailable(data.PortHTTPS) { if !WebCheckPortAvailable(data.PortHTTPS) {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS) aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable HTTPS on it",
data.PortHTTPS,
)
return return
} }
@ -321,7 +372,7 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
tlsConfigSettingsExt: data, tlsConfigSettingsExt: data,
tlsConfigStatus: t.status, tlsConfigStatus: t.status,
} }
marshalTLS(w, data2) marshalTLS(w, r, data2)
return return
} }
@ -334,7 +385,7 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
err = reconfigureDNSServer() err = reconfigureDNSServer()
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return return
} }
@ -344,15 +395,15 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
tlsConfigStatus: t.status, tlsConfigStatus: t.status,
} }
marshalTLS(w, data2) marshalTLS(w, r, data2)
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
f.Flush() f.Flush()
} }
// The background context is used because the TLSConfigChanged wraps // The background context is used because the TLSConfigChanged wraps context
// context with timeout on its own and shuts down the server, which // with timeout on its own and shuts down the server, which handles current
// handles current request. It is also should be done in a separate // request. It is also should be done in a separate goroutine due to the
// goroutine due to the same reason. // same reason.
if restartHTTPS { if restartHTTPS {
go func() { go func() {
Context.web.TLSConfigChanged(context.Background(), data.tlsConfigSettings) Context.web.TLSConfigChanged(context.Background(), data.tlsConfigSettings)
@ -595,7 +646,7 @@ func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) {
return data, nil return data, nil
} }
func marshalTLS(w http.ResponseWriter, data tlsConfig) { func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if data.CertificateChain != "" { if data.CertificateChain != "" {
@ -610,8 +661,13 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) {
err := json.NewEncoder(w).Encode(data) err := json.NewEncoder(w).Encode(data)
if err != nil { if err != nil {
httpError(w, http.StatusInternalServerError, "Failed to marshal json with TLS status: %s", err) aghhttp.Error(
return r,
w,
http.StatusInternalServerError,
"Failed to marshal json with TLS status: %s",
err,
)
} }
} }

View File

@ -114,17 +114,8 @@ func CreateWeb(conf *webConfig) *Web {
// WebCheckPortAvailable - check if port is available // WebCheckPortAvailable - check if port is available
// BUT: if we are already using this port, no need // BUT: if we are already using this port, no need
func WebCheckPortAvailable(port int) bool { func WebCheckPortAvailable(port int) bool {
alreadyRunning := false return Context.web.httpsServer.server != nil ||
if Context.web.httpsServer.server != nil { aghnet.CheckPort("tcp", config.BindHost, port) == nil
alreadyRunning = true
}
if !alreadyRunning {
err := aghnet.CheckPortAvailable(config.BindHost, port)
if err != nil {
return false
}
}
return true
} }
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server // TLSConfigChanged updates the TLS configuration and restarts the HTTPS server

View File

@ -10,6 +10,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/jsonutil" "github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
@ -33,18 +34,11 @@ func (l *queryLog) initWeb() {
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig) l.conf.HTTPRegister(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig)
} }
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("QueryLog: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
params, err := l.parseSearchParams(r) params, err := l.parseSearchParams(r)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "failed to parse params: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse params: %s", err)
return return
} }
@ -56,14 +50,21 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(data) jsonVal, err := json.Marshal(data)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Couldn't marshal data into json: %s", err) aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Couldn't marshal data into json: %s",
err,
)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
} }
} }
@ -80,13 +81,15 @@ func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(resp) jsonVal, err := json.Marshal(resp)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "json encode: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "http write: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err)
} }
} }
@ -109,13 +112,15 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
d := &qlogConfig{} d := &qlogConfig{}
req, err := jsonutil.DecodeObject(d, r.Body) req, err := jsonutil.DecodeObject(d, r.Body)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
ivl := time.Duration(float64(timeutil.Day) * d.Interval) ivl := time.Duration(float64(timeutil.Day) * d.Interval)
if req.Exists("interval") && !checkInterval(ivl) { if req.Exists("interval") && !checkInterval(ivl) {
httpError(r, w, http.StatusBadRequest, "Unsupported interval") aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
return return
} }

View File

@ -4,21 +4,13 @@ package stats
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("Stats: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
// topAddrs is an alias for the types of the TopFoo fields of statsResponse. // topAddrs is an alias for the types of the TopFoo fields of statsResponse.
// The key is either a client's address or a requested address. // The key is either a client's address or a requested address.
type topAddrs = map[string]uint64 type topAddrs = map[string]uint64
@ -71,7 +63,7 @@ func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
log.Debug("stats: prepared data in %v", time.Since(start)) log.Debug("stats: prepared data in %v", time.Since(start))
if !ok { if !ok {
httpError(r, w, http.StatusInternalServerError, "Couldn't get statistics data") aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data")
return return
} }
@ -81,7 +73,7 @@ func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
err := json.NewEncoder(w).Encode(resp) err := json.NewEncoder(w).Encode(resp)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "json encode: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return return
} }
@ -98,13 +90,14 @@ func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
data, err := json.Marshal(resp) data, err := json.Marshal(resp)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "json encode: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(data) _, err = w.Write(data)
if err != nil { if err != nil {
httpError(r, w, http.StatusInternalServerError, "http write: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err)
} }
} }
@ -113,12 +106,14 @@ func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
reqData := config{} reqData := config{}
err := json.NewDecoder(r.Body).Decode(&reqData) err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "json decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return return
} }
if !checkInterval(reqData.IntervalDays) { if !checkInterval(reqData.IntervalDays) {
httpError(r, w, http.StatusBadRequest, "Unsupported interval") aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
return return
} }

View File

@ -6,6 +6,10 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/home" "github.com/AdguardTeam/AdGuardHome/internal/home"
) )
// Embed the prebuilt client here since we strive to keep .go files inside the
// internal directory and the embed package is unable to embed files located
// outside of the same or underlying directory.
//go:embed build build2 //go:embed build build2
var clientBuildFS embed.FS var clientBuildFS embed.FS