Skip to content
This repository was archived by the owner on Jun 20, 2024. It is now read-only.

Commit

Permalink
Refactor dns server to pass configurables to handlers in a struct. Al…
Browse files Browse the repository at this point in the history
…so ignore EDNS on tcp requests.
  • Loading branch information
Tom Wilkie committed Aug 19, 2015
1 parent 99ab0d3 commit c485b1e
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 123 deletions.
248 changes: 129 additions & 119 deletions nameserver/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ func (d *DNSServer) listen(address string) error {
if err != nil {
return err
}
udpServer := &dns.Server{PacketConn: udpListener, Handler: d.createMux(d.udpClient, minUDPSize)}
udpServer := &dns.Server{PacketConn: udpListener, Handler: d.createMux(d.udpClient, minUDPSize).mux}

tcpListener, err := net.Listen("tcp", address)
if err != nil {
udpServer.Shutdown()
return err
}
tcpServer := &dns.Server{Listener: tcpListener, Handler: d.createMux(d.tcpClient, -1)}
tcpServer := &dns.Server{Listener: tcpListener, Handler: d.createMux(d.tcpClient, -1).mux}

d.servers = []*dns.Server{udpServer, tcpServer}
return nil
Expand Down Expand Up @@ -112,144 +112,146 @@ func (d *DNSServer) errorResponse(r *dns.Msg, code int, w dns.ResponseWriter) {
}
}

