[change] control: add upstreams validation

This commit is contained in:
Aleksey Dmitrevskiy 2019-03-06 18:24:14 +03:00
parent 8e09424774
commit 1b15bee2b0

View File

@ -322,15 +322,21 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
return
}
for _, u := range newconfig.Upstreams {
if err = validateUpstream(u); err != nil {
httpError(w, http.StatusBadRequest, "%s can not be used as upstream cause: %s", u, err)
return
}
}
config.DNS.UpstreamDNS = defaultDNS
if len(newconfig.Upstreams) > 0 {
config.DNS.UpstreamDNS = newconfig.Upstreams
}
// bootstrap servers are plain DNS only. We should return http error if there are tls:// https:// or sdns:// hosts in slice
// bootstrap servers are plain DNS only.
for _, host := range newconfig.BootstrapDNS {
err := checkPlainDNS(host)
if err != nil {
if err := checkPlainDNS(host); err != nil {
httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err)
return
}
@ -345,26 +351,38 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func validateUpstream(upstream string) error {
if strings.HasPrefix(upstream, "tls://") || strings.HasPrefix(upstream, "https://") || strings.HasPrefix(upstream, "sdns://") || strings.HasPrefix(upstream, "tcp://") {
return nil
}
if strings.Contains(upstream, "://") {
return fmt.Errorf("wrong protocol")
}
return checkPlainDNS(upstream)
}
// checkPlainDNS checks if host is plain DNS
func checkPlainDNS(host string) error {
func checkPlainDNS(upstream string) error {
// Check if host is ip without port
if net.ParseIP(host) != nil {
if net.ParseIP(upstream) != nil {
return nil
}
// Check if host is ip with port
ip, port, err := net.SplitHostPort(host)
ip, port, err := net.SplitHostPort(upstream)
if err != nil {
return err
}
if net.ParseIP(ip) == nil {
return fmt.Errorf("%s is not valid IP", ip)
return fmt.Errorf("%s is not a valid IP", ip)
}
_, err = strconv.ParseInt(port, 0, 64)
if err != nil {
return fmt.Errorf("%s is not valid port: %s", port, err)
return fmt.Errorf("%s is not a valid port: %s", port, err)
}
return nil