diff --git a/hack/generator/pkg/astmodel/types.go b/hack/generator/pkg/astmodel/types.go index f7801912ee1..8f3766ba92f 100644 --- a/hack/generator/pkg/astmodel/types.go +++ b/hack/generator/pkg/astmodel/types.go @@ -111,6 +111,14 @@ func (types Types) Contains(name TypeName) bool { return ok } +// OverlayWith creates a new set containing all the type definitions from both this and the provided set. Any name +// collisions are resolved in favour of the provided set. Returns a new independent set, leaving the original unmodified. +func (types Types) OverlayWith(t Types) Types { + result := t.Copy() + result.AddTypes(types.Except(t)) + return result +} + // TypesDisjointUnion merges this and other, with a safety check that no type is overwritten. // If an attempt is made to overwrite a type, this function panics func TypesDisjointUnion(s1 Types, s2 Types) Types { diff --git a/hack/generator/pkg/astmodel/types_test.go b/hack/generator/pkg/astmodel/types_test.go index 5e30f549d86..a03335c31b8 100644 --- a/hack/generator/pkg/astmodel/types_test.go +++ b/hack/generator/pkg/astmodel/types_test.go @@ -12,11 +12,12 @@ import ( ) var ( - pkg = MakeExternalPackageReference("foo") - alphaDefinition = createTestDefinition("alpha") - betaDefinition = createTestDefinition("beta") - gammaDefinition = createTestDefinition("gamma") - deltaDefinition = createTestDefinition("delta") + pkg = MakeExternalPackageReference("foo") + alphaDefinition = createTestDefinition("alpha", StringType) + betaDefinition = createTestDefinition("beta", StringType) + gammaDefinition = createTestDefinition("gamma", StringType) + deltaDefinition = createTestDefinition("delta", StringType) + deltaIntDefinition = createTestDefinition("delta", IntType) ) /* @@ -146,13 +147,45 @@ func Test_TypesExcept_GivenSubset_ReturnsExpectedSet(t *testing.T) { g.Expect(set).To(ContainElement(deltaDefinition)) } +/* + * Overlay() tests + */ + +func Test_TypesOverlayWith_GivenDisjointSets_ReturnsUnionSet(t *testing.T) { + g := NewGomegaWithT(t) + left := createTestTypes(alphaDefinition, betaDefinition) + right := createTestTypes(gammaDefinition, deltaDefinition) + + set := left.OverlayWith(right) + + g.Expect(len(set)).To(Equal(4)) + g.Expect(set).To(ContainElement(alphaDefinition)) + g.Expect(set).To(ContainElement(betaDefinition)) + g.Expect(set).To(ContainElement(gammaDefinition)) + g.Expect(set).To(ContainElement(deltaDefinition)) +} + +func Test_TypesOverlayWith_GivenOverlappingSets_PrefersTypeInOverlay(t *testing.T) { + g := NewGomegaWithT(t) + left := createTestTypes(alphaDefinition, deltaDefinition) + right := createTestTypes(gammaDefinition, deltaIntDefinition) + + set := left.OverlayWith(right) + + g.Expect(len(set)).To(Equal(3)) + g.Expect(set).To(ContainElement(alphaDefinition)) + g.Expect(set).To(ContainElement(gammaDefinition)) + g.Expect(set).To(ContainElement(deltaIntDefinition)) + g.Expect(set).NotTo(ContainElement(deltaDefinition)) +} + /* * Utility functions */ -func createTestDefinition(name string) TypeDefinition { +func createTestDefinition(name string, underlyingType Type) TypeDefinition { n := MakeTypeName(pkg, name) - return MakeTypeDefinition(n, StringType) + return MakeTypeDefinition(n, underlyingType) } func createTestTypes(defs ...TypeDefinition) Types {