diff --git a/client.go b/client.go index 6f90dac..e98797f 100644 --- a/client.go +++ b/client.go @@ -22,7 +22,15 @@ import ( sliceutil "github.com/projectdiscovery/utils/slice" ) -var ErrRetriesExceeded = errors.New("could not resolve, max retries exceeded") +var () + +var ( + // DefaultMaxPerCNAMEFollows is the default number of times a CNAME can be followed within a trace + DefaultMaxPerCNAMEFollows = 32 + + // ErrRetriesExceeded is the error returned when the max retries are exceeded + ErrRetriesExceeded = errors.New("could not resolve, max retries exceeded") +) var internalRangeCheckerInstance *internalRangeChecker @@ -64,6 +72,10 @@ func NewWithOptions(options Options) (*Client, error) { knownHosts, _ = hostsfile.ParseDefault() } + if options.MaxPerCNAMEFollows == 0 { + options.MaxPerCNAMEFollows = DefaultMaxPerCNAMEFollows + } + httpClient := doh.NewHttpClientWithTimeout(options.Timeout) client := Client{ @@ -480,6 +492,7 @@ func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*Trac msg.SetQuestion(host, requestType) servers := RootDNSServersIPv4 seenNS := make(map[string]struct{}) + seenCName := make(map[string]int) for i := 1; i < maxrecursion; i++ { msg.SetQuestion(host, requestType) dnsdatas, err := c.QueryParallel(host, requestType, servers) @@ -542,6 +555,10 @@ func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*Trac // follow cname if any if nextCname != "" { + seenCName[nextCname]++ + if seenCName[nextCname] > c.options.MaxPerCNAMEFollows { + break + } host = nextCname } } diff --git a/options.go b/options.go index df94050..7e0967e 100644 --- a/options.go +++ b/options.go @@ -20,6 +20,7 @@ type Options struct { LocalAddrIP net.IP LocalAddrPort uint16 ConnectionPoolThreads int + MaxPerCNAMEFollows int } // Returns a net.Addr of a UDP or TCP type depending on whats required