[change] control: add upstreams validation
This commit is contained in:
parent
8e09424774
commit
1b15bee2b0
34
control.go
34
control.go
|
@ -322,15 +322,21 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, u := range newconfig.Upstreams {
|
||||||
|
if err = validateUpstream(u); err != nil {
|
||||||
|
httpError(w, http.StatusBadRequest, "%s can not be used as upstream cause: %s", u, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
config.DNS.UpstreamDNS = defaultDNS
|
config.DNS.UpstreamDNS = defaultDNS
|
||||||
if len(newconfig.Upstreams) > 0 {
|
if len(newconfig.Upstreams) > 0 {
|
||||||
config.DNS.UpstreamDNS = newconfig.Upstreams
|
config.DNS.UpstreamDNS = newconfig.Upstreams
|
||||||
}
|
}
|
||||||
|
|
||||||
// bootstrap servers are plain DNS only. We should return http error if there are tls:// https:// or sdns:// hosts in slice
|
// bootstrap servers are plain DNS only.
|
||||||
for _, host := range newconfig.BootstrapDNS {
|
for _, host := range newconfig.BootstrapDNS {
|
||||||
err := checkPlainDNS(host)
|
if err := checkPlainDNS(host); err != nil {
|
||||||
if err != nil {
|
|
||||||
httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err)
|
httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -345,26 +351,38 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateUpstream(upstream string) error {
|
||||||
|
if strings.HasPrefix(upstream, "tls://") || strings.HasPrefix(upstream, "https://") || strings.HasPrefix(upstream, "sdns://") || strings.HasPrefix(upstream, "tcp://") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(upstream, "://") {
|
||||||
|
return fmt.Errorf("wrong protocol")
|
||||||
|
}
|
||||||
|
|
||||||
|
return checkPlainDNS(upstream)
|
||||||
|
}
|
||||||
|
|
||||||
// checkPlainDNS checks if host is plain DNS
|
// checkPlainDNS checks if host is plain DNS
|
||||||
func checkPlainDNS(host string) error {
|
func checkPlainDNS(upstream string) error {
|
||||||
// Check if host is ip without port
|
// Check if host is ip without port
|
||||||
if net.ParseIP(host) != nil {
|
if net.ParseIP(upstream) != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if host is ip with port
|
// Check if host is ip with port
|
||||||
ip, port, err := net.SplitHostPort(host)
|
ip, port, err := net.SplitHostPort(upstream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if net.ParseIP(ip) == nil {
|
if net.ParseIP(ip) == nil {
|
||||||
return fmt.Errorf("%s is not valid IP", ip)
|
return fmt.Errorf("%s is not a valid IP", ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = strconv.ParseInt(port, 0, 64)
|
_, err = strconv.ParseInt(port, 0, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s is not valid port: %s", port, err)
|
return fmt.Errorf("%s is not a valid port: %s", port, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
Loading…
Reference in New Issue