Skip to content

Commit

Permalink
Merge pull request #14 from projectdiscovery/atomic-misc
Browse files Browse the repository at this point in the history
Use atomics + improved deduplication and error handling
  • Loading branch information
Ice3man543 authored Feb 2, 2021
2 parents 118433d + c34b4d3 commit 99b511e
Showing 1 changed file with 71 additions and 120 deletions.
191 changes: 71 additions & 120 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,26 @@ import (
"bytes"
"encoding/gob"
"encoding/json"
"math/rand"
"errors"
"net"
"reflect"
"sort"
"strings"
"sync"
"time"
"sync/atomic"

"github.com/miekg/dns"
)

const defaultPort = "53"

// Client is a DNS resolver client to resolve hostnames.
type Client struct {
resolvers []string
maxRetries int
rand *rand.Rand
mutex *sync.Mutex
resolvers []string
maxRetries int
serversIndex uint32
}

const defaultPort = "53"

// New creates a new dns client
func New(baseResolvers []string, maxRetries int) *Client {
client := Client{
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
mutex: &sync.Mutex{},
maxRetries: maxRetries,
resolvers: baseResolvers,
}
Expand All @@ -43,22 +37,24 @@ func (c *Client) Resolve(host string) (*DNSData, error) {
}

// Do sends a provided dns request and return the raw native response
func (c *Client) Do(msg *dns.Msg) (resp *dns.Msg, err error) {

func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) {
for i := 0; i < c.maxRetries; i++ {
resolver := c.resolvers[rand.Intn(len(c.resolvers))]
resp, err = dns.Exchange(msg, resolver)
if err != nil {
index := atomic.AddUint32(&c.serversIndex, 1)
resolver := c.resolvers[index%uint32(len(c.resolvers))]

resp, err := dns.Exchange(msg, resolver)
if err != nil || resp == nil {
continue
}

// In case we get a non empty answer stop retrying
if resp != nil {
return
if resp.Rcode != dns.RcodeSuccess {
continue
}
}

return
// In case we get a non empty answer stop retrying
return resp, nil
}
return nil, errors.New("could not resolve, max retries exceeded")
}

// Query sends a provided dns request and return enriched response
Expand All @@ -71,9 +67,9 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
var (
dnsdata DNSData
err error
msg dns.Msg
)

msg := dns.Msg{}
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = make([]dns.Question, 1)
Expand All @@ -100,75 +96,36 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
}
msg.Question[0] = question

var resp *dns.Msg
for i := 0; i < c.maxRetries; i++ {
resolver := c.resolvers[rand.Intn(len(c.resolvers))]
var resp *dns.Msg
index := atomic.AddUint32(&c.serversIndex, 1)
resolver := c.resolvers[index%uint32(len(c.resolvers))]

resp, err = dns.Exchange(&msg, resolver)
if err != nil {
if err != nil || resp == nil {
continue
}

dnsdata.Host = host
dnsdata.Raw += resp.String()
dnsdata.StatusCode = dns.RcodeToString[resp.Rcode]
dnsdata.Resolver = append(dnsdata.Resolver, resolver)

// In case we got some error from the server, return.
if resp != nil && resp.Rcode != dns.RcodeSuccess {
break
if resp.Rcode != dns.RcodeSuccess {
continue
}

dnsdata.ParseFromMsg(resp)
break
}
}

dnsdata.dedupe()

return &dnsdata, err
}

func parse(answer *dns.Msg, requestType uint16) (results []string) {
for _, record := range answer.Answer {
switch requestType {
case dns.TypeA:
if t, ok := record.(*dns.A); ok {
results = append(results, t.A.String())
}
case dns.TypeNS:
if t, ok := record.(*dns.NS); ok {
results = append(results, t.Ns)
}
case dns.TypeCNAME:
if t, ok := record.(*dns.CNAME); ok {
results = append(results, t.Target)
}
case dns.TypeSOA:
if t, ok := record.(*dns.SOA); ok {
results = append(results, t.Mbox)
}
case dns.TypePTR:
if t, ok := record.(*dns.PTR); ok {
results = append(results, t.Ptr)
}
case dns.TypeMX:
if t, ok := record.(*dns.MX); ok {
results = append(results, t.Mx)
}
case dns.TypeTXT:
if t, ok := record.(*dns.TXT); ok {
results = append(results, t.Txt...)
}
case dns.TypeAAAA:
if t, ok := record.(*dns.AAAA); ok {
results = append(results, t.AAAA.String())
if !dnsdata.contains() {
continue
}
dnsdata.Host = host
dnsdata.Raw += resp.String()
dnsdata.StatusCode = dns.RcodeToString[resp.Rcode]
dnsdata.Resolver = append(dnsdata.Resolver, resolver)
dnsdata.dedupe()
return &dnsdata, err
}
}

return
return nil, err
}

// DNSData is the data for a DNS request response
type DNSData struct {
Host string `json:"host,omitempty"`
TTL int `json:"ttl,omitempty"`
Expand Down Expand Up @@ -209,10 +166,16 @@ func (d *DNSData) ParseFromMsg(msg *dns.Msg) error {
d.AAAA = append(d.AAAA, trimChars(record.(*dns.AAAA).AAAA.String()))
}
}

return nil
}

func (d *DNSData) contains() bool {
if len(d.A) > 0 || len(d.AAAA) > 0 || len(d.CNAME) > 0 || len(d.MX) > 0 || len(d.NS) > 0 || len(d.PTR) > 0 || len(d.TXT) > 0 {
return true
}
return false
}

// JSON returns the object as json string
func (d *DNSData) JSON() (string, error) {
b, err := json.Marshal(&d)
Expand All @@ -223,59 +186,47 @@ func trimChars(s string) string {
return strings.TrimRight(s, ".")
}

func (r *DNSData) dedupe() {
// dedupe all records
dedupeSlice(&r.Resolver, less(&r.Resolver))
dedupeSlice(&r.A, less(&r.A))
dedupeSlice(&r.AAAA, less(&r.AAAA))
dedupeSlice(&r.CNAME, less(&r.CNAME))
dedupeSlice(&r.MX, less(&r.MX))
dedupeSlice(&r.PTR, less(&r.PTR))
dedupeSlice(&r.SOA, less(&r.SOA))
dedupeSlice(&r.NS, less(&r.NS))
dedupeSlice(&r.TXT, less(&r.TXT))
func (d *DNSData) dedupe() {
d.Resolver = deduplicate(d.Resolver)
d.A = deduplicate(d.A)
d.AAAA = deduplicate(d.AAAA)
d.CNAME = deduplicate(d.CNAME)
d.MX = deduplicate(d.MX)
d.PTR = deduplicate(d.PTR)
d.SOA = deduplicate(d.SOA)
d.NS = deduplicate(d.NS)
d.TXT = deduplicate(d.TXT)
}

func (r *DNSData) Marshal() ([]byte, error) {
// Marshal encodes the dnsdata to a binary representation
func (d *DNSData) Marshal() ([]byte, error) {
var b bytes.Buffer
enc := gob.NewEncoder(&b)
err := enc.Encode(r)
err := enc.Encode(d)
if err != nil {
return nil, err
}

return b.Bytes(), nil
}

func (r *DNSData) Unmarshal(b []byte) error {
// Unmarshal decodes the dnsdata from a binary representation
func (d *DNSData) Unmarshal(b []byte) error {
dec := gob.NewDecoder(bytes.NewBuffer(b))
err := dec.Decode(&r)
if err != nil {
return err
}
return nil
}

func less(v interface{}) func(i, j int) bool {
s := *v.(*[]string)
return func(i, j int) bool { return s[i] < s[j] }
return dec.Decode(&d)
}

func dedupeSlice(slicePtr interface{}, less func(i, j int) bool) {
v := reflect.ValueOf(slicePtr).Elem()
if v.Len() <= 1 {
return
// deduplicate returns a new slice with duplicates values removed.
func deduplicate(s []string) []string {
if len(s) < 2 {
return s
}
sort.Slice(v.Interface(), less)

i := 0
for j := 1; j < v.Len(); j++ {
if !less(i, j) {
continue
var results []string
seen := make(map[string]struct{})
for _, val := range s {
if _, ok := seen[val]; !ok {
results = append(results, val)
seen[val] = struct{}{}
}
i++
v.Index(i).Set(v.Index(j))
}
i++
v.SetLen(i)
return results
}

0 comments on commit 99b511e

Please sign in to comment.