From 1b15bee2b09df5b2c31edf6d6d594aa53ac9124c Mon Sep 17 00:00:00 2001 From: Aleksey Dmitrevskiy Date: Wed, 6 Mar 2019 18:24:14 +0300 Subject: [PATCH] [change] control: add upstreams validation --- control.go | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/control.go b/control.go index fb1cf0af..fe340ecd 100644 --- a/control.go +++ b/control.go @@ -322,15 +322,21 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) { return } + for _, u := range newconfig.Upstreams { + if err = validateUpstream(u); err != nil { + httpError(w, http.StatusBadRequest, "%s can not be used as upstream cause: %s", u, err) + return + } + } + config.DNS.UpstreamDNS = defaultDNS if len(newconfig.Upstreams) > 0 { config.DNS.UpstreamDNS = newconfig.Upstreams } - // bootstrap servers are plain DNS only. We should return http error if there are tls:// https:// or sdns:// hosts in slice + // bootstrap servers are plain DNS only. for _, host := range newconfig.BootstrapDNS { - err := checkPlainDNS(host) - if err != nil { + if err := checkPlainDNS(host); err != nil { httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err) return } @@ -345,26 +351,38 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) { httpUpdateConfigReloadDNSReturnOK(w, r) } +func validateUpstream(upstream string) error { + if strings.HasPrefix(upstream, "tls://") || strings.HasPrefix(upstream, "https://") || strings.HasPrefix(upstream, "sdns://") || strings.HasPrefix(upstream, "tcp://") { + return nil + } + + if strings.Contains(upstream, "://") { + return fmt.Errorf("wrong protocol") + } + + return checkPlainDNS(upstream) +} + // checkPlainDNS checks if host is plain DNS -func checkPlainDNS(host string) error { +func checkPlainDNS(upstream string) error { // Check if host is ip without port - if net.ParseIP(host) != nil { + if net.ParseIP(upstream) != nil { return nil } // Check if host is ip with port - ip, port, err := net.SplitHostPort(host) + ip, port, err := net.SplitHostPort(upstream) if err != nil { return err } if net.ParseIP(ip) == nil { - return fmt.Errorf("%s is not valid IP", ip) + return fmt.Errorf("%s is not a valid IP", ip) } _, err = strconv.ParseInt(port, 0, 64) if err != nil { - return fmt.Errorf("%s is not valid port: %s", port, err) + return fmt.Errorf("%s is not a valid port: %s", port, err) } return nil