diff --git a/README.md b/README.md index 634b52548..4ec594bdc 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,24 @@ [![GolangCI](https://golangci.com/badges/github.com/AdguardTeam/dnsproxy.svg)](https://golangci.com/r/github.com/AdguardTeam/dnsproxy) [![Go Doc](https://godoc.org/github.com/AdguardTeam/dnsproxy?status.svg)](https://godoc.org/github.com/AdguardTeam/dnsproxy) -# DNS Proxy +# DNS Proxy A simple DNS proxy server that supports all existing DNS protocols including `DNS-over-TLS`, `DNS-over-HTTPS`, and `DNSCrypt`. Moreover, it can work as a `DNS-over-HTTPS` and/or `DNS-over-TLS` server. +- [How to build](#how-to-build) +- [Usage](#usage) +- [Examples](#examples) + - [Simple options](#simple-options) + - [Encrypted upstreams](#encrypted-upstreams) + - [Encrypted DNS server](#encrypted-dns-server) + - [Additional features](#additional-features) + - [Fastest addr + cache-min-ttl](#fastest-addr--cache-min-ttl) + - [Specifying upstreams for domains](#specifying-upstreams-for-domains) + - [EDNS Client Subnet](#edns-client-subnet) + - [Bogus NXDomain](#bogus-nxdomain) + ## How to build You will need go v1.14 or later. @@ -25,33 +37,34 @@ Usage: dnsproxy [OPTIONS] Application Options: - -v, --verbose Verbose output (optional) - -o, --output= Path to the log file. If not set, write to stdout. - -l, --listen= Listen address (default: 0.0.0.0) - -p, --port= Listen port. Zero value disables TCP and UDP listeners (default: 53) - -h, --https-port= Listen port for DNS-over-HTTPS (default: 0) - -t, --tls-port= Listen port for DNS-over-TLS (default: 0) - -c, --tls-crt= Path to a file with the certificate chain - -k, --tls-key= Path to a file with the private key - -b, --bootstrap= Bootstrap DNS for DoH and DoT, can be specified multiple times (default: 8.8.8.8:53) - -r, --ratelimit= Ratelimit (requests per second) (default: 0) - -z, --cache If specified, DNS cache is enabled - -e --cache-size= Cache size (in bytes). Default: 65536 - --cache-min-ttl= Minimum TTL value for DNS entries, in seconds. Capped at 3600 seconds (1 hour). - Artificially extending TTLs should only be done with careful consideration. - --cache-max-ttl= Maximum TTL value for DNS entries, in seconds. - -a, --refuse-any If specified, refuse ANY requests - -u, --upstream= An upstream to be used (can be specified multiple times) - -f, --fallback= Fallback resolvers to use when regular ones are unavailable, can be specified multiple times - -s, --all-servers Use parallel queries to speed up resolving by querying all upstream servers simultaneously - -d, --ipv6-disabled Disable IPv6. All AAAA requests will be replied with No Error response code and empty answer - --edns Use EDNS Client Subnet extension - --edns-addr= Send EDNS Client Address - --fastest-addr Respond to A or AAAA requests only with the fastest IP address + -v, --verbose Verbose output (optional) + -o, --output= Path to the log file. If not set, write to stdout. + -l, --listen= Listen address (default: 0.0.0.0) + -p, --port= Listen port. Zero value disables TCP and UDP listeners (default: 53) + -h, --https-port= Listen port for DNS-over-HTTPS (default: 0) + -t, --tls-port= Listen port for DNS-over-TLS (default: 0) + -c, --tls-crt= Path to a file with the certificate chain + -k, --tls-key= Path to a file with the private key + -u, --upstream= An upstream to be used (can be specified multiple times) + -b, --bootstrap= Bootstrap DNS for DoH and DoT, can be specified multiple times (default: 8.8.8.8:53) + -f, --fallback= Fallback resolvers to use when regular ones are unavailable, can be specified multiple times + --all-servers If specified, parallel queries to all configured upstream servers are enabled + --fastest-addr Respond to A or AAAA requests only with the fastest IP address + --cache If specified, DNS cache is enabled + --cache-size= Cache size (in bytes). Default: 64k + --cache-min-ttl= Minimum TTL value for DNS entries, in seconds. Capped at 3600. Artificially extending TTLs should only be done with + careful consideration. + --cache-max-ttl= Maximum TTL value for DNS entries, in seconds. + -r, --ratelimit= Ratelimit (requests per second) (default: 0) + --refuse-any If specified, refuse ANY requests + --edns Use EDNS Client Subnet extension + --edns-addr= Send EDNS Client Address + --ipv6-disabled If specified, all AAAA requests will be replied with NoError RCode and empty answer + --bogus-nxdomain= Transform responses that contain only given IP addresses into NXDOMAIN. Can be specified multiple times. + --version Prints the program version Help Options: - -h, --help Show this help message - --version Print DNS proxy version + -h, --help Show this help message ``` ## Examples @@ -154,7 +167,7 @@ If one or more domains are specified, that upstream (`upstreamString`) is used o 2. More specific domains take precedence over less specific domains, so: `--upstream=[/host.com/]1.2.3.4 --upstream=[/www.host.com/]2.3.4.5` will send queries for *.host.com to 1.2.3.4, except *.www.host.com, which will go to 2.3.4.5 3. The special server address '#' means, "use the standard servers", so: `--upstream=[/host.com/]1.2.3.4 --upstream=[/www.host.com/]#` will send queries for *.host.com to 1.2.3.4, except *.www.host.com which will be forwarded as usual. -#### Examples +**Examples** Sends queries for `*.local` domains to `192.168.0.1:53`. Other queries are sent to `8.8.8.8:53`. ``` @@ -183,3 +196,13 @@ If you want to use EDNS CS feature when you're connecting to the proxy from a lo ``` Now even if your IP address is 192.168.0.1 and it's not a public IP, the proxy will pass through 72.72.72.72 to the upstream server. + +### Bogus NXDomain + +This option is similar to dnsmasq `bogus-nxdomain`. If specified, `dnsproxy` transforms responses that contain only the given IP addresses into `NXDOMAIN`. Can be specified multiple times. + +In the example below, we use AdGuard DNS server that returns `0.0.0.0` for blocked domains, and tranform them to `NXDOMAIN`. + +``` +./dnsproxy -u 176.103.130.130:53 --bogus-nxdomain=0.0.0.0 +``` \ No newline at end of file diff --git a/main.go b/main.go index 6c01e8ed8..899aa2c66 100644 --- a/main.go +++ b/main.go @@ -18,12 +18,18 @@ import ( // Options represents console arguments type Options struct { + // Log settings + // -- + // Should we write Verbose bool `short:"v" long:"verbose" description:"Verbose output (optional)" optional:"yes" optional-value:"true"` // Path to a log file LogOutput string `short:"o" long:"output" description:"Path to the log file. If not set, write to stdout." default:""` + // Server settings + // -- + // Server listen address ListenAddr string `short:"l" long:"listen" description:"Listen address" default:"0.0.0.0"` @@ -42,17 +48,33 @@ type Options struct { // Path to the file with the private key TLSKeyPath string `short:"k" long:"tls-key" description:"Path to a file with the private key"` + // Upstream DNS servers settings + // -- + + // DNS upstreams + Upstreams []string `short:"u" long:"upstream" description:"An upstream to be used (can be specified multiple times)" required:"true"` + // Bootstrap DNS BootstrapDNS []string `short:"b" long:"bootstrap" description:"Bootstrap DNS for DoH and DoT, can be specified multiple times (default: 8.8.8.8:53)"` - // Ratelimit value - Ratelimit int `short:"r" long:"ratelimit" description:"Ratelimit (requests per second)" default:"0"` + // Fallback DNS resolver + Fallbacks []string `short:"f" long:"fallback" description:"Fallback resolvers to use when regular ones are unavailable, can be specified multiple times"` + + // If true, parallel queries to all configured upstream servers + AllServers bool `long:"all-servers" description:"If specified, parallel queries to all configured upstream servers are enabled" optional:"yes" optional-value:"true"` + + // Respond to A or AAAA requests only with the fastest IP address + // detected by ICMP response time or TCP connection time + FastestAddress bool `long:"fastest-addr" description:"Respond to A or AAAA requests only with the fastest IP address" optional:"yes" optional-value:"true"` + + // Cache settings + // -- // If true, DNS cache is enabled - Cache bool `short:"z" long:"cache" description:"If specified, DNS cache is enabled" optional:"yes" optional-value:"true"` + Cache bool `long:"cache" description:"If specified, DNS cache is enabled" optional:"yes" optional-value:"true"` // Cache size value - CacheSizeBytes int `short:"e" long:"cache-size" description:"Cache size (in bytes). Default: 64k"` + CacheSizeBytes int `long:"cache-size" description:"Cache size (in bytes). Default: 64k"` // DNS cache minimum TTL value - overrides record value CacheMinTTL uint32 `long:"cache-min-ttl" description:"Minimum TTL value for DNS entries, in seconds. Capped at 3600. Artificially extending TTLs should only be done with careful consideration."` @@ -60,20 +82,17 @@ type Options struct { // DNS cache maximum TTL value - overrides record value CacheMaxTTL uint32 `long:"cache-max-ttl" description:"Maximum TTL value for DNS entries, in seconds."` - // If true, refuse ANY requests - RefuseAny bool `short:"a" long:"refuse-any" description:"If specified, refuse ANY requests" optional:"yes" optional-value:"true"` - - // DNS upstreams - Upstreams []string `short:"u" long:"upstream" description:"An upstream to be used (can be specified multiple times)" required:"true"` + // Anti-DNS amplification measures + // -- - // Fallback DNS resolver - Fallbacks []string `short:"f" long:"fallback" description:"Fallback resolvers to use when regular ones are unavailable, can be specified multiple times"` + // Ratelimit value + Ratelimit int `short:"r" long:"ratelimit" description:"Ratelimit (requests per second)" default:"0"` - // If true, parallel queries to all configured upstream servers - AllServers bool `short:"s" long:"all-servers" description:"If specified, parallel queries to all configured upstream servers are enabled" optional:"yes" optional-value:"true"` + // If true, refuse ANY requests + RefuseAny bool `long:"refuse-any" description:"If specified, refuse ANY requests" optional:"yes" optional-value:"true"` - // If true, all AAAA requests will be replied with NoError RCode and empty answer - IPv6Disabled bool `short:"d" long:"ipv6-disabled" description:"If specified, all AAAA requests will be replied with NoError RCode and empty answer" optional:"yes" optional-value:"true"` + // ECS settings + // -- // Use EDNS Client Subnet extension EnableEDNSSubnet bool `long:"edns" description:"Use EDNS Client Subnet extension" optional:"yes" optional-value:"true"` @@ -81,9 +100,14 @@ type Options struct { // Use Custom EDNS Client Address EDNSAddr string `long:"edns-addr" description:"Send EDNS Client Address"` - // Respond to A or AAAA requests only with the fastest IP address - // detected by ICMP response time or TCP connection time - FastestAddress bool `long:"fastest-addr" description:"Respond to A or AAAA requests only with the fastest IP address" optional:"yes" optional-value:"true"` + // Other settings and options + // -- + + // If true, all AAAA requests will be replied with NoError RCode and empty answer + IPv6Disabled bool `long:"ipv6-disabled" description:"If specified, all AAAA requests will be replied with NoError RCode and empty answer" optional:"yes" optional-value:"true"` + + // Transform responses that contain only given IP addresses into NXDOMAIN + BogusNXDomain []string `long:"bogus-nxdomain" description:"Transform responses that contain only given IP addresses into NXDOMAIN. Can be specified multiple times."` // Print DNSProxy version (just for the help) Version bool `long:"version" description:"Prints the program version"` @@ -209,6 +233,19 @@ func createProxyConfig(options Options) proxy.Config { config.Fallbacks = fallbacks } + if len(options.BogusNXDomain) > 0 { + bogusIP := []net.IP{} + for _, s := range options.BogusNXDomain { + ip := net.ParseIP(s) + if ip == nil { + log.Error("Invalid IP: %s", s) + } else { + bogusIP = append(bogusIP, ip) + } + } + config.BogusNXDomain = bogusIP + } + // Prepare the TLS config if options.TLSCertPath != "" && options.TLSKeyPath != "" { tlsConfig, err := newTLSConfig(options.TLSCertPath, options.TLSKeyPath) diff --git a/proxy/bogus_nxdomain.go b/proxy/bogus_nxdomain.go new file mode 100644 index 000000000..14efd4c19 --- /dev/null +++ b/proxy/bogus_nxdomain.go @@ -0,0 +1,28 @@ +package proxy + +import ( + "github.com/AdguardTeam/dnsproxy/proxyutil" + "github.com/miekg/dns" +) + +// isBogusNXDomain - checks if the specified DNS message +// contains ONLY ip addresses from the Proxy.BogusNXDomain list +func (p *Proxy) isBogusNXDomain(reply *dns.Msg) bool { + if reply == nil || + len(p.BogusNXDomain) == 0 || + len(reply.Answer) == 0 || + (reply.Question[0].Qtype != dns.TypeA && + reply.Question[0].Qtype != dns.TypeAAAA) { + return false + } + + for _, rr := range reply.Answer { + ip := proxyutil.GetIPFromDNSRecord(rr) + if !proxyutil.ContainsIP(p.BogusNXDomain, ip) { + return false + } + } + + // All IPs are bogus if we got here + return true +} diff --git a/proxy/bogus_nxdomain_test.go b/proxy/bogus_nxdomain_test.go new file mode 100644 index 000000000..b2272ca33 --- /dev/null +++ b/proxy/bogus_nxdomain_test.go @@ -0,0 +1,60 @@ +package proxy + +import ( + "net" + "testing" + + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func TestBogusNXDomainTypeA(t *testing.T) { + dnsProxy := createTestProxy(t, nil) + dnsProxy.CacheEnabled = true + dnsProxy.BogusNXDomain = []net.IP{net.ParseIP("4.3.2.1")} + + u := testUpstream{} + dnsProxy.Upstreams = []upstream.Upstream{&u} + err := dnsProxy.Start() + assert.Nil(t, err) + + // first request + // upstream answers with a bogus IP + u.aResp = new(dns.A) + u.aResp.Hdr.Rrtype = dns.TypeA + u.aResp.Hdr.Name = "host." + u.aResp.A = net.ParseIP("4.3.2.1") + u.aResp.Hdr.Ttl = 10 + + clientIP := net.IP{1, 2, 3, 0} + d := DNSContext{} + d.Req = createHostTestMessage("host") + d.Addr = &net.TCPAddr{ + IP: clientIP, + } + + err = dnsProxy.Resolve(&d) + assert.Nil(t, err) + + // check response + assert.NotNil(t, d.Res) + assert.Equal(t, dns.RcodeNameError, d.Res.Rcode) + + // second request + // upstream answers with a normal IP + u.aResp = new(dns.A) + u.aResp.Hdr.Rrtype = dns.TypeA + u.aResp.Hdr.Name = "host." + u.aResp.A = net.ParseIP("4.3.2.2") + u.aResp.Hdr.Ttl = 10 + + err = dnsProxy.Resolve(&d) + assert.Nil(t, err) + + // check response + assert.NotNil(t, d.Res) + assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) + + _ = dnsProxy.Stop() +} diff --git a/proxy/config.go b/proxy/config.go new file mode 100644 index 000000000..af82b3bcb --- /dev/null +++ b/proxy/config.go @@ -0,0 +1,222 @@ +package proxy + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + "strings" + "time" + + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/utils" + + "github.com/AdguardTeam/dnsproxy/upstream" +) + +// Config contains all the fields necessary for proxy configuration +type Config struct { + // Listeners + // -- + + UDPListenAddr *net.UDPAddr // if nil, then it does not listen for UDP + TCPListenAddr *net.TCPAddr // if nil, then it does not listen for TCP + HTTPSListenAddr *net.TCPAddr // if nil, then it does not listen for HTTPS (DoH) + TLSListenAddr *net.TCPAddr // if nil, then it does not listen for TLS (DoT) + TLSConfig *tls.Config // necessary for listening for TLS + + // Rate-limiting and anti-DNS amplification measures + // -- + + Ratelimit int // max number of requests per second from a given IP (0 to disable) + RatelimitWhitelist []string // a list of whitelisted client IP addresses + RefuseAny bool // if true, refuse ANY requests + + // Upstream DNS servers and their settings + // -- + + Upstreams []upstream.Upstream // list of upstreams + Fallbacks []upstream.Upstream // list of fallback resolvers (which will be used if regular upstream failed to answer) + AllServers bool // if true, parallel queries to all configured upstream servers are enabled + DomainsReservedUpstreams map[string][]upstream.Upstream // map of domains and lists of corresponding upstreams + FindFastestAddr bool // use Fastest Address algorithm + + // BogusNXDomain - transforms responses that contain only given IP addresses into NXDOMAIN + // Similar to dnsmasq's "bogus-nxdomain" + BogusNXDomain []net.IP + + // Enable EDNS Client Subnet option + // DNS requests to the upstream server will contain an OPT record with Client Subnet option. + // If the original request already has this option set, we pass it through as is. + // Otherwise, we set it ourselves using the client IP with subnet /24 (for IPv4) and /112 (for IPv6). + // + // If the upstream server supports ECS, it sets subnet number in the response. + // This subnet number along with the client IP and other data is used as a cache key. + // Next time, if a client from the same subnet requests this host name, + // we get the response from cache. + // If another client from a different subnet requests this host name, + // we pass his request to the upstream server. + // + // If the upstream server doesn't support ECS (there's no subnet number in response), + // this response will be cached for all clients. + // + // If client IP is private (i.e. not public), we don't add EDNS record into a request. + // And so there will be no EDNS record in response either. + // We store these responses in general cache (without subnet) + // so they will never be used for clients with public IP addresses. + EnableEDNSClientSubnet bool + EDNSAddr net.IP // ECS IP used in request + + // Cache settings + // -- + + CacheEnabled bool // cache status + CacheSizeBytes int // Cache size (in bytes). Default: 64k + CacheMinTTL uint32 // Minimum TTL for DNS entries (in seconds). + CacheMaxTTL uint32 // Maximum TTL for DNS entries (in seconds). + + // Handlers (for the case when dnsproxy is used as a library) + // -- + + BeforeRequestHandler BeforeRequestHandler // callback that is called before each request + RequestHandler RequestHandler // callback that can handle incoming DNS requests + ResponseHandler ResponseHandler // response callback + + // Other settings + // -- + + MaxGoroutines int // maximum number of goroutines processing the DNS requests (important for mobile) +} + +// UpstreamConfig is a wrapper for list of default upstreams and map of reserved domains and corresponding upstreams +type UpstreamConfig struct { + Upstreams []upstream.Upstream // list of default upstreams + DomainReservedUpstreams map[string][]upstream.Upstream // map of reserved domains and lists of corresponding upstreams +} + +// ParseUpstreamsConfig returns UpstreamConfig and error if upstreams configuration is invalid +// default upstream syntax: +// reserved upstream syntax: [/domain1/../domainN/] +// More specific domains take priority over less specific domains, +// To exclude more specific domains from reserved upstreams querying you should use the following syntax: [/domain1/../domainN/]# +// So the following config: ["[/host.com/]1.2.3.4", "[/www.host.com/]2.3.4.5", "[/maps.host.com/]#", "3.4.5.6"] +// will send queries for *.host.com to 1.2.3.4, except for *.www.host.com, which will go to 2.3.4.5 and *.maps.host.com, +// which will go to default server 3.4.5.6 with all other domains +func ParseUpstreamsConfig(upstreamConfig, bootstrapDNS []string, timeout time.Duration) (UpstreamConfig, error) { + return ParseUpstreamsConfigEx(upstreamConfig, bootstrapDNS, timeout, func(address string, opts upstream.Options) (upstream.Upstream, error) { + return upstream.AddressToUpstream(address, opts) + }) +} + +// AddressToUpstreamFunction is a type for a callback function which creates an upstream object +type AddressToUpstreamFunction func(address string, opts upstream.Options) (upstream.Upstream, error) + +// ParseUpstreamsConfigEx is an extended version of ParseUpstreamsConfig() which has a custom callback function which creates an upstream object +func ParseUpstreamsConfigEx(upstreamConfig, bootstrapDNS []string, timeout time.Duration, addressToUpstreamFunction AddressToUpstreamFunction) (UpstreamConfig, error) { + upstreams := []upstream.Upstream{} + domainReservedUpstreams := map[string][]upstream.Upstream{} + + if len(bootstrapDNS) > 0 { + for i, b := range bootstrapDNS { + log.Info("Bootstrap %d: %s", i, b) + } + } + + for i, u := range upstreamConfig { + hosts := []string{} + if strings.HasPrefix(u, "[/") { + // split domains and upstream string + domainsAndUpstream := strings.Split(strings.TrimPrefix(u, "[/"), "/]") + if len(domainsAndUpstream) != 2 { + return UpstreamConfig{}, fmt.Errorf("wrong upstream specification: %s", u) + } + + // split domains list + for _, host := range strings.Split(domainsAndUpstream[0], "/") { + if host != "" { + if err := utils.IsValidHostname(host); err != nil { + return UpstreamConfig{}, err + } + hosts = append(hosts, strings.ToLower(host+".")) + } else { + // empty domain specification means `unqualified names only` + hosts = append(hosts, UnqualifiedNames) + } + } + u = domainsAndUpstream[1] + } + + // # excludes more specific domain from reserved upstreams querying + if u == "#" && len(hosts) > 0 { + for _, host := range hosts { + domainReservedUpstreams[host] = nil + } + continue + } + + // create an upstream + dnsUpstream, err := addressToUpstreamFunction(u, upstream.Options{Bootstrap: bootstrapDNS, Timeout: timeout}) + if err != nil { + return UpstreamConfig{}, fmt.Errorf("cannot prepare the upstream %s (%s): %s", u, bootstrapDNS, err) + } + + if len(hosts) > 0 { + for _, host := range hosts { + _, ok := domainReservedUpstreams[host] + if !ok { + domainReservedUpstreams[host] = []upstream.Upstream{} + } + domainReservedUpstreams[host] = append(domainReservedUpstreams[host], dnsUpstream) + } + log.Printf("Upstream %d: %s is reserved for next domains: %s", i, dnsUpstream.Address(), strings.Join(hosts, ", ")) + } else { + log.Printf("Upstream %d: %s", i, dnsUpstream.Address()) + upstreams = append(upstreams, dnsUpstream) + } + } + return UpstreamConfig{Upstreams: upstreams, DomainReservedUpstreams: domainReservedUpstreams}, nil +} + +// validateConfig verifies that the supplied configuration is valid and returns an error if it's not +func (p *Proxy) validateConfig() error { + if p.started { + return errors.New("server has been already started") + } + + if p.UDPListenAddr == nil && p.TCPListenAddr == nil && p.TLSListenAddr == nil && p.HTTPSListenAddr == nil { + return errors.New("no listen address specified") + } + + if p.TLSListenAddr != nil && p.TLSConfig == nil { + return errors.New("cannot create a TLS listener without TLS config") + } + + if p.HTTPSListenAddr != nil && p.TLSConfig == nil { + return errors.New("cannot create an HTTPS listener without TLS config") + } + + if len(p.Upstreams) == 0 { + if len(p.DomainsReservedUpstreams) == 0 { + return errors.New("no upstreams specified") + } + return errors.New("no default upstreams specified") + } + + if p.CacheMinTTL > 0 || p.CacheMaxTTL > 0 { + log.Info("Cache TTL override is enabled. Min=%d, Max=%d", p.CacheMinTTL, p.CacheMaxTTL) + } + + if p.Ratelimit > 0 { + log.Info("Ratelimit is enabled and set to %d rps", p.Ratelimit) + } + + if p.RefuseAny { + log.Info("The server is configured to refuse ANY requests") + } + + if len(p.BogusNXDomain) > 0 { + log.Info("%d bogus-nxdomain IP specified", len(p.BogusNXDomain)) + } + + return nil +} diff --git a/proxy/dns64.go b/proxy/dns64.go index 7b7f620c9..cb85407ed 100644 --- a/proxy/dns64.go +++ b/proxy/dns64.go @@ -12,7 +12,9 @@ import ( // isEmptyAAAAResponse checks AAAA answer to be empty // returns true if NAT64 prefix already calculated and there are no answers for AAAA question func (p *Proxy) isEmptyAAAAResponse(resp, req *dns.Msg) bool { - return p.isNAT64PrefixAvailable() && req.Question[0].Qtype == dns.TypeAAAA && (resp == nil || len(resp.Answer) == 0) + return p.isNAT64PrefixAvailable() && + (resp == nil || len(resp.Answer) == 0) && + req.Question[0].Qtype == dns.TypeAAAA } // isNAT64PrefixAvailable returns true if NAT64 prefix was calculated diff --git a/proxy/exchange.go b/proxy/exchange.go new file mode 100644 index 000000000..5bcfc181f --- /dev/null +++ b/proxy/exchange.go @@ -0,0 +1,86 @@ +package proxy + +import ( + "sort" + "time" + + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/log" + "github.com/joomcode/errorx" + "github.com/miekg/dns" +) + +// exchange -- sends DNS query to the upstream DNS server and returns the response +func (p *Proxy) exchange(req *dns.Msg, upstreams []upstream.Upstream) (reply *dns.Msg, u upstream.Upstream, err error) { + qtype := req.Question[0].Qtype + if p.FindFastestAddr && (qtype == dns.TypeA || qtype == dns.TypeAAAA) { + reply, u, err = p.fastestAddr.ExchangeFastest(req, upstreams) + return + } + + if p.AllServers { + reply, u, err = upstream.ExchangeParallel(upstreams, req) + return + } + + if len(upstreams) == 1 { + u = upstreams[0] + reply, _, err = exchangeWithUpstream(u, req) + return + } + + // sort upstreams by rtt from fast to slow + sortedUpstreams := p.getSortedUpstreams(upstreams) + + errs := []error{} + for _, dnsUpstream := range sortedUpstreams { + reply, elapsed, err := exchangeWithUpstream(dnsUpstream, req) + if err == nil { + p.updateRtt(dnsUpstream.Address(), elapsed) + return reply, dnsUpstream, err + } + errs = append(errs, err) + p.updateRtt(dnsUpstream.Address(), int(defaultTimeout/time.Millisecond)) + } + return nil, nil, errorx.DecorateMany("all upstreams failed to exchange request", errs...) +} + +func (p *Proxy) getSortedUpstreams(u []upstream.Upstream) []upstream.Upstream { + // clone upstreams list to avoid race conditions + p.rttLock.Lock() + clone := make([]upstream.Upstream, len(u)) + copy(clone, u) + + sort.Slice(clone, func(i, j int) bool { + if p.upstreamRttStats[clone[i].Address()] < p.upstreamRttStats[clone[j].Address()] { + return true + } + return false + }) + p.rttLock.Unlock() + + return clone +} + +// exchangeWithUpstream returns result of Exchange with elapsed time +func exchangeWithUpstream(u upstream.Upstream, req *dns.Msg) (*dns.Msg, int, error) { + startTime := time.Now() + reply, err := u.Exchange(req) + elapsed := int(time.Since(startTime) / time.Millisecond) + if err != nil { + log.Tracef("upstream %s failed to exchange %s in %d milliseconds. Cause: %s", u.Address(), req.Question[0].String(), elapsed, err) + } else { + log.Tracef("upstream %s successfully finished exchange of %s. Elapsed %d ms.", u.Address(), req.Question[0].String(), elapsed) + } + return reply, elapsed, err +} + +// updateRtt updates rtt in upstreamRttStats for given address +func (p *Proxy) updateRtt(address string, rtt int) { + p.rttLock.Lock() + if p.upstreamRttStats == nil { + p.upstreamRttStats = map[string]int{} + } + p.upstreamRttStats[address] = (p.upstreamRttStats[address] + rtt) / 2 + p.rttLock.Unlock() +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 8f5868bcb..a8d35a4bf 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -2,15 +2,8 @@ package proxy import ( - "crypto/tls" - "encoding/base64" - "errors" - "fmt" - "io/ioutil" "net" "net/http" - "sort" - "strconv" "strings" "sync" "time" @@ -19,7 +12,6 @@ import ( "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/utils" "github.com/joomcode/errorx" "github.com/miekg/dns" gocache "github.com/patrickmn/go-cache" @@ -91,62 +83,6 @@ type Proxy struct { sync.RWMutex // protects parallel access to proxy structures } -// Config contains all the fields necessary for proxy configuration -type Config struct { - UDPListenAddr *net.UDPAddr // if nil, then it does not listen for UDP - TCPListenAddr *net.TCPAddr // if nil, then it does not listen for TCP - - HTTPSListenAddr *net.TCPAddr // if nil, then it does not listen for HTTPS (DoH) - TLSListenAddr *net.TCPAddr // if nil, then it does not listen for TLS (DoT) - TLSConfig *tls.Config // necessary for listening for TLS - - Ratelimit int // max number of requests per second from a given IP (0 to disable) - RatelimitWhitelist []string // a list of whitelisted client IP addresses - - RefuseAny bool // if true, refuse ANY requests - AllServers bool // if true, parallel queries to all configured upstream servers are enabled - - // Enable EDNS Client Subnet option - // DNS requests to the upstream server will contain an OPT record with Client Subnet option. - // If the original request already has this option set, we pass it through as is. - // Otherwise, we set it ourselves using the client IP with subnet /24 (for IPv4) and /112 (for IPv6). - // - // If the upstream server supports ECS, it sets subnet number in the response. - // This subnet number along with the client IP and other data is used as a cache key. - // Next time, if a client from the same subnet requests this host name, - // we get the response from cache. - // If another client from a different subnet requests this host name, - // we pass his request to the upstream server. - // - // If the upstream server doesn't support ECS (there's no subnet number in response), - // this response will be cached for all clients. - // - // If client IP is private (i.e. not public), we don't add EDNS record into a request. - // And so there will be no EDNS record in response either. - // We store these responses in general cache (without subnet) - // so they will never be used for clients with public IP addresses. - EnableEDNSClientSubnet bool - EDNSAddr net.IP // ECS IP used in request - - CacheEnabled bool // cache status - CacheSizeBytes int // Cache size (in bytes). Default: 64k - CacheMinTTL uint32 // Minimum TTL for DNS entries (in seconds). - CacheMaxTTL uint32 // Maximum TTL for DNS entries (in seconds). - - Upstreams []upstream.Upstream // list of upstreams - Fallbacks []upstream.Upstream // list of fallback resolvers (which will be used if regular upstream failed to answer) - - BeforeRequestHandler BeforeRequestHandler // callback that is called before each request - RequestHandler RequestHandler // callback that can handle incoming DNS requests - ResponseHandler ResponseHandler // response callback - - DomainsReservedUpstreams map[string][]upstream.Upstream // map of domains and lists of corresponding upstreams - - FindFastestAddr bool // use Fastest Address algorithm - - MaxGoroutines int // maximum number of goroutines processing the DNS requests (important for mobile) -} - // DNSContext represents a DNS request message context type DNSContext struct { Proto string // "udp", "tcp", "tls", "https" @@ -168,95 +104,6 @@ type DNSContext struct { ecsReqMask uint8 // ECS mask used in request } -// UpstreamConfig is a wrapper for list of default upstreams and map of reserved domains and corresponding upstreams -type UpstreamConfig struct { - Upstreams []upstream.Upstream // list of default upstreams - DomainReservedUpstreams map[string][]upstream.Upstream // map of reserved domains and lists of corresponding upstreams -} - -// ParseUpstreamsConfig returns UpstreamConfig and error if upstreams configuration is invalid -// default upstream syntax: -// reserved upstream syntax: [/domain1/../domainN/] -// More specific domains take priority over less specific domains, -// To exclude more specific domains from reserved upstreams querying you should use the following syntax: [/domain1/../domainN/]# -// So the following config: ["[/host.com/]1.2.3.4", "[/www.host.com/]2.3.4.5", "[/maps.host.com/]#", "3.4.5.6"] -// will send queries for *.host.com to 1.2.3.4, except for *.www.host.com, which will go to 2.3.4.5 and *.maps.host.com, -// which will go to default server 3.4.5.6 with all other domains -func ParseUpstreamsConfig(upstreamConfig, bootstrapDNS []string, timeout time.Duration) (UpstreamConfig, error) { - return ParseUpstreamsConfigEx(upstreamConfig, bootstrapDNS, timeout, func(address string, opts upstream.Options) (upstream.Upstream, error) { - return upstream.AddressToUpstream(address, opts) - }) -} - -// AddressToUpstreamFunction is a type for a callback function which creates an upstream object -type AddressToUpstreamFunction func(address string, opts upstream.Options) (upstream.Upstream, error) - -// ParseUpstreamsConfigEx is an extended version of ParseUpstreamsConfig() which has a custom callback function which creates an upstream object -func ParseUpstreamsConfigEx(upstreamConfig, bootstrapDNS []string, timeout time.Duration, addressToUpstreamFunction AddressToUpstreamFunction) (UpstreamConfig, error) { - upstreams := []upstream.Upstream{} - domainReservedUpstreams := map[string][]upstream.Upstream{} - - if len(bootstrapDNS) > 0 { - for i, b := range bootstrapDNS { - log.Info("Bootstrap %d: %s", i, b) - } - } - - for i, u := range upstreamConfig { - hosts := []string{} - if strings.HasPrefix(u, "[/") { - // split domains and upstream string - domainsAndUpstream := strings.Split(strings.TrimPrefix(u, "[/"), "/]") - if len(domainsAndUpstream) != 2 { - return UpstreamConfig{}, fmt.Errorf("wrong upstream specification: %s", u) - } - - // split domains list - for _, host := range strings.Split(domainsAndUpstream[0], "/") { - if host != "" { - if err := utils.IsValidHostname(host); err != nil { - return UpstreamConfig{}, err - } - hosts = append(hosts, strings.ToLower(host+".")) - } else { - // empty domain specification means `unqualified names only` - hosts = append(hosts, UnqualifiedNames) - } - } - u = domainsAndUpstream[1] - } - - // # excludes more specific domain from reserved upstreams querying - if u == "#" && len(hosts) > 0 { - for _, host := range hosts { - domainReservedUpstreams[host] = nil - } - continue - } - - // create an upstream - dnsUpstream, err := addressToUpstreamFunction(u, upstream.Options{Bootstrap: bootstrapDNS, Timeout: timeout}) - if err != nil { - return UpstreamConfig{}, fmt.Errorf("cannot prepare the upstream %s (%s): %s", u, bootstrapDNS, err) - } - - if len(hosts) > 0 { - for _, host := range hosts { - _, ok := domainReservedUpstreams[host] - if !ok { - domainReservedUpstreams[host] = []upstream.Upstream{} - } - domainReservedUpstreams[host] = append(domainReservedUpstreams[host], dnsUpstream) - } - log.Printf("Upstream %d: %s is reserved for next domains: %s", i, dnsUpstream.Address(), strings.Join(hosts, ", ")) - } else { - log.Printf("Upstream %d: %s", i, dnsUpstream.Address()) - upstreams = append(upstreams, dnsUpstream) - } - } - return UpstreamConfig{Upstreams: upstreams, DomainReservedUpstreams: domainReservedUpstreams}, nil -} - // Init - initializes the proxy structures but does not start it func (p *Proxy) Init() { if p.CacheEnabled { @@ -401,82 +248,6 @@ func (p *Proxy) Addr(proto string) net.Addr { } } -// getUpstreamsForDomain looks for a domain in reserved domains map and returns a list of corresponding upstreams. -// returns default upstreams list if domain isn't found. More specific domains take priority over less specific domains. -// For example, map contains the following keys: host.com and www.host.com -// If we are looking for domain mail.host.com, this method will return value of host.com key -// If we are looking for domain www.host.com, this method will return value of www.host.com key -// If more specific domain value is nil, it means that domain was excluded and should be exchanged with default upstreams -func (p *Proxy) getUpstreamsForDomain(host string) []upstream.Upstream { - if len(p.DomainsReservedUpstreams) == 0 { - return p.Upstreams - } - - dotsCount := strings.Count(host, ".") - if dotsCount < 2 { - return p.DomainsReservedUpstreams[UnqualifiedNames] - } - - for i := 1; i <= dotsCount; i++ { - h := strings.SplitAfterN(host, ".", i) - name := h[i-1] - if u, ok := p.DomainsReservedUpstreams[strings.ToLower(name)]; ok { - if u == nil { - // domain was excluded from reserved upstreams querying - return p.Upstreams - } - return u - } - } - - return p.Upstreams -} - -// Set EDNS Client-Subnet data in DNS request -func (p *Proxy) processECS(d *DNSContext) { - d.ecsReqIP = nil - d.ecsReqMask = uint8(0) - - ip, mask, _ := parseECS(d.Req) - if mask == 0 { - // Set EDNS Client-Subnet data - var clientIP net.IP - if p.Config.EDNSAddr != nil { - clientIP = p.Config.EDNSAddr - } else { - switch addr := d.Addr.(type) { - case *net.UDPAddr: - clientIP = addr.IP - case *net.TCPAddr: - clientIP = addr.IP - } - } - - if clientIP != nil && isPublicIP(clientIP) { - ip, mask = setECS(d.Req, clientIP, 0) - log.Debug("Set ECS data: %s/%d", ip, mask) - } - } else { - log.Debug("Passing through ECS data: %s/%d", ip, mask) - } - - d.ecsReqIP = ip - d.ecsReqMask = mask -} - -// Set TTL value of all records according to our settings -func (p *Proxy) setMinMaxTTL(r *dns.Msg) { - for _, rr := range r.Answer { - originalTTL := rr.Header().Ttl - newTTL := respectTTLOverrides(originalTTL, p.CacheMinTTL, p.CacheMaxTTL) - - if originalTTL != newTTL { - log.Debug("Override TTL from %d to %d", originalTTL, newTTL) - rr.Header().Ttl = newTTL - } - } -} - // Resolve is the default resolving method used by the DNS proxy to query upstreams func (p *Proxy) Resolve(d *DNSContext) error { if p.Config.EnableEDNSClientSubnet { @@ -498,7 +269,11 @@ func (p *Proxy) Resolve(d *DNSContext) error { startTime := time.Now() reply, u, err := p.exchange(d.Req, upstreams) if p.isEmptyAAAAResponse(reply, d.Req) { + log.Tracef("Received empty AAAA response, checking DNS64") reply, u, err = p.checkDNS64(d.Req, reply, upstreams) + } else if p.isBogusNXDomain(reply) { + log.Tracef("Received IP from the bogus-nxdomain list, replacing response") + reply = p.genNXDomain(reply) } rtt := int(time.Since(startTime) / time.Millisecond) @@ -533,545 +308,65 @@ func (p *Proxy) Resolve(d *DNSContext) error { return err } -func (p *Proxy) exchange(req *dns.Msg, upstreams []upstream.Upstream) (reply *dns.Msg, u upstream.Upstream, err error) { - qtype := req.Question[0].Qtype - if p.FindFastestAddr && (qtype == dns.TypeA || qtype == dns.TypeAAAA) { - reply, u, err = p.fastestAddr.ExchangeFastest(req, upstreams) - return - } - - if p.AllServers { - reply, u, err = upstream.ExchangeParallel(upstreams, req) - return - } - - if len(upstreams) == 1 { - u = upstreams[0] - reply, _, err = exchangeWithUpstream(u, req) - return - } - - // sort upstreams by rtt from fast to slow - sortedUpstreams := p.getSortedUpstreams(upstreams) - - errs := []error{} - for _, dnsUpstream := range sortedUpstreams { - reply, elapsed, err := exchangeWithUpstream(dnsUpstream, req) - if err == nil { - p.updateRtt(dnsUpstream.Address(), elapsed) - return reply, dnsUpstream, err - } - errs = append(errs, err) - p.updateRtt(dnsUpstream.Address(), int(defaultTimeout/time.Millisecond)) - } - return nil, nil, errorx.DecorateMany("all upstreams failed to exchange request", errs...) -} - -func (p *Proxy) getSortedUpstreams(u []upstream.Upstream) []upstream.Upstream { - // clone upstreams list to avoid race conditions - p.rttLock.Lock() - clone := make([]upstream.Upstream, len(u)) - copy(clone, u) - - sort.Slice(clone, func(i, j int) bool { - if p.upstreamRttStats[clone[i].Address()] < p.upstreamRttStats[clone[j].Address()] { - return true - } - return false - }) - p.rttLock.Unlock() - - return clone -} - -// exchangeWithUpstream returns result of Exchange with elapsed time -func exchangeWithUpstream(u upstream.Upstream, req *dns.Msg) (*dns.Msg, int, error) { - startTime := time.Now() - reply, err := u.Exchange(req) - elapsed := int(time.Since(startTime) / time.Millisecond) - if err != nil { - log.Tracef("upstream %s failed to exchange %s in %d milliseconds. Cause: %s", u.Address(), req.Question[0].String(), elapsed, err) - } else { - log.Tracef("upstream %s successfully finished exchange of %s. Elapsed %d ms.", u.Address(), req.Question[0].String(), elapsed) - } - return reply, elapsed, err -} - -// updateRtt updates rtt in upstreamRttStats for given address -func (p *Proxy) updateRtt(address string, rtt int) { - p.rttLock.Lock() - if p.upstreamRttStats == nil { - p.upstreamRttStats = map[string]int{} - } - p.upstreamRttStats[address] = (p.upstreamRttStats[address] + rtt) / 2 - p.rttLock.Unlock() -} - -// validateConfig verifies that the supplied configuration is valid and returns an error if it's not -func (p *Proxy) validateConfig() error { - if p.started { - return errors.New("server has been already started") - } - - if p.UDPListenAddr == nil && p.TCPListenAddr == nil && p.TLSListenAddr == nil && p.HTTPSListenAddr == nil { - return errors.New("no listen address specified") - } - - if p.TLSListenAddr != nil && p.TLSConfig == nil { - return errors.New("cannot create a TLS listener without TLS config") - } - - if p.HTTPSListenAddr != nil && p.TLSConfig == nil { - return errors.New("cannot create an HTTPS listener without TLS config") - } - - if len(p.Upstreams) == 0 { - if len(p.DomainsReservedUpstreams) == 0 { - return errors.New("no upstreams specified") - } - return errors.New("no default upstreams specified") - } - - if p.CacheMinTTL > 0 || p.CacheMaxTTL > 0 { - log.Info("Cache TTL override is enabled. Min=%d, Max=%d", p.CacheMinTTL, p.CacheMaxTTL) - } - - if p.Ratelimit > 0 { - log.Info("Ratelimit is enabled and set to %d rps", p.Ratelimit) - } - - if p.RefuseAny { - log.Info("The server is configured to refuse ANY requests") - } - - return nil -} - -// startListeners configures and starts listener loops -func (p *Proxy) startListeners() error { - if p.UDPListenAddr != nil { - err := p.udpCreate() - if err != nil { - return err - } - } - - if p.TCPListenAddr != nil { - log.Printf("Creating the TCP server socket") - tcpAddr := p.TCPListenAddr - tcpListen, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return errorx.Decorate(err, "couldn't listen to TCP socket") - } - p.tcpListen = tcpListen - log.Printf("Listening to tcp://%s", p.tcpListen.Addr()) - } - - if p.TLSListenAddr != nil { - log.Printf("Creating the TLS server socket") - tlsAddr := p.TLSListenAddr - tcpListen, err := net.ListenTCP("tcp", tlsAddr) - if err != nil { - return errorx.Decorate(err, "could not start TLS listener") - } - p.tlsListen = tls.NewListener(tcpListen, p.TLSConfig) - log.Printf("Listening to tls://%s", p.tlsListen.Addr()) - } - - if p.HTTPSListenAddr != nil { - log.Printf("Creating the HTTPS server") - tcpListen, err := net.ListenTCP("tcp", p.HTTPSListenAddr) - if err != nil { - return errorx.Decorate(err, "could not start HTTPS listener") - } - p.httpsListen = tls.NewListener(tcpListen, p.TLSConfig) - log.Printf("Listening to https://%s", p.httpsListen.Addr()) - p.httpsServer = &http.Server{ - Handler: p, - ReadHeaderTimeout: defaultTimeout, - WriteTimeout: defaultTimeout, - } - } - - if p.udpListen != nil { - go p.udpPacketLoop(p.udpListen) - } - - if p.tcpListen != nil { - go p.tcpPacketLoop(p.tcpListen, ProtoTCP) - } - - if p.tlsListen != nil { - go p.tcpPacketLoop(p.tlsListen, ProtoTLS) +// getUpstreamsForDomain looks for a domain in reserved domains map and returns a list of corresponding upstreams. +// returns default upstreams list if domain isn't found. More specific domains take priority over less specific domains. +// For example, map contains the following keys: host.com and www.host.com +// If we are looking for domain mail.host.com, this method will return value of host.com key +// If we are looking for domain www.host.com, this method will return value of www.host.com key +// If more specific domain value is nil, it means that domain was excluded and should be exchanged with default upstreams +func (p *Proxy) getUpstreamsForDomain(host string) []upstream.Upstream { + if len(p.DomainsReservedUpstreams) == 0 { + return p.Upstreams } - if p.httpsListen != nil { - go p.listenHTTPS() + dotsCount := strings.Count(host, ".") + if dotsCount < 2 { + return p.DomainsReservedUpstreams[UnqualifiedNames] } - return nil -} - -// tcpPacketLoop listens for incoming TCP packets -// proto is either "tcp" or "tls" -func (p *Proxy) tcpPacketLoop(l net.Listener, proto string) { - log.Printf("Entering the %s listener loop on %s", proto, l.Addr()) - for { - clientConn, err := l.Accept() - - if err != nil { - if isConnClosed(err) { - log.Printf("tcpListen.Accept() returned because we're reading from a closed connection, exiting loop") - break + for i := 1; i <= dotsCount; i++ { + h := strings.SplitAfterN(host, ".", i) + name := h[i-1] + if u, ok := p.DomainsReservedUpstreams[strings.ToLower(name)]; ok { + if u == nil { + // domain was excluded from reserved upstreams querying + return p.Upstreams } - log.Printf("got error when reading from TCP listen: %s", err) - } else { - p.guardMaxGoroutines() - go func() { - p.handleTCPConnection(clientConn, proto) - p.freeMaxGoroutines() - }() - } - } -} - -// handleTCPConnection starts a loop that handles an incoming TCP connection -// proto is either "tcp" or "tls" -func (p *Proxy) handleTCPConnection(conn net.Conn, proto string) { - log.Tracef("Start handling the new %s connection %s", proto, conn.RemoteAddr()) - defer conn.Close() - - for { - p.RLock() - if !p.started { - return - } - p.RUnlock() - - conn.SetDeadline(time.Now().Add(defaultTimeout)) //nolint - packet, err := readPrefixed(&conn) - if err != nil { - return - } - - msg := &dns.Msg{} - err = msg.Unpack(packet) - if err != nil { - log.Printf("error handling TCP packet: %s", err) - return - } - - d := &DNSContext{ - Proto: proto, - Req: msg, - Addr: conn.RemoteAddr(), - Conn: conn, - } - - err = p.handleDNSRequest(d) - if err != nil { - log.Tracef("error handling DNS (%s) request: %s", d.Proto, err) - } - } -} - -// Writes a response to the TCP (or TLS) client -func (p *Proxy) respondTCP(d *DNSContext) error { - resp := d.Res - conn := d.Conn - - bytes, err := resp.Pack() - if err != nil { - return errorx.Decorate(err, "couldn't convert message into wire format: %s", resp.String()) - } - - bytes, err = prefixWithSize(bytes) - if err != nil { - return errorx.Decorate(err, "couldn't add prefix with size") - } - - n, err := conn.Write(bytes) - if n == 0 && isConnClosed(err) { - return err - } - if err != nil { - return errorx.Decorate(err, "conn.Write() returned error") - } - if n != len(bytes) { - return fmt.Errorf("conn.Write() returned with %d != %d", n, len(bytes)) - } - return nil -} - -// serveHttps starts the HTTPS server -func (p *Proxy) listenHTTPS() { - log.Printf("Listening to DNS-over-HTTPS on %s", p.httpsListen.Addr()) - err := p.httpsServer.Serve(p.httpsListen) - - if err != http.ErrServerClosed { - log.Printf("HTTPS server was closed unexpectedly: %s", err) - } else { - log.Printf("HTTPS server was closed") - } -} - -// ServeHTTP is the http.RequestHandler implementation that handles DOH queries -// Here is what it returns: -// http.StatusBadRequest - if there is no DNS request data -// http.StatusUnsupportedMediaType - if request content type is not application/dns-message -// http.StatusMethodNotAllowed - if request method is not GET or POST -func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - log.Tracef("Incoming HTTPS request on %s", r.URL) - - var buf []byte - var err error - - switch r.Method { - case http.MethodGet: - dnsParam := r.URL.Query().Get("dns") - buf, err = base64.RawURLEncoding.DecodeString(dnsParam) - if len(buf) == 0 || err != nil { - log.Tracef("Cannot parse DNS request from %s", dnsParam) - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } - case http.MethodPost: - contentType := r.Header.Get("Content-Type") - if contentType != "application/dns-message" { - log.Tracef("Unsupported media type: %s", contentType) - http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType) - return - } - - buf, err = ioutil.ReadAll(r.Body) - if err != nil { - log.Tracef("Cannot read the request body: %s", err) - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } - defer r.Body.Close() - default: - log.Tracef("Wrong HTTP method: %s", r.Method) - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - return - } - - msg := new(dns.Msg) - if err = msg.Unpack(buf); err != nil { - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } - - addr, _ := p.remoteAddr(r) - - d := &DNSContext{ - Proto: ProtoHTTPS, - Req: msg, - Addr: addr, - HTTPRequest: r, - HTTPResponseWriter: w, - } - - err = p.handleDNSRequest(d) - if err != nil { - log.Tracef("error handling DNS (%s) request: %s", d.Proto, err) - } -} - -// Get a client IP address from HTTP headers that proxy servers may set -func getIPFromHTTPRequest(r *http.Request) net.IP { - names := []string{ - "CF-Connecting-IP", "True-Client-IP", // set by CloudFlare servers - "X-Real-IP", - } - for _, name := range names { - s := r.Header.Get(name) - ip := net.ParseIP(s) - if ip != nil { - return ip - } - } - - s := r.Header.Get("X-Forwarded-For") - s = splitNext(&s, ',') // get left-most IP address - ip := net.ParseIP(s) - if ip != nil { - return ip - } - - return nil -} - -// Writes a response to the DOH client -func (p *Proxy) respondHTTPS(d *DNSContext) error { - resp := d.Res - w := d.HTTPResponseWriter - - bytes, err := resp.Pack() - if err != nil { - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return errorx.Decorate(err, "couldn't convert message into wire format: %s", resp.String()) - } - - w.Header().Set("Server", "AdGuard DNS") - w.Header().Set("Content-Type", "application/dns-message") - _, err = w.Write(bytes) - return err -} - -func (p *Proxy) remoteAddr(r *http.Request) (net.Addr, error) { - host, port, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return nil, err - } - - portValue, err := strconv.Atoi(port) - if err != nil { - return nil, err - } - - ip := getIPFromHTTPRequest(r) - if ip != nil { - log.Debug("Using IP address from HTTP request: %s", ip) - } else { - ip = net.ParseIP(host) - if ip == nil { - return nil, fmt.Errorf("invalid IP: %s", host) + return u } } - return &net.TCPAddr{IP: ip, Port: portValue}, nil -} - -// guardMaxGoroutines makes sure that there are no more than p.MaxGoroutines parallel goroutines -func (p *Proxy) guardMaxGoroutines() { - if p.maxGoroutines != nil { - p.maxGoroutines <- true - } -} - -// freeMaxGoroutines allows other goroutines to do the job -func (p *Proxy) freeMaxGoroutines() { - if p.maxGoroutines != nil { - <-p.maxGoroutines - } + return p.Upstreams } -// handleDNSRequest processes the incoming packet bytes and returns with an optional response packet. -func (p *Proxy) handleDNSRequest(d *DNSContext) error { - d.StartTime = time.Now() - p.logDNSMessage(d.Req) - - if p.BeforeRequestHandler != nil { - ok, err := p.BeforeRequestHandler(p, d) - if err != nil { - log.Error("Error in the BeforeRequestHandler: %s", err) - d.Res = p.genServerFailure(d.Req) - p.respond(d) - return nil - } - if !ok { - return nil // do nothing, don't reply - } - } - - // ratelimit based on IP only, protects CPU cycles and outbound connections - if d.Proto == ProtoUDP && p.isRatelimited(d.Addr) { - log.Tracef("Ratelimiting %v based on IP only", d.Addr) - return nil // do nothing, don't reply, we got ratelimited - } - - if len(d.Req.Question) != 1 { - log.Printf("got invalid number of questions: %v", len(d.Req.Question)) - d.Res = p.genServerFailure(d.Req) - } - - // refuse ANY requests (anti-DDOS measure) - if p.RefuseAny && len(d.Req.Question) > 0 && d.Req.Question[0].Qtype == dns.TypeANY { - log.Tracef("Refusing type=ANY request") - d.Res = p.genNotImpl(d.Req) - } - - var err error - - if d.Res == nil { - if len(p.Upstreams) == 0 { - panic("SHOULD NOT HAPPEN: no default upstreams specified") - } +// Set EDNS Client-Subnet data in DNS request +func (p *Proxy) processECS(d *DNSContext) { + d.ecsReqIP = nil + d.ecsReqMask = uint8(0) - // execute the DNS request - // if there is a custom middleware configured, use it - if p.RequestHandler != nil { - err = p.RequestHandler(p, d) + ip, mask, _ := parseECS(d.Req) + if mask == 0 { + // Set EDNS Client-Subnet data + var clientIP net.IP + if p.Config.EDNSAddr != nil { + clientIP = p.Config.EDNSAddr } else { - err = p.Resolve(d) - } - - if err != nil { - err = errorx.Decorate(err, "talking to dnsUpstream failed") + switch addr := d.Addr.(type) { + case *net.UDPAddr: + clientIP = addr.IP + case *net.TCPAddr: + clientIP = addr.IP + } } - } - - p.logDNSMessage(d.Res) - p.respond(d) - return err -} -// respond writes the specified response to the client (or does nothing if d.Res is empty) -func (p *Proxy) respond(d *DNSContext) { - if d.Res == nil { - return - } - - // d.Conn can be nil in the case of a DOH request - if d.Conn != nil { - d.Conn.SetWriteDeadline(time.Now().Add(defaultTimeout)) //nolint - } - - var err error - - switch d.Proto { - case ProtoUDP: - err = p.respondUDP(d) - case ProtoTCP: - err = p.respondTCP(d) - case ProtoTLS: - err = p.respondTCP(d) - case ProtoHTTPS: - err = p.respondHTTPS(d) - default: - err = fmt.Errorf("SHOULD NOT HAPPEN - unknown protocol: %s", d.Proto) - } - - if err != nil { - if strings.HasSuffix(err.Error(), "use of closed network connection") { - // This case may happen while we're restarting DNS server - log.Debug("error while responding to a DNS request: %s", err) - } else { - log.Printf("error while responding to a DNS request: %s", err) + if clientIP != nil && isPublicIP(clientIP) { + ip, mask = setECS(d.Req, clientIP, 0) + log.Debug("Set ECS data: %s/%d", ip, mask) } - } -} - -func (p *Proxy) genServerFailure(request *dns.Msg) *dns.Msg { - resp := dns.Msg{} - resp.SetRcode(request, dns.RcodeServerFailure) - resp.RecursionAvailable = true - return &resp -} - -func (p *Proxy) genNotImpl(request *dns.Msg) *dns.Msg { - resp := dns.Msg{} - resp.SetRcode(request, dns.RcodeNotImplemented) - resp.RecursionAvailable = true - resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it - return &resp -} - -func (p *Proxy) logDNSMessage(m *dns.Msg) { - if m.Response { - log.Tracef("OUT: %s", m) } else { - log.Tracef("IN: %s", m) + log.Debug("Passing through ECS data: %s/%d", ip, mask) } + + d.ecsReqIP = ip + d.ecsReqMask = mask } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index fe393a1f4..f17cf80e5 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -49,8 +49,7 @@ func TestHttpsProxy(t *testing.T) { httpsAddr := dnsProxy.Addr(ProtoHTTPS) dialer := &net.Dialer{ - Timeout: defaultTimeout, - DualStack: true, + Timeout: defaultTimeout, } dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { // Route request to the DOH server address @@ -816,18 +815,6 @@ func getIPFromResponse(resp *dns.Msg) net.IP { return nil } -// Return the first CNAME value in response -func getCNAMEFromResponse(resp *dns.Msg) string { - for _, ans := range resp.Answer { - cn, ok := ans.(*dns.CNAME) - if !ok { - continue - } - return cn.Target - } - return "" -} - type testUpstream struct { cname1Resp *dns.CNAME aResp *dns.A diff --git a/proxy/server.go b/proxy/server.go new file mode 100644 index 000000000..3c78ef40b --- /dev/null +++ b/proxy/server.go @@ -0,0 +1,232 @@ +package proxy + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "strings" + "time" + + "github.com/AdguardTeam/golibs/log" + "github.com/joomcode/errorx" + "github.com/miekg/dns" +) + +// startListeners configures and starts listener loops +func (p *Proxy) startListeners() error { + if p.UDPListenAddr != nil { + err := p.udpCreate() + if err != nil { + return err + } + } + + if p.TCPListenAddr != nil { + log.Printf("Creating the TCP server socket") + tcpAddr := p.TCPListenAddr + tcpListen, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return errorx.Decorate(err, "couldn't listen to TCP socket") + } + p.tcpListen = tcpListen + log.Printf("Listening to tcp://%s", p.tcpListen.Addr()) + } + + if p.TLSListenAddr != nil { + log.Printf("Creating the TLS server socket") + tlsAddr := p.TLSListenAddr + tcpListen, err := net.ListenTCP("tcp", tlsAddr) + if err != nil { + return errorx.Decorate(err, "could not start TLS listener") + } + p.tlsListen = tls.NewListener(tcpListen, p.TLSConfig) + log.Printf("Listening to tls://%s", p.tlsListen.Addr()) + } + + if p.HTTPSListenAddr != nil { + log.Printf("Creating the HTTPS server") + tcpListen, err := net.ListenTCP("tcp", p.HTTPSListenAddr) + if err != nil { + return errorx.Decorate(err, "could not start HTTPS listener") + } + p.httpsListen = tls.NewListener(tcpListen, p.TLSConfig) + log.Printf("Listening to https://%s", p.httpsListen.Addr()) + p.httpsServer = &http.Server{ + Handler: p, + ReadHeaderTimeout: defaultTimeout, + WriteTimeout: defaultTimeout, + } + } + + if p.udpListen != nil { + go p.udpPacketLoop(p.udpListen) + } + + if p.tcpListen != nil { + go p.tcpPacketLoop(p.tcpListen, ProtoTCP) + } + + if p.tlsListen != nil { + go p.tcpPacketLoop(p.tlsListen, ProtoTLS) + } + + if p.httpsListen != nil { + go p.listenHTTPS() + } + + return nil +} + +// guardMaxGoroutines makes sure that there are no more than p.MaxGoroutines parallel goroutines +func (p *Proxy) guardMaxGoroutines() { + if p.maxGoroutines != nil { + p.maxGoroutines <- true + } +} + +// freeMaxGoroutines allows other goroutines to do the job +func (p *Proxy) freeMaxGoroutines() { + if p.maxGoroutines != nil { + <-p.maxGoroutines + } +} + +// handleDNSRequest processes the incoming packet bytes and returns with an optional response packet. +func (p *Proxy) handleDNSRequest(d *DNSContext) error { + d.StartTime = time.Now() + p.logDNSMessage(d.Req) + + if p.BeforeRequestHandler != nil { + ok, err := p.BeforeRequestHandler(p, d) + if err != nil { + log.Error("Error in the BeforeRequestHandler: %s", err) + d.Res = p.genServerFailure(d.Req) + p.respond(d) + return nil + } + if !ok { + return nil // do nothing, don't reply + } + } + + // ratelimit based on IP only, protects CPU cycles and outbound connections + if d.Proto == ProtoUDP && p.isRatelimited(d.Addr) { + log.Tracef("Ratelimiting %v based on IP only", d.Addr) + return nil // do nothing, don't reply, we got ratelimited + } + + if len(d.Req.Question) != 1 { + log.Printf("got invalid number of questions: %v", len(d.Req.Question)) + d.Res = p.genServerFailure(d.Req) + } + + // refuse ANY requests (anti-DDOS measure) + if p.RefuseAny && len(d.Req.Question) > 0 && d.Req.Question[0].Qtype == dns.TypeANY { + log.Tracef("Refusing type=ANY request") + d.Res = p.genNotImpl(d.Req) + } + + var err error + + if d.Res == nil { + if len(p.Upstreams) == 0 { + panic("SHOULD NOT HAPPEN: no default upstreams specified") + } + + // execute the DNS request + // if there is a custom middleware configured, use it + if p.RequestHandler != nil { + err = p.RequestHandler(p, d) + } else { + err = p.Resolve(d) + } + + if err != nil { + err = errorx.Decorate(err, "talking to dnsUpstream failed") + } + } + + p.logDNSMessage(d.Res) + p.respond(d) + return err +} + +// respond writes the specified response to the client (or does nothing if d.Res is empty) +func (p *Proxy) respond(d *DNSContext) { + if d.Res == nil { + return + } + + // d.Conn can be nil in the case of a DOH request + if d.Conn != nil { + d.Conn.SetWriteDeadline(time.Now().Add(defaultTimeout)) //nolint + } + + var err error + + switch d.Proto { + case ProtoUDP: + err = p.respondUDP(d) + case ProtoTCP: + err = p.respondTCP(d) + case ProtoTLS: + err = p.respondTCP(d) + case ProtoHTTPS: + err = p.respondHTTPS(d) + default: + err = fmt.Errorf("SHOULD NOT HAPPEN - unknown protocol: %s", d.Proto) + } + + if err != nil { + if strings.HasSuffix(err.Error(), "use of closed network connection") { + // This case may happen while we're restarting DNS server + log.Debug("error while responding to a DNS request: %s", err) + } else { + log.Printf("error while responding to a DNS request: %s", err) + } + } +} + +// Set TTL value of all records according to our settings +func (p *Proxy) setMinMaxTTL(r *dns.Msg) { + for _, rr := range r.Answer { + originalTTL := rr.Header().Ttl + newTTL := respectTTLOverrides(originalTTL, p.CacheMinTTL, p.CacheMaxTTL) + + if originalTTL != newTTL { + log.Debug("Override TTL from %d to %d", originalTTL, newTTL) + rr.Header().Ttl = newTTL + } + } +} + +func (p *Proxy) genServerFailure(request *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeServerFailure) + resp.RecursionAvailable = true + return &resp +} + +func (p *Proxy) genNotImpl(request *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeNotImplemented) + resp.RecursionAvailable = true + resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it + return &resp +} + +func (p *Proxy) genNXDomain(req *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetRcode(req, dns.RcodeNameError) + resp.RecursionAvailable = true + return &resp +} + +func (p *Proxy) logDNSMessage(m *dns.Msg) { + if m.Response { + log.Tracef("OUT: %s", m) + } else { + log.Tracef("IN: %s", m) + } +} diff --git a/proxy/server_https.go b/proxy/server_https.go new file mode 100644 index 000000000..324b0795c --- /dev/null +++ b/proxy/server_https.go @@ -0,0 +1,154 @@ +package proxy + +import ( + "encoding/base64" + "fmt" + "io/ioutil" + "net" + "net/http" + "strconv" + + "github.com/AdguardTeam/golibs/log" + "github.com/joomcode/errorx" + "github.com/miekg/dns" +) + +// serveHttps starts the HTTPS server +func (p *Proxy) listenHTTPS() { + log.Printf("Listening to DNS-over-HTTPS on %s", p.httpsListen.Addr()) + err := p.httpsServer.Serve(p.httpsListen) + + if err != http.ErrServerClosed { + log.Printf("HTTPS server was closed unexpectedly: %s", err) + } else { + log.Printf("HTTPS server was closed") + } +} + +// ServeHTTP is the http.RequestHandler implementation that handles DOH queries +// Here is what it returns: +// http.StatusBadRequest - if there is no DNS request data +// http.StatusUnsupportedMediaType - if request content type is not application/dns-message +// http.StatusMethodNotAllowed - if request method is not GET or POST +func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + log.Tracef("Incoming HTTPS request on %s", r.URL) + + var buf []byte + var err error + + switch r.Method { + case http.MethodGet: + dnsParam := r.URL.Query().Get("dns") + buf, err = base64.RawURLEncoding.DecodeString(dnsParam) + if len(buf) == 0 || err != nil { + log.Tracef("Cannot parse DNS request from %s", dnsParam) + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + case http.MethodPost: + contentType := r.Header.Get("Content-Type") + if contentType != "application/dns-message" { + log.Tracef("Unsupported media type: %s", contentType) + http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType) + return + } + + buf, err = ioutil.ReadAll(r.Body) + if err != nil { + log.Tracef("Cannot read the request body: %s", err) + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + defer r.Body.Close() + default: + log.Tracef("Wrong HTTP method: %s", r.Method) + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + + msg := new(dns.Msg) + if err = msg.Unpack(buf); err != nil { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + + addr, _ := p.remoteAddr(r) + + d := &DNSContext{ + Proto: ProtoHTTPS, + Req: msg, + Addr: addr, + HTTPRequest: r, + HTTPResponseWriter: w, + } + + err = p.handleDNSRequest(d) + if err != nil { + log.Tracef("error handling DNS (%s) request: %s", d.Proto, err) + } +} + +// Get a client IP address from HTTP headers that proxy servers may set +func getIPFromHTTPRequest(r *http.Request) net.IP { + names := []string{ + "CF-Connecting-IP", "True-Client-IP", // set by CloudFlare servers + "X-Real-IP", + } + for _, name := range names { + s := r.Header.Get(name) + ip := net.ParseIP(s) + if ip != nil { + return ip + } + } + + s := r.Header.Get("X-Forwarded-For") + s = splitNext(&s, ',') // get left-most IP address + ip := net.ParseIP(s) + if ip != nil { + return ip + } + + return nil +} + +// Writes a response to the DOH client +func (p *Proxy) respondHTTPS(d *DNSContext) error { + resp := d.Res + w := d.HTTPResponseWriter + + bytes, err := resp.Pack() + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return errorx.Decorate(err, "couldn't convert message into wire format: %s", resp.String()) + } + + w.Header().Set("Server", "AdGuard DNS") + w.Header().Set("Content-Type", "application/dns-message") + _, err = w.Write(bytes) + return err +} + +func (p *Proxy) remoteAddr(r *http.Request) (net.Addr, error) { + host, port, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return nil, err + } + + portValue, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + + ip := getIPFromHTTPRequest(r) + if ip != nil { + log.Debug("Using IP address from HTTP request: %s", ip) + } else { + ip = net.ParseIP(host) + if ip == nil { + return nil, fmt.Errorf("invalid IP: %s", host) + } + } + + return &net.TCPAddr{IP: ip, Port: portValue}, nil +} diff --git a/proxy/server_tcp.go b/proxy/server_tcp.go new file mode 100644 index 000000000..2b7ec4901 --- /dev/null +++ b/proxy/server_tcp.go @@ -0,0 +1,102 @@ +package proxy + +import ( + "fmt" + "net" + "time" + + "github.com/AdguardTeam/golibs/log" + "github.com/joomcode/errorx" + "github.com/miekg/dns" +) + +// tcpPacketLoop listens for incoming TCP packets +// proto is either "tcp" or "tls" +func (p *Proxy) tcpPacketLoop(l net.Listener, proto string) { + log.Printf("Entering the %s listener loop on %s", proto, l.Addr()) + for { + clientConn, err := l.Accept() + + if err != nil { + if isConnClosed(err) { + log.Printf("tcpListen.Accept() returned because we're reading from a closed connection, exiting loop") + break + } + log.Printf("got error when reading from TCP listen: %s", err) + } else { + p.guardMaxGoroutines() + go func() { + p.handleTCPConnection(clientConn, proto) + p.freeMaxGoroutines() + }() + } + } +} + +// handleTCPConnection starts a loop that handles an incoming TCP connection +// proto is either "tcp" or "tls" +func (p *Proxy) handleTCPConnection(conn net.Conn, proto string) { + log.Tracef("Start handling the new %s connection %s", proto, conn.RemoteAddr()) + defer conn.Close() + + for { + p.RLock() + if !p.started { + return + } + p.RUnlock() + + conn.SetDeadline(time.Now().Add(defaultTimeout)) //nolint + packet, err := readPrefixed(&conn) + if err != nil { + return + } + + msg := &dns.Msg{} + err = msg.Unpack(packet) + if err != nil { + log.Printf("error handling TCP packet: %s", err) + return + } + + d := &DNSContext{ + Proto: proto, + Req: msg, + Addr: conn.RemoteAddr(), + Conn: conn, + } + + err = p.handleDNSRequest(d) + if err != nil { + log.Tracef("error handling DNS (%s) request: %s", d.Proto, err) + } + } +} + +// Writes a response to the TCP (or TLS) client +func (p *Proxy) respondTCP(d *DNSContext) error { + resp := d.Res + conn := d.Conn + + bytes, err := resp.Pack() + if err != nil { + return errorx.Decorate(err, "couldn't convert message into wire format: %s", resp.String()) + } + + bytes, err = prefixWithSize(bytes) + if err != nil { + return errorx.Decorate(err, "couldn't add prefix with size") + } + + n, err := conn.Write(bytes) + if n == 0 && isConnClosed(err) { + return err + } + if err != nil { + return errorx.Decorate(err, "conn.Write() returned error") + } + if n != len(bytes) { + return fmt.Errorf("conn.Write() returned with %d != %d", n, len(bytes)) + } + return nil +} diff --git a/proxyutil/helpers.go b/proxyutil/helpers.go index e829b6f9c..e010b5ff5 100644 --- a/proxyutil/helpers.go +++ b/proxyutil/helpers.go @@ -10,6 +10,7 @@ import ( ) // GetIPFromDNSRecord - extracts IP address for a DNS record +// returns null if the record is of a wrong type func GetIPFromDNSRecord(r dns.RR) net.IP { switch addr := r.(type) { case *dns.A: @@ -21,6 +22,17 @@ func GetIPFromDNSRecord(r dns.RR) net.IP { return nil } +// ContainsIP checks if the specified IP is in the array +func ContainsIP(ips []net.IP, ip net.IP) bool { + for _, i := range ips { + if i.Equal(ip) { + return true + } + } + + return false +} + // AppendIPAddrs appends the IP addresses got from dns.RR to the specified array func AppendIPAddrs(ipAddrs *[]net.IPAddr, answers []dns.RR) { for _, ans := range answers { diff --git a/proxyutil/helpers_test.go b/proxyutil/helpers_test.go index 0eee12206..27877ae0d 100644 --- a/proxyutil/helpers_test.go +++ b/proxyutil/helpers_test.go @@ -21,3 +21,21 @@ func TestSortIPAddrs(t *testing.T) { assert.Equal(t, ipAddrs[2].String(), "2a00:5a60::bad1:ff") assert.Equal(t, ipAddrs[3].String(), "2a00:5a60::bad2:ff") } + +func TestContainsIP(t *testing.T) { + ips := []net.IP{} + ips = append(ips, net.ParseIP("176.103.130.134")) + ips = append(ips, net.ParseIP("2a00:5a60::bad1:ff")) + + ip := net.ParseIP("176.103.130.134") + assert.True(t, ContainsIP(ips, ip)) + + ip = net.ParseIP("2a00:5a60::bad1:ff") + assert.True(t, ContainsIP(ips, ip)) + + ip = net.ParseIP("2a00:5a60::bad1:ff1") + assert.False(t, ContainsIP(ips, ip)) + + ip = net.ParseIP("127.0.0.1") + assert.False(t, ContainsIP(ips, ip)) +}