diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 4ea0fd092cb..4744ef5e7d6 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -10,6 +10,7 @@ import ( "strings" "sync" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" @@ -54,7 +55,7 @@ type Config struct { BlockedServices []string `yaml:"blocked_services"` // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files - AutoHosts *AutoHosts `yaml:"-"` + AutoHosts *util.AutoHosts `yaml:"-"` // Called when the configuration is changed by HTTP request ConfigModified func() `yaml:"-"` @@ -143,8 +144,8 @@ const ( // ReasonRewrite - rewrite rule was applied ReasonRewrite - // ReasonRewriteAuto - automatic DNS record - ReasonRewriteAuto + // RewriteEtcHosts - rewrite by /etc/hosts rule + RewriteEtcHosts ) var reasonNames = []string{ @@ -160,7 +161,7 @@ var reasonNames = []string{ "FilteredBlockedService", "Rewrite", - "RewriteAuto", + "RewriteEtcHosts", } func (r Reason) String() string { @@ -313,7 +314,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering if d.Config.AutoHosts != nil { ips := d.Config.AutoHosts.Process(host) if ips != nil { - result.Reason = ReasonRewriteAuto + result.Reason = RewriteEtcHosts result.IPList = ips return result, nil } diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 2df228e0578..532e1c3c0db 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -2,7 +2,6 @@ package dnsfilter import ( "fmt" - "io/ioutil" "net" "os" "path" @@ -630,24 +629,6 @@ func prepareTestDir() string { return dir } -func TestAutoHosts(t *testing.T) { - ah := AutoHosts{} - ah.table = make(map[string][]net.IP) - - dir := prepareTestDir() - defer func() { _ = os.RemoveAll(dir) }() - - f, _ := ioutil.TempFile(dir, "") - defer os.Remove(f.Name()) - defer f.Close() - - f.WriteString(" 127.0.0.1 host localhost ") - - ah.load(ah.table, f.Name()) - ips := ah.Process("localhost") - assert.True(t, ips[0].Equal(net.ParseIP("127.0.0.1"))) -} - // BENCHMARKS func BenchmarkSafeBrowsing(b *testing.B) { diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index d536065383a..9fe3b20078e 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -680,7 +680,7 @@ func processFilteringAfterResponse(ctx *dnsContext) int { d.Res.Answer = answer } - case dnsfilter.ReasonRewriteAuto: + case dnsfilter.RewriteEtcHosts: case dnsfilter.NotFilteredWhiteList: // nothing @@ -856,7 +856,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { // log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) d.Res = s.genDNSFilterMessage(d, &res) - } else if (res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.ReasonRewriteAuto) && + } else if (res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteEtcHosts) && len(res.IPList) != 0 { resp := s.makeResponse(req) diff --git a/home/clients.go b/home/clients.go index 0da39d288c4..b943d25d5da 100644 --- a/home/clients.go +++ b/home/clients.go @@ -14,6 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/utils" @@ -77,14 +78,14 @@ type clientsContainer struct { // dhcpServer is used for looking up clients IP addresses by MAC addresses dhcpServer *dhcpd.Server - autoHosts *dnsfilter.AutoHosts // get entries from system hosts-files + autoHosts *util.AutoHosts // get entries from system hosts-files testing bool // if TRUE, this object is used for internal tests } // Init initializes clients container // Note: this function must be called only once -func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.Server, autoHosts *dnsfilter.AutoHosts) { +func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.Server, autoHosts *util.AutoHosts) { if clients.list != nil { log.Fatal("clients.list != nil") } diff --git a/home/home.go b/home/home.go index a34e246ac4b..9712373fb4f 100644 --- a/home/home.go +++ b/home/home.go @@ -68,7 +68,7 @@ type homeContext struct { filters Filtering // DNS filtering module web *Web // Web (HTTP, HTTPS) module tls *TLSMod // TLS module - autoHosts dnsfilter.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files + autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files // Runtime properties // -- diff --git a/dnsfilter/auto_hosts.go b/util/auto_hosts.go similarity index 96% rename from dnsfilter/auto_hosts.go rename to util/auto_hosts.go index 59562f45a83..aa64151e048 100644 --- a/dnsfilter/auto_hosts.go +++ b/util/auto_hosts.go @@ -1,4 +1,4 @@ -package dnsfilter +package util import ( "bufio" @@ -10,7 +10,6 @@ import ( "strings" "sync" - "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/log" "github.com/fsnotify/fsnotify" ) @@ -52,7 +51,7 @@ func (a *AutoHosts) Init() { a.hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts") } - if util.IsOpenWrt() { + if IsOpenWrt() { a.hostsDirs = append(a.hostsDirs, "/tmp/hosts") // OpenWRT: "/tmp/hosts/dhcp.cfg01411c" } @@ -112,13 +111,13 @@ func (a *AutoHosts) load(table map[string][]net.IP, fn string) { } line = strings.TrimSpace(line) - ip := util.SplitNext(&line, ' ') + ip := SplitNext(&line, ' ') ipAddr := net.ParseIP(ip) if ipAddr == nil { continue } for { - host := util.SplitNext(&line, ' ') + host := SplitNext(&line, ' ') if len(host) == 0 { break } diff --git a/util/auto_hosts_test.go b/util/auto_hosts_test.go new file mode 100644 index 00000000000..32d57d540b4 --- /dev/null +++ b/util/auto_hosts_test.go @@ -0,0 +1,28 @@ +package util + +import ( + "io/ioutil" + "net" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAutoHosts(t *testing.T) { + ah := AutoHosts{} + ah.table = make(map[string][]net.IP) + + dir := prepareTestDir() + defer func() { _ = os.RemoveAll(dir) }() + + f, _ := ioutil.TempFile(dir, "") + defer os.Remove(f.Name()) + defer f.Close() + + f.WriteString(" 127.0.0.1 host localhost ") + + ah.load(ah.table, f.Name()) + ips := ah.Process("localhost") + assert.True(t, ips[0].Equal(net.ParseIP("127.0.0.1"))) +}