[change] control: add upstreams validation
This commit is contained in:
parent
8e09424774
commit
1b15bee2b0
34
control.go
34
control.go
@ -322,15 +322,21 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
for _, u := range newconfig.Upstreams {
|
||||
if err = validateUpstream(u); err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s can not be used as upstream cause: %s", u, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
config.DNS.UpstreamDNS = defaultDNS
|
||||
if len(newconfig.Upstreams) > 0 {
|
||||
config.DNS.UpstreamDNS = newconfig.Upstreams
|
||||
}
|
||||
|
||||
// bootstrap servers are plain DNS only. We should return http error if there are tls:// https:// or sdns:// hosts in slice
|
||||
// bootstrap servers are plain DNS only.
|
||||
for _, host := range newconfig.BootstrapDNS {
|
||||
err := checkPlainDNS(host)
|
||||
if err != nil {
|
||||
if err := checkPlainDNS(host); err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err)
|
||||
return
|
||||
}
|
||||
@ -345,26 +351,38 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func validateUpstream(upstream string) error {
|
||||
if strings.HasPrefix(upstream, "tls://") || strings.HasPrefix(upstream, "https://") || strings.HasPrefix(upstream, "sdns://") || strings.HasPrefix(upstream, "tcp://") {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.Contains(upstream, "://") {
|
||||
return fmt.Errorf("wrong protocol")
|
||||
}
|
||||
|
||||
return checkPlainDNS(upstream)
|
||||
}
|
||||
|
||||
// checkPlainDNS checks if host is plain DNS
|
||||
func checkPlainDNS(host string) error {
|
||||
func checkPlainDNS(upstream string) error {
|
||||
// Check if host is ip without port
|
||||
if net.ParseIP(host) != nil {
|
||||
if net.ParseIP(upstream) != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if host is ip with port
|
||||
ip, port, err := net.SplitHostPort(host)
|
||||
ip, port, err := net.SplitHostPort(upstream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if net.ParseIP(ip) == nil {
|
||||
return fmt.Errorf("%s is not valid IP", ip)
|
||||
return fmt.Errorf("%s is not a valid IP", ip)
|
||||
}
|
||||
|
||||
_, err = strconv.ParseInt(port, 0, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s is not valid port: %s", port, err)
|
||||
return fmt.Errorf("%s is not a valid port: %s", port, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
Loading…
Reference in New Issue
Block a user