2018-08-30 14:25:33 +00:00
package dnsfilter
import (
"bufio"
"errors"
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
"github.com/mholt/caddy"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"github.com/AdguardTeam/AdguardDNS/dnsfilter"
"golang.org/x/net/context"
)
var defaultSOA = & dns . SOA {
// values copied from verisign's nonexistent .com domain
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
Refresh : 1800 ,
Retry : 900 ,
Expire : 604800 ,
Minttl : 86400 ,
}
func init ( ) {
caddy . RegisterPlugin ( "dnsfilter" , caddy . Plugin {
ServerType : "dns" ,
Action : setup ,
} )
}
type Plugin struct {
d * dnsfilter . Dnsfilter
Next plugin . Handler
upstream upstream . Upstream
hosts map [ string ] net . IP
SafeBrowsingBlockHost string
ParentalBlockHost string
QueryLogEnabled bool
}
var defaultPlugin = Plugin {
SafeBrowsingBlockHost : "safebrowsing.block.dns.adguard.com" ,
ParentalBlockHost : "family.block.dns.adguard.com" ,
}
func newDnsCounter ( name string , help string ) prometheus . Counter {
return prometheus . NewCounter ( prometheus . CounterOpts {
Namespace : plugin . Namespace ,
Subsystem : "dnsfilter" ,
Name : name ,
Help : help ,
} )
}
var (
requests = newDnsCounter ( "requests_total" , "Count of requests seen by dnsfilter." )
filtered = newDnsCounter ( "filtered_total" , "Count of requests filtered by dnsfilter." )
filteredLists = newDnsCounter ( "filtered_lists_total" , "Count of requests filtered by dnsfilter using lists." )
filteredSafebrowsing = newDnsCounter ( "filtered_safebrowsing_total" , "Count of requests filtered by dnsfilter using safebrowsing." )
filteredParental = newDnsCounter ( "filtered_parental_total" , "Count of requests filtered by dnsfilter using parental." )
filteredInvalid = newDnsCounter ( "filtered_invalid_total" , "Count of requests filtered by dnsfilter because they were invalid." )
whitelisted = newDnsCounter ( "whitelisted_total" , "Count of requests not filtered by dnsfilter because they are whitelisted." )
safesearch = newDnsCounter ( "safesearch_total" , "Count of requests replaced by dnsfilter safesearch." )
errorsTotal = newDnsCounter ( "errors_total" , "Count of requests that dnsfilter couldn't process because of transitive errors." )
)
//
// coredns handling functions
//
func setupPlugin ( c * caddy . Controller ) ( * Plugin , error ) {
// create new Plugin and copy default values
var d = new ( Plugin )
* d = defaultPlugin
d . d = dnsfilter . New ( )
d . hosts = make ( map [ string ] net . IP )
var filterFileName string
for c . Next ( ) {
args := c . RemainingArgs ( )
if len ( args ) == 0 {
// must have at least one argument
return nil , c . ArgErr ( )
}
filterFileName = args [ 0 ]
for c . NextBlock ( ) {
switch c . Val ( ) {
case "safebrowsing" :
d . d . EnableSafeBrowsing ( )
if c . NextArg ( ) {
if len ( c . Val ( ) ) == 0 {
return nil , c . ArgErr ( )
}
d . d . SetSafeBrowsingServer ( c . Val ( ) )
}
case "safesearch" :
d . d . EnableSafeSearch ( )
case "parental" :
if ! c . NextArg ( ) {
return nil , c . ArgErr ( )
}
sensitivity , err := strconv . Atoi ( c . Val ( ) )
if err != nil {
return nil , c . ArgErr ( )
}
err = d . d . EnableParental ( sensitivity )
if err != nil {
return nil , c . ArgErr ( )
}
if c . NextArg ( ) {
if len ( c . Val ( ) ) == 0 {
return nil , c . ArgErr ( )
}
d . ParentalBlockHost = c . Val ( )
}
case "querylog" :
d . QueryLogEnabled = true
2018-09-05 18:21:46 +00:00
onceQueryLog . Do ( func ( ) {
2018-08-30 14:25:33 +00:00
go startQueryLogServer ( ) // TODO: how to handle errors?
} )
}
}
}
file , err := os . Open ( filterFileName )
if err != nil {
return nil , err
}
defer file . Close ( )
2018-09-05 23:07:23 +00:00
count := 0
2018-08-30 14:25:33 +00:00
scanner := bufio . NewScanner ( file )
for scanner . Scan ( ) {
text := scanner . Text ( )
if d . parseEtcHosts ( text ) {
continue
}
err = d . d . AddRule ( text , 0 )
if err == dnsfilter . ErrInvalidSyntax {
continue
}
if err != nil {
return nil , err
}
2018-09-05 23:07:23 +00:00
count ++
2018-08-30 14:25:33 +00:00
}
2018-09-05 23:07:23 +00:00
log . Printf ( "Added %d rules from %s" , count , filterFileName )
2018-08-30 14:25:33 +00:00
if err = scanner . Err ( ) ; err != nil {
return nil , err
}
d . upstream , err = upstream . New ( nil )
if err != nil {
return nil , err
}
return d , nil
}
func setup ( c * caddy . Controller ) error {
d , err := setupPlugin ( c )
if err != nil {
return err
}
config := dnsserver . GetConfig ( c )
config . AddPlugin ( func ( next plugin . Handler ) plugin . Handler {
d . Next = next
return d
} )
c . OnStartup ( func ( ) error {
2018-09-05 23:07:57 +00:00
m := dnsserver . GetConfig ( c ) . Handler ( "prometheus" )
if m == nil {
return nil
}
if x , ok := m . ( * metrics . Metrics ) ; ok {
x . MustRegister ( requests )
x . MustRegister ( filtered )
x . MustRegister ( filteredLists )
x . MustRegister ( filteredSafebrowsing )
x . MustRegister ( filteredParental )
x . MustRegister ( whitelisted )
x . MustRegister ( safesearch )
x . MustRegister ( errorsTotal )
x . MustRegister ( d )
}
2018-08-30 14:25:33 +00:00
return nil
} )
c . OnShutdown ( d . OnShutdown )
return nil
}
func ( d * Plugin ) parseEtcHosts ( text string ) bool {
if pos := strings . IndexByte ( text , '#' ) ; pos != - 1 {
text = text [ 0 : pos ]
}
fields := strings . Fields ( text )
if len ( fields ) < 2 {
return false
}
addr := net . ParseIP ( fields [ 0 ] )
if addr == nil {
return false
}
for _ , host := range fields [ 1 : ] {
if val , ok := d . hosts [ host ] ; ok {
log . Printf ( "warning: host %s already has value %s, will overwrite it with %s" , host , val , addr )
}
d . hosts [ host ] = addr
}
return true
}
func ( d * Plugin ) OnShutdown ( ) error {
d . d . Destroy ( )
d . d = nil
return nil
}
type statsFunc func ( ch interface { } , name string , text string , value float64 , valueType prometheus . ValueType )
func doDesc ( ch interface { } , name string , text string , value float64 , valueType prometheus . ValueType ) {
realch , ok := ch . ( chan <- * prometheus . Desc )
if ok == false {
log . Printf ( "Couldn't convert ch to chan<- *prometheus.Desc\n" )
return
}
realch <- prometheus . NewDesc ( name , text , nil , nil )
}
func doMetric ( ch interface { } , name string , text string , value float64 , valueType prometheus . ValueType ) {
realch , ok := ch . ( chan <- prometheus . Metric )
if ok == false {
log . Printf ( "Couldn't convert ch to chan<- prometheus.Metric\n" )
return
}
desc := prometheus . NewDesc ( name , text , nil , nil )
realch <- prometheus . MustNewConstMetric ( desc , valueType , value )
}
func gen ( ch interface { } , doFunc statsFunc , name string , text string , value float64 , valueType prometheus . ValueType ) {
doFunc ( ch , name , text , value , valueType )
}
func ( d * Plugin ) doStats ( ch interface { } , doFunc statsFunc ) {
stats := d . d . GetStats ( )
gen ( ch , doFunc , "coredns_dnsfilter_safebrowsing_requests" , "Number of safebrowsing HTTP requests that were sent" , float64 ( stats . Safebrowsing . Requests ) , prometheus . CounterValue )
gen ( ch , doFunc , "coredns_dnsfilter_safebrowsing_cachehits" , "Number of safebrowsing lookups that didn't need HTTP requests" , float64 ( stats . Safebrowsing . CacheHits ) , prometheus . CounterValue )
gen ( ch , doFunc , "coredns_dnsfilter_safebrowsing_pending" , "Number of currently pending safebrowsing HTTP requests" , float64 ( stats . Safebrowsing . Pending ) , prometheus . GaugeValue )
gen ( ch , doFunc , "coredns_dnsfilter_safebrowsing_pending_max" , "Maximum number of pending safebrowsing HTTP requests" , float64 ( stats . Safebrowsing . PendingMax ) , prometheus . GaugeValue )
gen ( ch , doFunc , "coredns_dnsfilter_parental_requests" , "Number of parental HTTP requests that were sent" , float64 ( stats . Parental . Requests ) , prometheus . CounterValue )
gen ( ch , doFunc , "coredns_dnsfilter_parental_cachehits" , "Number of parental lookups that didn't need HTTP requests" , float64 ( stats . Parental . CacheHits ) , prometheus . CounterValue )
gen ( ch , doFunc , "coredns_dnsfilter_parental_pending" , "Number of currently pending parental HTTP requests" , float64 ( stats . Parental . Pending ) , prometheus . GaugeValue )
gen ( ch , doFunc , "coredns_dnsfilter_parental_pending_max" , "Maximum number of pending parental HTTP requests" , float64 ( stats . Parental . PendingMax ) , prometheus . GaugeValue )
}
func ( d * Plugin ) Describe ( ch chan <- * prometheus . Desc ) {
d . doStats ( ch , doDesc )
}
func ( d * Plugin ) Collect ( ch chan <- prometheus . Metric ) {
d . doStats ( ch , doMetric )
}
func ( d * Plugin ) replaceHostWithValAndReply ( ctx context . Context , w dns . ResponseWriter , r * dns . Msg , host string , val string , question dns . Question ) ( int , error ) {
// check if it's a domain name or IP address
addr := net . ParseIP ( val )
var records [ ] dns . RR
log . Println ( "Will give" , val , "instead of" , host )
if addr != nil {
// this is an IP address, return it
result , err := dns . NewRR ( host + " A " + val )
if err != nil {
log . Printf ( "Got error %s\n" , err )
return dns . RcodeServerFailure , fmt . Errorf ( "plugin/dnsfilter: %s" , err )
}
records = append ( records , result )
} else {
// this is a domain name, need to look it up
req := new ( dns . Msg )
req . SetQuestion ( dns . Fqdn ( val ) , question . Qtype )
req . RecursionDesired = true
reqstate := request . Request { W : w , Req : req , Context : ctx }
result , err := d . upstream . Lookup ( reqstate , dns . Fqdn ( val ) , reqstate . QType ( ) )
if err != nil {
log . Printf ( "Got error %s\n" , err )
return dns . RcodeServerFailure , fmt . Errorf ( "plugin/dnsfilter: %s" , err )
}
if result != nil {
for _ , answer := range result . Answer {
answer . Header ( ) . Name = question . Name
}
records = result . Answer
}
}
m := new ( dns . Msg )
m . SetReply ( r )
m . Authoritative , m . RecursionAvailable , m . Compress = true , true , true
m . Answer = append ( m . Answer , records ... )
state := request . Request { W : w , Req : r , Context : ctx }
state . SizeAndDo ( m )
err := state . W . WriteMsg ( m )
if err != nil {
log . Printf ( "Got error %s\n" , err )
return dns . RcodeServerFailure , fmt . Errorf ( "plugin/dnsfilter: %s" , err )
}
return dns . RcodeSuccess , nil
}
// generate SOA record that makes DNS clients cache NXdomain results
// the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant
func genSOA ( r * dns . Msg ) [ ] dns . RR {
zone := r . Question [ 0 ] . Name
header := dns . RR_Header { Name : zone , Rrtype : dns . TypeSOA , Ttl : 3600 , Class : dns . ClassINET }
Mbox := "hostmaster."
if zone [ 0 ] != '.' {
Mbox += zone
}
Ns := "fake-for-negative-caching.adguard.com."
soa := defaultSOA
soa . Hdr = header
soa . Mbox = Mbox
soa . Ns = Ns
soa . Serial = uint32 ( time . Now ( ) . Unix ( ) )
return [ ] dns . RR { soa }
}
func writeNXdomain ( ctx context . Context , w dns . ResponseWriter , r * dns . Msg ) ( int , error ) {
state := request . Request { W : w , Req : r , Context : ctx }
m := new ( dns . Msg )
m . SetRcode ( state . Req , dns . RcodeNameError )
m . Authoritative , m . RecursionAvailable , m . Compress = true , true , true
m . Ns = genSOA ( r )
state . SizeAndDo ( m )
err := state . W . WriteMsg ( m )
if err != nil {
log . Printf ( "Got error %s\n" , err )
return dns . RcodeServerFailure , err
}
return dns . RcodeNameError , nil
}
func ( d * Plugin ) serveDNSInternal ( ctx context . Context , w dns . ResponseWriter , r * dns . Msg ) ( int , error , dnsfilter . Result ) {
if len ( r . Question ) != 1 {
// google DNS, bind and others do the same
return dns . RcodeFormatError , fmt . Errorf ( "Got DNS request with != 1 questions" ) , dnsfilter . Result { }
}
for _ , question := range r . Question {
host := strings . ToLower ( strings . TrimSuffix ( question . Name , "." ) )
// if input is empty host, filter it out right away
if index := strings . IndexByte ( host , byte ( '.' ) ) ; index == - 1 {
rcode , err := writeNXdomain ( ctx , w , r )
if err != nil {
return rcode , err , dnsfilter . Result { }
}
return rcode , err , dnsfilter . Result { Reason : dnsfilter . FilteredInvalid }
}
// is it a safesearch domain?
if val , ok := d . d . SafeSearchDomain ( host ) ; ok {
rcode , err := d . replaceHostWithValAndReply ( ctx , w , r , host , val , question )
if err != nil {
return rcode , err , dnsfilter . Result { }
}
return rcode , err , dnsfilter . Result { Reason : dnsfilter . FilteredSafeSearch }
}
// is it in hosts?
if val , ok := d . hosts [ host ] ; ok {
// it is, if it's a loopback host, reply with NXDOMAIN
if val . IsLoopback ( ) {
rcode , err := writeNXdomain ( ctx , w , r )
if err != nil {
return rcode , err , dnsfilter . Result { }
}
return rcode , err , dnsfilter . Result { Reason : dnsfilter . FilteredInvalid }
}
// it's not a loopback host, replace it with value specified
rcode , err := d . replaceHostWithValAndReply ( ctx , w , r , host , val . String ( ) , question )
if err != nil {
return rcode , err , dnsfilter . Result { }
}
return rcode , err , dnsfilter . Result { Reason : dnsfilter . FilteredSafeSearch }
}
// needs to be filtered instead
result , err := d . d . CheckHost ( host )
if err != nil {
log . Printf ( "plugin/dnsfilter: %s\n" , err )
return dns . RcodeServerFailure , fmt . Errorf ( "plugin/dnsfilter: %s" , err ) , dnsfilter . Result { }
}
2018-09-05 23:09:57 +00:00
if result . IsFiltered {
switch result . Reason {
case dnsfilter . FilteredSafeBrowsing :
// return cname safebrowsing.block.dns.adguard.com
val := d . SafeBrowsingBlockHost
rcode , err := d . replaceHostWithValAndReply ( ctx , w , r , host , val , question )
if err != nil {
return rcode , err , dnsfilter . Result { }
}
return rcode , err , result
case dnsfilter . FilteredParental :
// return cname family.block.dns.adguard.com
val := d . ParentalBlockHost
rcode , err := d . replaceHostWithValAndReply ( ctx , w , r , host , val , question )
if err != nil {
return rcode , err , dnsfilter . Result { }
}
return rcode , err , result
case dnsfilter . FilteredBlackList :
// return NXdomain
rcode , err := writeNXdomain ( ctx , w , r )
if err != nil {
return rcode , err , dnsfilter . Result { }
}
return rcode , err , result
default :
log . Printf ( "SHOULD NOT HAPPEN -- got unknown reason for filtering: %T %v %s" , result . Reason , result . Reason , result . Reason . String ( ) )
2018-08-30 14:25:33 +00:00
}
2018-09-05 23:09:57 +00:00
} else {
switch result . Reason {
case dnsfilter . NotFilteredWhiteList :
rcode , err := plugin . NextOrFailure ( d . Name ( ) , d . Next , ctx , w , r )
return rcode , err , result
case dnsfilter . NotFilteredNotFound :
// do nothing, pass through to lower code
default :
log . Printf ( "SHOULD NOT HAPPEN -- got unknown reason for not filtering: %T %v %s" , result . Reason , result . Reason , result . Reason . String ( ) )
2018-08-30 14:25:33 +00:00
}
}
}
rcode , err := plugin . NextOrFailure ( d . Name ( ) , d . Next , ctx , w , r )
return rcode , err , dnsfilter . Result { }
}
func ( d * Plugin ) ServeDNS ( ctx context . Context , w dns . ResponseWriter , r * dns . Msg ) ( int , error ) {
start := time . Now ( )
requests . Inc ( )
2018-08-31 16:59:04 +00:00
state := request . Request { W : w , Req : r }
ip := state . IP ( )
2018-08-30 14:25:33 +00:00
// capture the written answer
rrw := dnstest . NewRecorder ( w )
rcode , err , result := d . serveDNSInternal ( ctx , rrw , r )
if rcode > 0 {
// actually send the answer if we have one
answer := new ( dns . Msg )
answer . SetRcode ( r , rcode )
state . SizeAndDo ( answer )
w . WriteMsg ( answer )
}
// increment counters
switch {
case err != nil :
errorsTotal . Inc ( )
case result . Reason == dnsfilter . FilteredBlackList :
filtered . Inc ( )
filteredLists . Inc ( )
case result . Reason == dnsfilter . FilteredSafeBrowsing :
filtered . Inc ( )
filteredSafebrowsing . Inc ( )
case result . Reason == dnsfilter . FilteredParental :
filtered . Inc ( )
filteredParental . Inc ( )
case result . Reason == dnsfilter . FilteredInvalid :
filtered . Inc ( )
filteredInvalid . Inc ( )
case result . Reason == dnsfilter . FilteredSafeSearch :
// the request was passsed through but not filtered, don't increment filtered
safesearch . Inc ( )
case result . Reason == dnsfilter . NotFilteredWhiteList :
whitelisted . Inc ( )
case result . Reason == dnsfilter . NotFilteredNotFound :
// do nothing
case result . Reason == dnsfilter . NotFilteredError :
text := "SHOULD NOT HAPPEN: got DNSFILTER_NOTFILTERED_ERROR without err != nil!"
log . Println ( text )
err = errors . New ( text )
rcode = dns . RcodeServerFailure
}
// log
if d . QueryLogEnabled {
2018-09-05 18:21:46 +00:00
logRequest ( r , rrw . Msg , result , time . Since ( start ) , ip )
2018-08-30 14:25:33 +00:00
}
return rcode , err
}
func ( d * Plugin ) Name ( ) string { return "dnsfilter" }
2018-09-05 23:07:57 +00:00
var onceQueryLog sync . Once