func (d *DNSServer) createMux(client *dns.Client, defaultMaxResponseSize int) *dns.ServeMux {
m := dns.NewServeMux()
m.HandleFunc(d.domain, d.handleLocal(defaultMaxResponseSize))
m.HandleFunc(reverseDNSdomain, d.handleReverse(client, defaultMaxResponseSize))
m.HandleFunc(topDomain, d.handleRecursive(client, defaultMaxResponseSize))
return m
type handler struct {
*DNSServer
maxResponseSize int
client *dns.Client
mux *dns.ServeMux
}

func (d *DNSServer) handleLocal(defaultMaxResponseSize int) func(dns.ResponseWriter, *dns.Msg) {
return func(w dns.ResponseWriter, req *dns.Msg) {
d.ns.debugf("local request: %+v", *req)
if len(req.Question) != 1 || req.Question[0].Qtype != dns.TypeA {
d.errorResponse(req, dns.RcodeNameError, w)
return
}
func (d *DNSServer) createMux(client *dns.Client, defaultMaxResponseSize int) *handler {
h := &handler{
DNSServer: d,
maxResponseSize: defaultMaxResponseSize,
client: client,
mux: dns.NewServeMux(),
}
h.mux.HandleFunc(d.domain, h.handleLocal)
h.mux.HandleFunc(reverseDNSdomain, h.handleReverse)
h.mux.HandleFunc(topDomain, h.handleRecursive)
return h
}

hostname := dns.Fqdn(req.Question[0].Name)
if strings.Count(hostname, ".") == 1 {
hostname = hostname + d.domain
}
func (h *handler) handleLocal(w dns.ResponseWriter, req *dns.Msg) {
h.ns.debugf("local request: %+v", *req)
if len(req.Question) != 1 || req.Question[0].Qtype != dns.TypeA {
h.errorResponse(req, dns.RcodeNameError, w)
return
}

addrs := d.ns.Lookup(hostname)
if len(addrs) == 0 {
d.errorResponse(req, dns.RcodeNameError, w)
return
}
hostname := dns.Fqdn(req.Question[0].Name)
if strings.Count(hostname, ".") == 1 {
hostname = hostname + h.domain
}

response := dns.Msg{}
response.RecursionAvailable = true
response.Authoritative = true
response.SetReply(req)
response.Answer = make([]dns.RR, len(addrs))

header := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: d.ttl,
}
addrs := h.ns.Lookup(hostname)
if len(addrs) == 0 {
h.errorResponse(req, dns.RcodeNameError, w)
return
}

for i, addr := range addrs {
ip := addr.IP4()
response.Answer[i] = &dns.A{Hdr: header, A: ip}
}
response := dns.Msg{}
response.RecursionAvailable = true
response.Authoritative = true
response.SetReply(req)
response.Answer = make([]dns.RR, len(addrs))

header := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: h.ttl,
}

shuffleAnswers(&response.Answer)
maxResponseSize := getMaxResponseSize(req, defaultMaxResponseSize)
truncateResponse(&response, maxResponseSize)
for i, addr := range addrs {
ip := addr.IP4()
response.Answer[i] = &dns.A{Hdr: header, A: ip}
}

d.ns.debugf("response: %+v", response)
if err := w.WriteMsg(&response); err != nil {
d.ns.infof("error responding: %v", err)
}
shuffleAnswers(&response.Answer)
h.truncateResponse(req, &response)
h.ns.debugf("response: %+v", response)
if err := w.WriteMsg(&response); err != nil {
h.ns.infof("error responding: %v", err)
}
}

func (d *DNSServer) handleReverse(client *dns.Client, defaultMaxResponseSize int) func(dns.ResponseWriter, *dns.Msg) {
return func(w dns.ResponseWriter, req *dns.Msg) {
d.ns.debugf("reverse request: %+v", *req)
if len(req.Question) != 1 || req.Question[0].Qtype != dns.TypePTR {
d.errorResponse(req, dns.RcodeNameError, w)
return
}

ipStr := strings.TrimSuffix(req.Question[0].Name, "."+reverseDNSdomain)
ip, err := address.ParseIP(ipStr)
if err != nil {
d.errorResponse(req, dns.RcodeNameError, w)
return
}
func (h *handler) handleReverse(w dns.ResponseWriter, req *dns.Msg) {
h.ns.debugf("reverse request: %+v", *req)
if len(req.Question) != 1 || req.Question[0].Qtype != dns.TypePTR {
h.errorResponse(req, dns.RcodeNameError, w)
return
}

hostname, err := d.ns.ReverseLookup(ip.Reverse())
if err != nil {
d.handleRecursive(client, defaultMaxResponseSize)(w, req)
return
}
ipStr := strings.TrimSuffix(req.Question[0].Name, "."+reverseDNSdomain)
ip, err := address.ParseIP(ipStr)
if err != nil {
h.errorResponse(req, dns.RcodeNameError, w)
return
}

response := dns.Msg{}
response.RecursionAvailable = true
response.Authoritative = true
response.SetReply(req)
hostname, err := h.ns.ReverseLookup(ip.Reverse())
if err != nil {
h.handleRecursive(w, req)
return
}

header := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: d.ttl,
}
response := dns.Msg{}
response.RecursionAvailable = true
response.Authoritative = true
response.SetReply(req)

response.Answer = []dns.RR{&dns.PTR{
Hdr: header,
Ptr: hostname,
}}
header := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: h.ttl,
}

maxResponseSize := getMaxResponseSize(req, defaultMaxResponseSize)
truncateResponse(&response, maxResponseSize)
response.Answer = []dns.RR{&dns.PTR{
Hdr: header,
Ptr: hostname,
}}

d.ns.debugf("response: %+v", response)
if err := w.WriteMsg(&response); err != nil {
d.ns.infof("error responding: %v", err)
}
h.truncateResponse(req, &response)
h.ns.debugf("response: %+v", response)
if err := w.WriteMsg(&response); err != nil {
h.ns.infof("error responding: %v", err)
}
}

func (d *DNSServer) handleRecursive(client *dns.Client, defaultMaxResponseSize int) func(dns.ResponseWriter, *dns.Msg) {
return func(w dns.ResponseWriter, req *dns.Msg) {
d.ns.debugf("recursive request: %+v", *req)

// Resolve unqualified names locally
if len(req.Question) == 1 && req.Question[0].Qtype == dns.TypeA {
hostname := dns.Fqdn(req.Question[0].Name)
if strings.Count(hostname, ".") == 1 {
d.handleLocal(defaultMaxResponseSize)(w, req)
return
}
}
func (h *handler) handleRecursive(w dns.ResponseWriter, req *dns.Msg) {
h.ns.debugf("recursive request: %+v", *req)

for _, server := range d.upstream.Servers {
reqCopy := req.Copy()
reqCopy.Id = dns.Id()
response, _, err := client.Exchange(reqCopy, fmt.Sprintf("%s:%s", server, d.upstream.Port))
if err != nil || response == nil {
d.ns.debugf("error trying %s: %v", server, err)
continue
}
d.ns.debugf("response: %+v", response)
response.Id = req.Id
if response.Len() > getMaxResponseSize(req, defaultMaxResponseSize) {
response.Compress = true
}
if err := w.WriteMsg(response); err != nil {
d.ns.infof("error responding: %v", err)
}
// Resolve unqualified names locally
if len(req.Question) == 1 && req.Question[0].Qtype == dns.TypeA {
hostname := dns.Fqdn(req.Question[0].Name)
if strings.Count(hostname, ".") == 1 {
h.handleLocal(w, req)
return
}
}

d.errorResponse(req, dns.RcodeServerFailure, w)
for _, server := range h.upstream.Servers {
reqCopy := req.Copy()
reqCopy.Id = dns.Id()
response, _, err := h.client.Exchange(reqCopy, fmt.Sprintf("%s:%s", server, h.upstream.Port))
if err != nil || response == nil {
h.ns.debugf("error trying %s: %v", server, err)
continue
}
h.ns.debugf("response: %+v", response)
response.Id = req.Id
if h.responseTooBig(req, response) {
response.Compress = true
}
if err := w.WriteMsg(response); err != nil {
h.ns.infof("error responding: %v", err)
}
return
}

h.errorResponse(req, dns.RcodeServerFailure, w)
}

func shuffleAnswers(answers *[]dns.RR) {
Expand All @@ -263,13 +265,14 @@ func shuffleAnswers(answers *[]dns.RR) {
}
}

func truncateResponse(response *dns.Msg, maxSize int) {
if len(response.Answer) <= 1 || maxSize <= 0 {
func (h *handler) truncateResponse(request, response *dns.Msg) {
if !h.responseTooBig(request, response) {
return
}

// take a copy of answers, as we're going to mutate response
answers := response.Answer
maxSize := h.getMaxResponseSize(request)

// search for smallest i that is too big
i := sort.Search(len(response.Answer), func(i int) bool {
Expand All @@ -286,9 +289,16 @@ func truncateResponse(response *dns.Msg, maxSize int) {
response.Truncated = true
}

func getMaxResponseSize(req *dns.Msg, defaultMaxResponseSize int) int {
func (h *handler) responseTooBig(request, response *dns.Msg) bool {
if len(response.Answer) <= 1 || h.maxResponseSize <= 0 {
return false
}
return response.Len() > h.getMaxResponseSize(request)
}

func (h *handler) getMaxResponseSize(req *dns.Msg) int {
if opt := req.IsEdns0(); opt != nil {
return int(opt.UDPSize())
}
return defaultMaxResponseSize
return h.maxResponseSize
}
4 changes: 2 additions & 2 deletions nameserver/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ func TestTruncation(t *testing.T) {
}

func TestTruncateResponse(t *testing.T) {

header := dns.RR_Header{
Name: "host.domain.com",
Rrtype: dns.TypePTR,
Expand All @@ -93,7 +92,8 @@ func TestTruncateResponse(t *testing.T) {

// pick a random max size, truncate response to that, check it
maxSize := 512 + rand.Intn(response.Len()-512)
truncateResponse(response, maxSize)
h := handler{maxResponseSize: maxSize}
h.truncateResponse(&dns.Msg{}, response)
require.True(t, response.Len() <= maxSize)
}
}
Expand Down
14 changes: 12 additions & 2 deletions test/295_dns_large_response_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@ for i in $(seq $N); do
done
weave_on $HOST1 dns-add $IPS $CID -h $NAME

assert_dns_record $HOST1 c0 $NAME $IPS

assert_raises "exec_on $HOST1 c0 dig MX $NAME | grep -q 'status: NXDOMAIN'"

M=$(exec_on $HOST1 c0 dig +short $NAME A | grep -v ';;' | wc -l)
assert_raises "test $M -eq $N"

M=$(exec_on $HOST1 c0 dig +tcp +short $NAME A | grep -v ';;' | wc -l)
assert_raises "test $M -eq $N"

M=$(exec_on $HOST1 c0 dig +bufsize=700 +short $NAME A | grep -v ';;' | wc -l)
assert_raises "test $M -eq $N"

M=$(exec_on $HOST1 c0 dig +tcp +bufsize=700 +short $NAME A | grep -v ';;' | wc -l)
assert_raises "test $M -eq $N"

end_suite

0 comments on commit c485b1e

Please sign in to comment.