[change] control: add upstreams validation

This commit is contained in:
Aleksey Dmitrevskiy 2019-03-06 18:24:14 +03:00
parent 8e09424774
commit 1b15bee2b0
1 changed files with 26 additions and 8 deletions

View File

@ -322,15 +322,21 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
for _, u := range newconfig.Upstreams {
if err = validateUpstream(u); err != nil {
httpError(w, http.StatusBadRequest, "%s can not be used as upstream cause: %s", u, err)
return
}
}
config.DNS.UpstreamDNS = defaultDNS config.DNS.UpstreamDNS = defaultDNS
if len(newconfig.Upstreams) > 0 { if len(newconfig.Upstreams) > 0 {
config.DNS.UpstreamDNS = newconfig.Upstreams config.DNS.UpstreamDNS = newconfig.Upstreams
} }
// bootstrap servers are plain DNS only. We should return http error if there are tls:// https:// or sdns:// hosts in slice // bootstrap servers are plain DNS only.
for _, host := range newconfig.BootstrapDNS { for _, host := range newconfig.BootstrapDNS {
err := checkPlainDNS(host) if err := checkPlainDNS(host); err != nil {
if err != nil {
httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err) httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err)
return return
} }
@ -345,26 +351,38 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
httpUpdateConfigReloadDNSReturnOK(w, r) httpUpdateConfigReloadDNSReturnOK(w, r)
} }
func validateUpstream(upstream string) error {
if strings.HasPrefix(upstream, "tls://") || strings.HasPrefix(upstream, "https://") || strings.HasPrefix(upstream, "sdns://") || strings.HasPrefix(upstream, "tcp://") {
return nil
}
if strings.Contains(upstream, "://") {
return fmt.Errorf("wrong protocol")
}
return checkPlainDNS(upstream)
}
// checkPlainDNS checks if host is plain DNS // checkPlainDNS checks if host is plain DNS
func checkPlainDNS(host string) error { func checkPlainDNS(upstream string) error {
// Check if host is ip without port // Check if host is ip without port
if net.ParseIP(host) != nil { if net.ParseIP(upstream) != nil {
return nil return nil
} }
// Check if host is ip with port // Check if host is ip with port
ip, port, err := net.SplitHostPort(host) ip, port, err := net.SplitHostPort(upstream)
if err != nil { if err != nil {
return err return err
} }
if net.ParseIP(ip) == nil { if net.ParseIP(ip) == nil {
return fmt.Errorf("%s is not valid IP", ip) return fmt.Errorf("%s is not a valid IP", ip)
} }
_, err = strconv.ParseInt(port, 0, 64) _, err = strconv.ParseInt(port, 0, 64)
if err != nil { if err != nil {
return fmt.Errorf("%s is not valid port: %s", port, err) return fmt.Errorf("%s is not a valid port: %s", port, err)
} }
return nil return nil