diff --git a/endpoint/domain_filter.go b/endpoint/domain_filter.go index 8ddde179b6..7bfe1b3e67 100644 --- a/endpoint/domain_filter.go +++ b/endpoint/domain_filter.go @@ -17,7 +17,11 @@ limitations under the License. package endpoint import ( + "encoding/json" + "errors" + "fmt" "regexp" + "sort" "strings" ) @@ -52,6 +56,14 @@ type DomainFilter struct { regexExclusion *regexp.Regexp } +// domainFilterSerde is a helper type for serializing and deserializing DomainFilter. +type domainFilterSerde struct { + Include []string `json:"include,omitempty"` + Exclude []string `json:"exclude,omitempty"` + RegexInclude string `json:"regexInclude,omitempty"` + RegexExclude string `json:"regexExclude,omitempty"` +} + // prepareFilters provides consistent trimming for filters/exclude params func prepareFilters(filters []string) []string { var fs []string @@ -159,3 +171,58 @@ func (df DomainFilter) IsConfigured() bool { } return len(df.Filters) > 0 || len(df.exclude) > 0 } + +func (df DomainFilter) MarshalJSON() ([]byte, error) { + if df.regex != nil || df.regexExclusion != nil { + var include, exclude string + if df.regex != nil { + include = df.regex.String() + } + if df.regexExclusion != nil { + exclude = df.regexExclusion.String() + } + return json.Marshal(domainFilterSerde{ + RegexInclude: include, + RegexExclude: exclude, + }) + } + sort.Strings(df.Filters) + sort.Strings(df.exclude) + return json.Marshal(domainFilterSerde{ + Include: df.Filters, + Exclude: df.exclude, + }) +} + +func (df *DomainFilter) UnmarshalJSON(b []byte) error { + var deserialized domainFilterSerde + err := json.Unmarshal(b, &deserialized) + if err != nil { + return err + } + + if deserialized.RegexInclude == "" && deserialized.RegexExclude == "" { + *df = NewDomainFilterWithExclusions(deserialized.Include, deserialized.Exclude) + return nil + } + + if len(deserialized.Include) > 0 || len(deserialized.Exclude) > 0 { + return errors.New("cannot have both domain list and regex") + } + + var include, exclude *regexp.Regexp + if deserialized.RegexInclude != "" { + include, err = regexp.Compile(deserialized.RegexInclude) + if err != nil { + return fmt.Errorf("invalid regexInclude: %w", err) + } + } + if deserialized.RegexExclude != "" { + exclude, err = regexp.Compile(deserialized.RegexExclude) + if err != nil { + return fmt.Errorf("invalid regexExclude: %w", err) + } + } + *df = NewRegexDomainFilter(include, exclude) + return nil +} diff --git a/endpoint/domain_filter_test.go b/endpoint/domain_filter_test.go index 7cfd781263..dcd518bab4 100644 --- a/endpoint/domain_filter_test.go +++ b/endpoint/domain_filter_test.go @@ -17,24 +17,29 @@ limitations under the License. package endpoint import ( + "encoding/json" + "fmt" "regexp" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type domainFilterTest struct { - domainFilter []string - exclusions []string - domains []string - expected bool + domainFilter []string + exclusions []string + domains []string + expected bool + expectedSerialization map[string][]string } type regexDomainFilterTest struct { - regex *regexp.Regexp - regexExclusion *regexp.Regexp - domains []string - expected bool + regex *regexp.Regexp + regexExclusion *regexp.Regexp + domains []string + expected bool + expectedSerialization map[string]string } var domainFilterTests = []domainFilterTest{ @@ -43,180 +48,274 @@ var domainFilterTests = []domainFilterTest{ []string{}, []string{"google.com", "exaring.de", "inovex.de"}, true, + map[string][]string{ + "include": {"exaring.de", "google.com", "inovex.de"}, + }, }, { []string{"google.com.", "exaring.de", "inovex.de"}, []string{}, []string{"google.com", "exaring.de", "inovex.de"}, true, + map[string][]string{ + "include": {"exaring.de", "google.com", "inovex.de"}, + }, }, { []string{"google.com.", "exaring.de.", "inovex.de"}, []string{}, []string{"google.com", "exaring.de", "inovex.de"}, true, + map[string][]string{ + "include": {"exaring.de", "google.com", "inovex.de"}, + }, }, { []string{"foo.org. "}, []string{}, []string{"foo.org"}, true, + map[string][]string{ + "include": {"foo.org"}, + }, }, { []string{" foo.org"}, []string{}, []string{"foo.org"}, true, + map[string][]string{ + "include": {"foo.org"}, + }, }, { []string{"foo.org."}, []string{}, []string{"foo.org"}, true, + map[string][]string{ + "include": {"foo.org"}, + }, }, { []string{"foo.org."}, []string{}, []string{"baz.org"}, false, + map[string][]string{ + "include": {"foo.org"}, + }, }, { []string{"baz.foo.org."}, []string{}, []string{"foo.org"}, false, + map[string][]string{ + "include": {"baz.foo.org"}, + }, }, { []string{"", "foo.org."}, []string{}, []string{"foo.org"}, true, + map[string][]string{ + "include": {"foo.org"}, + }, }, { []string{"", "foo.org."}, []string{}, []string{}, true, + map[string][]string{ + "include": {"foo.org"}, + }, }, { []string{""}, []string{}, []string{"foo.org"}, true, + map[string][]string{}, }, { []string{""}, []string{}, []string{}, true, + map[string][]string{}, }, { []string{" "}, []string{}, []string{}, true, + map[string][]string{}, }, { []string{"bar.sub.example.org"}, []string{}, []string{"foo.bar.sub.example.org"}, true, + map[string][]string{ + "include": {"bar.sub.example.org"}, + }, }, { []string{"example.org"}, []string{}, []string{"anexample.org", "test.anexample.org"}, false, + map[string][]string{ + "include": {"example.org"}, + }, }, { []string{".example.org"}, []string{}, []string{"anexample.org", "test.anexample.org"}, false, + map[string][]string{ + "include": {".example.org"}, + }, }, { []string{".example.org"}, []string{}, []string{"example.org"}, false, + map[string][]string{ + "include": {".example.org"}, + }, }, { []string{".example.org"}, []string{}, []string{"test.example.org"}, true, + map[string][]string{ + "include": {".example.org"}, + }, }, { []string{"anexample.org"}, []string{}, []string{"example.org", "test.example.org"}, false, + map[string][]string{ + "include": {"anexample.org"}, + }, }, { []string{".org"}, []string{}, []string{"example.org", "test.example.org", "foo.test.example.org"}, true, + map[string][]string{ + "include": {".org"}, + }, }, { []string{"example.org"}, []string{"api.example.org"}, []string{"example.org", "test.example.org", "foo.test.example.org"}, true, + map[string][]string{ + "include": {"example.org"}, + "exclude": {"api.example.org"}, + }, }, { []string{"example.org"}, []string{"api.example.org"}, []string{"foo.api.example.org", "api.example.org"}, false, + map[string][]string{ + "include": {"example.org"}, + "exclude": {"api.example.org"}, + }, }, { []string{" example.org. "}, []string{" .api.example.org "}, []string{"foo.api.example.org", "bar.baz.api.example.org."}, false, + map[string][]string{ + "include": {"example.org"}, + "exclude": {".api.example.org"}, + }, }, { []string{"example.org."}, []string{"api.example.org"}, []string{"dev-api.example.org", "qa-api.example.org"}, true, + map[string][]string{ + "include": {"example.org"}, + "exclude": {"api.example.org"}, + }, }, { []string{"example.org."}, []string{"api.example.org"}, []string{"dev.api.example.org", "qa.api.example.org"}, false, + map[string][]string{ + "include": {"example.org"}, + "exclude": {"api.example.org"}, + }, }, { []string{"example.org", "api.example.org"}, []string{"internal.api.example.org"}, []string{"foo.api.example.org"}, true, + map[string][]string{ + "include": {"api.example.org", "example.org"}, + "exclude": {"internal.api.example.org"}, + }, }, { []string{"example.org", "api.example.org"}, []string{"internal.api.example.org"}, []string{"foo.internal.api.example.org"}, false, + map[string][]string{ + "include": {"api.example.org", "example.org"}, + "exclude": {"internal.api.example.org"}, + }, }, { []string{"eXaMPle.ORG", "API.example.ORG"}, []string{"Foo-Bar.Example.Org"}, []string{"FoOoo.Api.Example.Org"}, true, + map[string][]string{ + "include": {"api.example.org", "example.org"}, + "exclude": {"foo-bar.example.org"}, + }, }, { []string{"eXaMPle.ORG", "API.example.ORG"}, []string{"api.example.org"}, []string{"foobar.Example.Org"}, true, + map[string][]string{ + "include": {"api.example.org", "example.org"}, + "exclude": {"api.example.org"}, + }, }, { []string{"eXaMPle.ORG", "API.example.ORG"}, []string{"api.example.org"}, []string{"foobar.API.Example.Org"}, false, + map[string][]string{ + "include": {"api.example.org", "example.org"}, + "exclude": {"api.example.org"}, + }, }, } @@ -226,36 +325,57 @@ var regexDomainFilterTests = []regexDomainFilterTest{ regexp.MustCompile(""), []string{"foo.org", "bar.org", "foo.bar.org"}, true, + map[string]string{ + "regexInclude": "\\.org$", + }, }, { regexp.MustCompile("\\.bar\\.org$"), regexp.MustCompile(""), []string{"foo.org", "bar.org", "example.com"}, false, + map[string]string{ + "regexInclude": "\\.bar\\.org$", + }, }, { regexp.MustCompile("(?:foo|bar)\\.org$"), regexp.MustCompile(""), []string{"foo.org", "bar.org", "example.foo.org", "example.bar.org", "a.example.foo.org", "a.example.bar.org"}, true, + map[string]string{ + "regexInclude": "(?:foo|bar)\\.org$", + }, }, { regexp.MustCompile("(?:foo|bar)\\.org$"), regexp.MustCompile("^example\\.(?:foo|bar)\\.org$"), []string{"foo.org", "bar.org", "a.example.foo.org", "a.example.bar.org"}, true, + map[string]string{ + "regexInclude": "(?:foo|bar)\\.org$", + "regexExclude": "^example\\.(?:foo|bar)\\.org$", + }, }, { regexp.MustCompile("(?:foo|bar)\\.org$"), regexp.MustCompile("^example\\.(?:foo|bar)\\.org$"), []string{"example.foo.org", "example.bar.org"}, false, + map[string]string{ + "regexInclude": "(?:foo|bar)\\.org$", + "regexExclude": "^example\\.(?:foo|bar)\\.org$", + }, }, { regexp.MustCompile("(?:foo|bar)\\.org$"), regexp.MustCompile("^example\\.(?:foo|bar)\\.org$"), []string{"foo.org", "bar.org", "a.example.foo.org", "a.example.bar.org"}, true, + map[string]string{ + "regexInclude": "(?:foo|bar)\\.org$", + "regexExclude": "^example\\.(?:foo|bar)\\.org$", + }, }, } @@ -265,24 +385,47 @@ func TestDomainFilterMatch(t *testing.T) { t.Logf("NewDomainFilter() doesn't support exclusions - skipping test %+v", tt) continue } - domainFilter := NewDomainFilter(tt.domainFilter) - for _, domain := range tt.domains { - assert.Equal(t, tt.expected, domainFilter.Match(domain), "should not fail: %v in test-case #%v", domain, i) - assert.Equal(t, tt.expected, domainFilter.Match(domain+"."), "should not fail: %v in test-case #%v", domain+".", i) - } + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + domainFilter := NewDomainFilter(tt.domainFilter) + + assertSerializes(t, domainFilter, tt.expectedSerialization) + deserialized := deserialize(t, map[string][]string{ + "include": tt.domainFilter, + }) + + for _, domain := range tt.domains { + assert.Equal(t, tt.expected, domainFilter.Match(domain), "%v", domain) + assert.Equal(t, tt.expected, domainFilter.Match(domain+"."), "%v", domain+".") + + assert.Equal(t, tt.expected, deserialized.Match(domain), "deserialized %v", domain) + assert.Equal(t, tt.expected, deserialized.Match(domain+"."), "deserialized %v", domain+".") + } + }) } } func TestDomainFilterWithExclusions(t *testing.T) { for i, tt := range domainFilterTests { - if len(tt.exclusions) == 0 { - tt.exclusions = append(tt.exclusions, "") - } - domainFilter := NewDomainFilterWithExclusions(tt.domainFilter, tt.exclusions) - for _, domain := range tt.domains { - assert.Equal(t, tt.expected, domainFilter.Match(domain), "should not fail: %v in test-case #%v", domain, i) - assert.Equal(t, tt.expected, domainFilter.Match(domain+"."), "should not fail: %v in test-case #%v", domain+".", i) - } + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + if len(tt.exclusions) == 0 { + tt.exclusions = append(tt.exclusions, "") + } + domainFilter := NewDomainFilterWithExclusions(tt.domainFilter, tt.exclusions) + + assertSerializes(t, domainFilter, tt.expectedSerialization) + deserialized := deserialize(t, map[string][]string{ + "include": tt.domainFilter, + "exclude": tt.exclusions, + }) + + for _, domain := range tt.domains { + assert.Equal(t, tt.expected, domainFilter.Match(domain), "%v", domain) + assert.Equal(t, tt.expected, domainFilter.Match(domain+"."), "%v", domain+".") + + assert.Equal(t, tt.expected, deserialized.Match(domain), "deserialized %v", domain) + assert.Equal(t, tt.expected, deserialized.Match(domain+"."), "deserialized %v", domain+".") + } + }) } } @@ -303,72 +446,119 @@ func TestDomainFilterMatchParent(t *testing.T) { []string{}, []string{"example.com"}, true, + map[string][]string{ + "include": {"a.example.com"}, + }, }, { []string{" a.example.com "}, []string{}, []string{"example.com"}, true, + map[string][]string{ + "include": {"a.example.com"}, + }, }, { []string{""}, []string{}, []string{"example.com"}, true, + map[string][]string{}, }, { []string{".a.example.com."}, []string{}, []string{"example.com"}, false, + map[string][]string{ + "include": {".a.example.com"}, + }, }, { []string{"a.example.com.", "b.example.com"}, []string{}, []string{"example.com"}, true, + map[string][]string{ + "include": {"a.example.com", "b.example.com"}, + }, }, { []string{"a.example.com"}, []string{}, []string{"b.example.com"}, false, + map[string][]string{ + "include": {"a.example.com"}, + }, }, { []string{"example.com"}, []string{}, []string{"example.com"}, false, + map[string][]string{ + "include": {"example.com"}, + }, }, { []string{"example.com"}, []string{}, []string{"anexample.com"}, false, + map[string][]string{ + "include": {"example.com"}, + }, }, { []string{""}, []string{}, []string{""}, true, + map[string][]string{}, }, } for i, tt := range parentMatchTests { - domainFilter := NewDomainFilterWithExclusions(tt.domainFilter, tt.exclusions) - for _, domain := range tt.domains { - assert.Equal(t, tt.expected, domainFilter.MatchParent(domain), "should not fail: %v in test-case #%v", domain, i) - assert.Equal(t, tt.expected, domainFilter.MatchParent(domain+"."), "should not fail: %v in test-case #%v", domain+".", i) - } + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + domainFilter := NewDomainFilterWithExclusions(tt.domainFilter, tt.exclusions) + + assertSerializes(t, domainFilter, tt.expectedSerialization) + deserialized := deserialize(t, map[string][]string{ + "include": tt.domainFilter, + "exclude": tt.exclusions, + }) + + for _, domain := range tt.domains { + assert.Equal(t, tt.expected, domainFilter.MatchParent(domain), "%v", domain) + assert.Equal(t, tt.expected, domainFilter.MatchParent(domain+"."), "%v", domain+".") + + assert.Equal(t, tt.expected, deserialized.MatchParent(domain), "deserialized %v", domain) + assert.Equal(t, tt.expected, deserialized.MatchParent(domain+"."), "deserialized %v", domain+".") + } + }) } } func TestRegexDomainFilter(t *testing.T) { for i, tt := range regexDomainFilterTests { - domainFilter := NewRegexDomainFilter(tt.regex, tt.regexExclusion) - for _, domain := range tt.domains { - assert.Equal(t, tt.expected, domainFilter.Match(domain), "should not fail: %v in test-case #%v", domain, i) - assert.Equal(t, tt.expected, domainFilter.Match(domain+"."), "should not fail: %v in test-case #%v", domain+".", i) - } + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + domainFilter := NewRegexDomainFilter(tt.regex, tt.regexExclusion) + + assertSerializes(t, domainFilter, tt.expectedSerialization) + deserialized := deserialize(t, map[string]string{ + "regexInclude": tt.regex.String(), + "regexExclude": tt.regexExclusion.String(), + }) + + for _, domain := range tt.domains { + assert.Equal(t, tt.expected, domainFilter.Match(domain), "%v", domain) + assert.Equal(t, tt.expected, domainFilter.Match(domain+"."), "%v", domain+".") + + assert.Equal(t, tt.expected, deserialized.Match(domain), "deserialized %v", domain) + assert.Equal(t, tt.expected, deserialized.Match(domain+"."), "deserialized %v", domain+".") + } + }) } } @@ -411,7 +601,7 @@ func TestMatchFilterReturnsProperEmptyVal(t *testing.T) { } func TestDomainFilterIsConfigured(t *testing.T) { - for _, tt := range []struct { + for i, tt := range []struct { filters []string exclude []string expected bool @@ -452,9 +642,130 @@ func TestDomainFilterIsConfigured(t *testing.T) { true, }, } { - t.Run("test IsConfigured", func(t *testing.T) { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { df := NewDomainFilterWithExclusions(tt.filters, tt.exclude) assert.Equal(t, tt.expected, df.IsConfigured()) }) } } + +func TestRegexDomainFilterIsConfigured(t *testing.T) { + for i, tt := range []struct { + regex string + regexExclude string + expected bool + }{ + { + "", + "", + false, + }, + { + "(?:foo|bar)\\.org$", + "", + true, + }, + { + "", + "\\.org$", + true, + }, + { + "(?:foo|bar)\\.org$", + "\\.org$", + true, + }, + } { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + df := NewRegexDomainFilter(regexp.MustCompile(tt.regex), regexp.MustCompile(tt.regexExclude)) + assert.Equal(t, tt.expected, df.IsConfigured()) + }) + } +} + +func TestDomainFilterDeserializeError(t *testing.T) { + for _, tt := range []struct { + name string + serialized map[string]interface{} + expectedError string + }{ + { + name: "invalid json", + serialized: map[string]interface{}{ + "include": 3, + }, + expectedError: "json: cannot unmarshal number into Go struct field domainFilterSerde.include of type []string", + }, + { + name: "include and regex", + serialized: map[string]interface{}{ + "include": []string{"example.com"}, + "regexInclude": "example.com", + }, + expectedError: "cannot have both domain list and regex", + }, + { + name: "exclude and regex", + serialized: map[string]interface{}{ + "exclude": []string{"example.com"}, + "regexInclude": "example.com", + }, + expectedError: "cannot have both domain list and regex", + }, + { + name: "include and regexExclude", + serialized: map[string]interface{}{ + "include": []string{"example.com"}, + "regexExclude": "example.com", + }, + expectedError: "cannot have both domain list and regex", + }, + { + name: "exclude and regexExclude", + serialized: map[string]interface{}{ + "exclude": []string{"example.com"}, + "regexExclude": "example.com", + }, + expectedError: "cannot have both domain list and regex", + }, + { + name: "invalid regex", + serialized: map[string]interface{}{ + "regexInclude": "*", + }, + expectedError: "invalid regexInclude: error parsing regexp: missing argument to repetition operator: `*`", + }, + { + name: "invalid regexExclude", + serialized: map[string]interface{}{ + "regexExclude": "*", + }, + expectedError: "invalid regexExclude: error parsing regexp: missing argument to repetition operator: `*`", + }, + } { + t.Run(tt.name, func(t *testing.T) { + var deserialized DomainFilter + toJson, _ := json.Marshal(tt.serialized) + err := json.Unmarshal(toJson, &deserialized) + assert.EqualError(t, err, tt.expectedError) + }) + } +} + +func assertSerializes[T any](t *testing.T, domainFilter DomainFilter, expectedSerialization map[string]T) { + serialized, err := json.Marshal(domainFilter) + assert.NoError(t, err, "serializing") + expected, err := json.Marshal(expectedSerialization) + require.NoError(t, err) + assert.JSONEq(t, string(expected), string(serialized), "json serialization") +} + +func deserialize[T any](t *testing.T, serialized map[string]T) DomainFilter { + inJson, err := json.Marshal(serialized) + require.NoError(t, err) + var deserialized DomainFilter + err = json.Unmarshal(inJson, &deserialized) + assert.NoError(t, err, "deserializing") + + return deserialized +}