From 16265dfd105967c408e84a6473bdcb862f39d92b Mon Sep 17 00:00:00 2001 From: Chandra Sanapala Date: Wed, 12 Feb 2025 08:59:55 +0530 Subject: [PATCH 1/3] feat: remove extension package init --- expr/binding_test.go | 2 +- expr/builder_test.go | 2 +- expr/expressions_test.go | 8 +-- extensions/extension_mgr.go | 89 ++++++++++++++++++++----------- extensions/extension_mgr_test.go | 22 ++++---- extensions/variants_test.go | 2 +- functions/dialect_test.go | 2 +- functions/local_functions_test.go | 4 +- plan/builders.go | 2 +- plan/internal/helper_test.go | 2 +- plan/plan_builder_test.go | 6 +-- plan/plan_test.go | 2 +- plan/relations_test.go | 4 +- testcases/parser/parse_test.go | 8 +-- 14 files changed, 92 insertions(+), 63 deletions(-) diff --git a/expr/binding_test.go b/expr/binding_test.go index c3f6537..c7c11d2 100644 --- a/expr/binding_test.go +++ b/expr/binding_test.go @@ -12,7 +12,7 @@ import ( ) var ( - extReg = NewEmptyExtensionRegistry(&extensions.DefaultCollection) + extReg = NewEmptyExtensionRegistry(extensions.GetDefaultCollection()) uPointRef = extReg.GetTypeAnchor(extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml", Name: "point", diff --git a/expr/builder_test.go b/expr/builder_test.go index f517ec7..b7c89ed 100644 --- a/expr/builder_test.go +++ b/expr/builder_test.go @@ -14,7 +14,7 @@ import ( func TestExprBuilder(t *testing.T) { b := expr.ExprBuilder{ - Reg: expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection), + Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollection()), BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct), } precomputedLiteral, _ := expr.NewLiteral(int32(3), false) diff --git a/expr/expressions_test.go b/expr/expressions_test.go index 9473872..c2f3cf5 100644 --- a/expr/expressions_test.go +++ b/expr/expressions_test.go @@ -224,7 +224,7 @@ func TestExpressionsRoundtrip(t *testing.T) { } // get the extension set extSet := ext.GetExtensionSet(&plan) - reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection) + reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollection()) tests := []expr.Expression{ sampleNestedExpr(reg, substraitExtURI), } @@ -240,7 +240,7 @@ func TestExpressionsRoundtrip(t *testing.T) { func ExampleExpression_Visit() { const substraitExtURI = "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" var ( - exp = sampleNestedExpr(expr.NewEmptyExtensionRegistry(&ext.DefaultCollection), substraitExtURI) + exp = sampleNestedExpr(expr.NewEmptyExtensionRegistry(ext.GetDefaultCollection()), substraitExtURI) preVisit, postVisit expr.VisitFunc ) @@ -347,7 +347,7 @@ func TestRoundTripUsingTestData(t *testing.T) { require.NoError(t, err) require.NoError(t, protojson.Unmarshal(raw, &protoSchema)) baseSchema := types.NewNamedStructFromProto(&protoSchema) - reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection) + reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollection()) for _, tc := range tmp["cases"].([]any) { tt := tc.(map[string]any) t.Run(tt["name"].(string), func(t *testing.T) { @@ -403,7 +403,7 @@ func TestRoundTripExtendedExpression(t *testing.T) { var ex proto.ExtendedExpression require.NoError(t, protojson.Unmarshal(buf.Bytes(), &ex)) - result, err := expr.ExtendedFromProto(&ex, &ext.DefaultCollection) + result, err := expr.ExtendedFromProto(&ex, ext.GetDefaultCollection()) require.NoError(t, err) out := result.ToProto() diff --git a/extensions/extension_mgr.go b/extensions/extension_mgr.go index 1f46929..9572903 100644 --- a/extensions/extension_mgr.go +++ b/extensions/extension_mgr.go @@ -3,15 +3,17 @@ package extensions import ( + "embed" "fmt" "io" "io/fs" "path" "sort" + "sync" "github.com/creasty/defaults" "github.com/goccy/go-yaml" - substrait "github.com/substrait-io/substrait" + "github.com/substrait-io/substrait" substraitgo "github.com/substrait-io/substrait-go/v3" "github.com/substrait-io/substrait-go/v3/proto/extensions" ) @@ -20,45 +22,72 @@ type AdvancedExtension = extensions.AdvancedExtension const SubstraitDefaultURIPrefix = "https://github.com/substrait-io/substrait/blob/main/extensions/" -// DefaultCollection is loaded with the default Substrait extension -// definitions. Not all files are currently parsable. -// Parser needs to enhanced to support all files -var DefaultCollection Collection +var ( + defaultCollection Collection + collectionOnce sync.Once + collectionLoadError error +) + +// GetDefaultCollection returns a Collection that is loaded with the default Substrait extension definitions. +func GetDefaultCollection() *Collection { + c, err := GetDefaultCollectionWithError() + if err != nil { + panic(err) + } + return c +} + +// GetDefaultCollectionWithError returns a Collection that is loaded with the default Substrait extension definitions. +func GetDefaultCollectionWithError() (*Collection, error) { + collectionOnce.Do(func() { + collectionLoadError = loadDefaultCollection() + }) -func init() { + if collectionLoadError != nil { + return nil, collectionLoadError + } + return &defaultCollection, nil +} + +func loadDefaultCollection() error { substraitFS := substrait.GetSubstraitExtensionsFS() entries, err := substraitFS.ReadDir("extensions") if err != nil { - return + return err } for _, ent := range entries { - f, err := substraitFS.Open(path.Join("extensions/", ent.Name())) - if err != nil { - panic(err) + err2, done := loadExtensionFile(substraitFS, ent) + if done { + return err2 } - fileStat, err := f.Stat() - if err != nil { - panic(err) - } - fileName := path.Base(fileStat.Name()) - // Catch and ignore load error for a file - // Currently extension grammar is not fully parseable - // There is a parser fix planned, once that is done, - // we can panic instead of ignoring failed extension file load - defer func(f fs.File, fileName string) { - if r := recover(); r != nil { - fmt.Printf("Ignoring extension file:%s, Recovered from panic: %v\n", fileName, r) - } - if err1 := f.Close(); err1 != nil { - panic(err1) - } - }(f, fileName) - err1 := DefaultCollection.Load(SubstraitDefaultURIPrefix+ent.Name(), f) - if err1 != nil { - fmt.Printf("Ignoring extension file:%s err:%v, Skipping it \n", fileName, err1) + } + return nil +} + +func loadExtensionFile(substraitFS embed.FS, ent fs.DirEntry) (error, bool) { + f, err := substraitFS.Open(path.Join("extensions/", ent.Name())) + if err != nil { + return err, true + } + defer func() { + _ = f.Close() + }() + fileStat, err := f.Stat() + if err != nil { + return err, true + } + fileName := path.Base(fileStat.Name()) + err = defaultCollection.Load(SubstraitDefaultURIPrefix+ent.Name(), f) + if err != nil { + if fileName == "unknown.yaml" { + // TODO: Remove this once extension parser is fixed to support unknown.yaml extension file + fmt.Printf("Ignoring extension file:%s err:%v, Skipping it \n", fileName, err) + } else { + return err, true } } + return nil, false } // ID is the unique identifier for a substrait object diff --git a/extensions/extension_mgr_test.go b/extensions/extension_mgr_test.go index 8184402..d8f7bb1 100644 --- a/extensions/extension_mgr_test.go +++ b/extensions/extension_mgr_test.go @@ -277,11 +277,11 @@ func TestDefaultCollection(t *testing.T) { ) switch tt.typ { case scalarFunc: - variant, ok = extensions.DefaultCollection.GetScalarFunc(id) + variant, ok = extensions.GetDefaultCollection().GetScalarFunc(id) case aggFunc: - variant, ok = extensions.DefaultCollection.GetAggregateFunc(id) + variant, ok = extensions.GetDefaultCollection().GetAggregateFunc(id) case windowFunc: - variant, ok = extensions.DefaultCollection.GetWindowFunc(id) + variant, ok = extensions.GetDefaultCollection().GetWindowFunc(id) } require.True(t, ok) @@ -295,7 +295,7 @@ func TestDefaultCollection(t *testing.T) { }) } - et, ok := extensions.DefaultCollection.GetType(extensions.ID{ + et, ok := extensions.GetDefaultCollection().GetType(extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml", Name: "point"}) assert.True(t, ok) assert.Equal(t, "point", et.Name) @@ -303,9 +303,10 @@ func TestDefaultCollection(t *testing.T) { } func TestCollection_GetAllScalarFunctions(t *testing.T) { - scalarFunctions := extensions.DefaultCollection.GetAllScalarFunctions() - aggregateFunctions := extensions.DefaultCollection.GetAllAggregateFunctions() - windowFunctions := extensions.DefaultCollection.GetAllWindowFunctions() + defaultExtensions := extensions.GetDefaultCollection() + scalarFunctions := defaultExtensions.GetAllScalarFunctions() + aggregateFunctions := defaultExtensions.GetAllAggregateFunctions() + windowFunctions := defaultExtensions.GetAllWindowFunctions() assert.GreaterOrEqual(t, len(scalarFunctions), 309) assert.GreaterOrEqual(t, len(aggregateFunctions), 62) assert.GreaterOrEqual(t, len(windowFunctions), 7) @@ -323,21 +324,20 @@ func TestCollection_GetAllScalarFunctions(t *testing.T) { for _, tt := range tests { t.Run(tt.signature, func(t *testing.T) { assert.True(t, tt.isScalar || tt.isAggregate || tt.isWindow) - c := extensions.DefaultCollection if tt.isScalar { - sf, ok := c.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) + sf, ok := defaultExtensions.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) assert.True(t, ok) assert.Contains(t, scalarFunctions, sf) // verify that default nullability is set to MIRROR assert.Equal(t, extensions.MirrorNullability, sf.Nullability()) } if tt.isAggregate { - af, ok := c.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) + af, ok := defaultExtensions.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) assert.True(t, ok) assert.Contains(t, aggregateFunctions, af) } if tt.isWindow { - wf, ok := c.GetWindowFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) + wf, ok := defaultExtensions.GetWindowFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) assert.True(t, ok) assert.Contains(t, windowFunctions, wf) } diff --git a/extensions/variants_test.go b/extensions/variants_test.go index 868e33d..4735346 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -229,7 +229,7 @@ func TestMatchWithSyncParams(t *testing.T) { require.NotNil(t, testFile) assert.Len(t, testFile.TestCases, testFileInfo.numTests) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) for _, tc := range testFile.TestCases { t.Run(tc.FuncName, func(t *testing.T) { switch tc.FuncType { diff --git a/functions/dialect_test.go b/functions/dialect_test.go index dba1735..f90751b 100644 --- a/functions/dialect_test.go +++ b/functions/dialect_test.go @@ -16,7 +16,7 @@ import ( var gFunctionRegistry FunctionRegistry func TestMain(m *testing.M) { - gFunctionRegistry = NewFunctionRegistry(&extensions.DefaultCollection) + gFunctionRegistry = NewFunctionRegistry(extensions.GetDefaultCollection()) m.Run() } diff --git a/functions/local_functions_test.go b/functions/local_functions_test.go index 05f2821..7653487 100644 --- a/functions/local_functions_test.go +++ b/functions/local_functions_test.go @@ -94,7 +94,7 @@ add(120::i8, 10::i8) [overflow:SILENT] = assert.Len(t, testFile.TestCases, len(testResults)) require.GreaterOrEqual(t, len(testFile.TestCases), len(testResults)) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) for i, result := range testResults { tc := testFile.TestCases[i] t.Run(result.name, func(t *testing.T) { @@ -220,7 +220,7 @@ sum((2.5000007152557373046875, 7.0000007152557373046875, 0, 7.000000715255737304 testCases := append(testFile.TestCases, testFile1.TestCases...) require.GreaterOrEqual(t, len(testCases), len(testResults)) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) for i, result := range testResults { tc := testCases[i] t.Run(result.name, func(t *testing.T) { diff --git a/plan/builders.go b/plan/builders.go index 1bc7910..e1d6de7 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -154,7 +154,7 @@ type Builder interface { const FETCH_COUNT_ALL_RECORDS = -1 func NewBuilderDefault() Builder { - return NewBuilder(&extensions.DefaultCollection) + return NewBuilder(extensions.GetDefaultCollection()) } func NewBuilder(c *extensions.Collection) Builder { diff --git a/plan/internal/helper_test.go b/plan/internal/helper_test.go index 5686a10..29e05a5 100644 --- a/plan/internal/helper_test.go +++ b/plan/internal/helper_test.go @@ -41,7 +41,7 @@ func TestVirtualTableExpressionFromProto(t *testing.T) { literal1 := expr.NewPrimitiveLiteral(int32(1), false) expr1 := literal1.ToProto() - reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection) + reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollection()) rows := &proto.Expression_Nested_Struct{Fields: []*proto.Expression{ expr1, }} diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index b8b3d58..1aabbc7 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -65,7 +65,7 @@ func TestBasicEmitPlan(t *testing.T) { protoPlan, err := p.ToProto() require.NoError(t, err) - roundTrip, err := plan.FromProto(protoPlan, &extensions.DefaultCollection) + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollection()) require.NoError(t, err) assert.Equal(t, p, roundTrip) @@ -105,7 +105,7 @@ func TestEmitEmptyPlan(t *testing.T) { protoPlan, err := p.ToProto() require.NoError(t, err) - roundTrip, err := plan.FromProto(protoPlan, &extensions.DefaultCollection) + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollection()) require.NoError(t, err) assert.Equal(t, p, roundTrip) @@ -169,7 +169,7 @@ func checkRoundTrip(t *testing.T, expectedJSON string, p *plan.Plan) { assert.Truef(t, proto.Equal(&expectedProto, protoPlan), "JSON expected: %s\ngot: %s", protojson.Format(&expectedProto), protojson.Format(protoPlan)) - roundTrip, err := plan.FromProto(&expectedProto, &extensions.DefaultCollection) + roundTrip, err := plan.FromProto(&expectedProto, extensions.GetDefaultCollection()) require.NoError(t, err) roundTripProto, err := roundTrip.ToProto() diff --git a/plan/plan_test.go b/plan/plan_test.go index 363367b..cd831cd 100644 --- a/plan/plan_test.go +++ b/plan/plan_test.go @@ -13,7 +13,7 @@ import ( func TestRelFromProto(t *testing.T) { - registry := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection) + registry := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollection()) literal5 := &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_I64{I64: 5}} exprLiteral5 := &proto.Expression{RexType: &proto.Expression_Literal_{Literal: literal5}} diff --git a/plan/relations_test.go b/plan/relations_test.go index b261cfc..cb90b76 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -28,7 +28,7 @@ func createPrimitiveBool(value bool) expr.Expression { } func TestRelations_Copy(t *testing.T) { - extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection) + extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollection()) aggregateFnID := extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", Name: "avg", @@ -414,7 +414,7 @@ func TestRelations_Copy(t *testing.T) { } func TestAggregateRelToBuilder(t *testing.T) { - extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection) + extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollection()) aggregateFnID := extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", Name: "avg", diff --git a/testcases/parser/parse_test.go b/testcases/parser/parse_test.go index 86e3b63..0f27e22 100644 --- a/testcases/parser/parse_test.go +++ b/testcases/parser/parse_test.go @@ -50,7 +50,7 @@ add(120::i8, 10::i8) [overflow:ERROR] = {&types.Int16Type{Nullability: types.NullabilityRequired}, &types.Int16Type{Nullability: types.NullabilityRequired}}, {&types.Int8Type{Nullability: types.NullabilityRequired}, &types.Int8Type{Nullability: types.NullabilityRequired}}, } - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) basicGroupDesc := "'Basic examples without any special cases'" overflowGroupDesc := "Overflow examples demonstrating overflow behavior" groupDescs := []string{basicGroupDesc, basicGroupDesc, overflowGroupDesc} @@ -325,7 +325,7 @@ func TestParseAggregateFunc(t *testing.T) { avg((1,2,3)::fp32) = 2::fp64 sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = ` - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) arithUri := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" testFile, err := ParseTestCasesFromString(header + tests) require.NoError(t, err) @@ -544,7 +544,7 @@ LIST_AGG(t1.col0, ','::string) = 1::fp64 require.NotNil(t, testFile) assert.Len(t, testFile.TestCases, 1) assert.Equal(t, AggregateFuncType, testFile.TestCases[0].FuncType) - reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection) + reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollection()) aggFun, err := testFile.TestCases[0].GetAggregateFunctionInvocation(®, nil) require.NoError(t, err) assert.Equal(t, "string_agg", aggFun.Name()) @@ -724,7 +724,7 @@ func TestLoadAllSubstraitTestFiles(t *testing.T) { testFile, err := ParseTestCaseFileFromFS(got, filePath) require.NoError(t, err) require.NotNil(t, testFile) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) for _, tc := range testFile.TestCases { testGetFunctionInvocation(t, tc, ®, funcRegistry) } From bb543ca27fd6d4066bc069464fb7160cdb5530b4 Mon Sep 17 00:00:00 2001 From: Chandra Sanapala Date: Wed, 12 Feb 2025 11:45:22 +0530 Subject: [PATCH 2/3] Address review comments --- extensions/extension_mgr.go | 49 ++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/extensions/extension_mgr.go b/extensions/extension_mgr.go index 9572903..1dabdb3 100644 --- a/extensions/extension_mgr.go +++ b/extensions/extension_mgr.go @@ -23,12 +23,14 @@ type AdvancedExtension = extensions.AdvancedExtension const SubstraitDefaultURIPrefix = "https://github.com/substrait-io/substrait/blob/main/extensions/" var ( - defaultCollection Collection - collectionOnce sync.Once - collectionLoadError error + getDefaultCollectionOnce = sync.OnceValues[*Collection, error](loadDefaultCollection) + unsupportedExtensions = map[string]bool{ + "unknown.yaml": true, + } ) // GetDefaultCollection returns a Collection that is loaded with the default Substrait extension definitions. +// This version is provided for the ease of use of legacy code. Please use GetDefaultCollectionWithError instead. func GetDefaultCollection() *Collection { c, err := GetDefaultCollectionWithError() if err != nil { @@ -39,55 +41,46 @@ func GetDefaultCollection() *Collection { // GetDefaultCollectionWithError returns a Collection that is loaded with the default Substrait extension definitions. func GetDefaultCollectionWithError() (*Collection, error) { - collectionOnce.Do(func() { - collectionLoadError = loadDefaultCollection() - }) - - if collectionLoadError != nil { - return nil, collectionLoadError - } - return &defaultCollection, nil + return getDefaultCollectionOnce() } -func loadDefaultCollection() error { +func loadDefaultCollection() (*Collection, error) { substraitFS := substrait.GetSubstraitExtensionsFS() entries, err := substraitFS.ReadDir("extensions") if err != nil { - return err + return nil, err } + var defaultCollection Collection for _, ent := range entries { - err2, done := loadExtensionFile(substraitFS, ent) - if done { - return err2 + err2 := loadExtensionFile(&defaultCollection, substraitFS, ent) + if err2 != nil { + return nil, err2 } } - return nil + return &defaultCollection, nil } -func loadExtensionFile(substraitFS embed.FS, ent fs.DirEntry) (error, bool) { +func loadExtensionFile(collection *Collection, substraitFS embed.FS, ent fs.DirEntry) error { f, err := substraitFS.Open(path.Join("extensions/", ent.Name())) if err != nil { - return err, true + return err } defer func() { _ = f.Close() }() fileStat, err := f.Stat() if err != nil { - return err, true + return err } fileName := path.Base(fileStat.Name()) - err = defaultCollection.Load(SubstraitDefaultURIPrefix+ent.Name(), f) - if err != nil { - if fileName == "unknown.yaml" { - // TODO: Remove this once extension parser is fixed to support unknown.yaml extension file - fmt.Printf("Ignoring extension file:%s err:%v, Skipping it \n", fileName, err) - } else { - return err, true + if _, ok := unsupportedExtensions[fileName]; !ok { + err = collection.Load(SubstraitDefaultURIPrefix+ent.Name(), f) + if err != nil { + return err } } - return nil, false + return nil } // ID is the unique identifier for a substrait object From 513f32bce7462a7ea84cfa915f9ab9b3ad3d4618 Mon Sep 17 00:00:00 2001 From: Chandra Sanapala Date: Wed, 12 Feb 2025 11:50:58 +0530 Subject: [PATCH 3/3] rename GetDefaultCollectionWithError to GetDefaultCollection --- expr/binding_test.go | 2 +- expr/builder_test.go | 2 +- expr/expressions_test.go | 8 ++++---- extensions/extension_mgr.go | 12 ++++++------ extensions/extension_mgr_test.go | 10 +++++----- extensions/variants_test.go | 2 +- functions/dialect_test.go | 2 +- functions/local_functions_test.go | 4 ++-- plan/builders.go | 2 +- plan/internal/helper_test.go | 2 +- plan/plan_builder_test.go | 6 +++--- plan/plan_test.go | 2 +- plan/relations_test.go | 4 ++-- testcases/parser/parse_test.go | 8 ++++---- 14 files changed, 33 insertions(+), 33 deletions(-) diff --git a/expr/binding_test.go b/expr/binding_test.go index c7c11d2..2025241 100644 --- a/expr/binding_test.go +++ b/expr/binding_test.go @@ -12,7 +12,7 @@ import ( ) var ( - extReg = NewEmptyExtensionRegistry(extensions.GetDefaultCollection()) + extReg = NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()) uPointRef = extReg.GetTypeAnchor(extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml", Name: "point", diff --git a/expr/builder_test.go b/expr/builder_test.go index b7c89ed..e62d4b4 100644 --- a/expr/builder_test.go +++ b/expr/builder_test.go @@ -14,7 +14,7 @@ import ( func TestExprBuilder(t *testing.T) { b := expr.ExprBuilder{ - Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollection()), + Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()), BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct), } precomputedLiteral, _ := expr.NewLiteral(int32(3), false) diff --git a/expr/expressions_test.go b/expr/expressions_test.go index c2f3cf5..a5a9fff 100644 --- a/expr/expressions_test.go +++ b/expr/expressions_test.go @@ -224,7 +224,7 @@ func TestExpressionsRoundtrip(t *testing.T) { } // get the extension set extSet := ext.GetExtensionSet(&plan) - reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollection()) + reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollectionWithNoError()) tests := []expr.Expression{ sampleNestedExpr(reg, substraitExtURI), } @@ -240,7 +240,7 @@ func TestExpressionsRoundtrip(t *testing.T) { func ExampleExpression_Visit() { const substraitExtURI = "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" var ( - exp = sampleNestedExpr(expr.NewEmptyExtensionRegistry(ext.GetDefaultCollection()), substraitExtURI) + exp = sampleNestedExpr(expr.NewEmptyExtensionRegistry(ext.GetDefaultCollectionWithNoError()), substraitExtURI) preVisit, postVisit expr.VisitFunc ) @@ -347,7 +347,7 @@ func TestRoundTripUsingTestData(t *testing.T) { require.NoError(t, err) require.NoError(t, protojson.Unmarshal(raw, &protoSchema)) baseSchema := types.NewNamedStructFromProto(&protoSchema) - reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollection()) + reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollectionWithNoError()) for _, tc := range tmp["cases"].([]any) { tt := tc.(map[string]any) t.Run(tt["name"].(string), func(t *testing.T) { @@ -403,7 +403,7 @@ func TestRoundTripExtendedExpression(t *testing.T) { var ex proto.ExtendedExpression require.NoError(t, protojson.Unmarshal(buf.Bytes(), &ex)) - result, err := expr.ExtendedFromProto(&ex, ext.GetDefaultCollection()) + result, err := expr.ExtendedFromProto(&ex, ext.GetDefaultCollectionWithNoError()) require.NoError(t, err) out := result.ToProto() diff --git a/extensions/extension_mgr.go b/extensions/extension_mgr.go index 1dabdb3..bbbea2b 100644 --- a/extensions/extension_mgr.go +++ b/extensions/extension_mgr.go @@ -29,18 +29,18 @@ var ( } ) -// GetDefaultCollection returns a Collection that is loaded with the default Substrait extension definitions. -// This version is provided for the ease of use of legacy code. Please use GetDefaultCollectionWithError instead. -func GetDefaultCollection() *Collection { - c, err := GetDefaultCollectionWithError() +// GetDefaultCollectionWithNoError returns a Collection that is loaded with the default Substrait extension definitions. +// This version is provided for the ease of use of legacy code. Please use GetDefaultCollection instead. +func GetDefaultCollectionWithNoError() *Collection { + c, err := GetDefaultCollection() if err != nil { panic(err) } return c } -// GetDefaultCollectionWithError returns a Collection that is loaded with the default Substrait extension definitions. -func GetDefaultCollectionWithError() (*Collection, error) { +// GetDefaultCollection returns a Collection that is loaded with the default Substrait extension definitions. +func GetDefaultCollection() (*Collection, error) { return getDefaultCollectionOnce() } diff --git a/extensions/extension_mgr_test.go b/extensions/extension_mgr_test.go index d8f7bb1..0ef7bd7 100644 --- a/extensions/extension_mgr_test.go +++ b/extensions/extension_mgr_test.go @@ -277,11 +277,11 @@ func TestDefaultCollection(t *testing.T) { ) switch tt.typ { case scalarFunc: - variant, ok = extensions.GetDefaultCollection().GetScalarFunc(id) + variant, ok = extensions.GetDefaultCollectionWithNoError().GetScalarFunc(id) case aggFunc: - variant, ok = extensions.GetDefaultCollection().GetAggregateFunc(id) + variant, ok = extensions.GetDefaultCollectionWithNoError().GetAggregateFunc(id) case windowFunc: - variant, ok = extensions.GetDefaultCollection().GetWindowFunc(id) + variant, ok = extensions.GetDefaultCollectionWithNoError().GetWindowFunc(id) } require.True(t, ok) @@ -295,7 +295,7 @@ func TestDefaultCollection(t *testing.T) { }) } - et, ok := extensions.GetDefaultCollection().GetType(extensions.ID{ + et, ok := extensions.GetDefaultCollectionWithNoError().GetType(extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml", Name: "point"}) assert.True(t, ok) assert.Equal(t, "point", et.Name) @@ -303,7 +303,7 @@ func TestDefaultCollection(t *testing.T) { } func TestCollection_GetAllScalarFunctions(t *testing.T) { - defaultExtensions := extensions.GetDefaultCollection() + defaultExtensions := extensions.GetDefaultCollectionWithNoError() scalarFunctions := defaultExtensions.GetAllScalarFunctions() aggregateFunctions := defaultExtensions.GetAllAggregateFunctions() windowFunctions := defaultExtensions.GetAllWindowFunctions() diff --git a/extensions/variants_test.go b/extensions/variants_test.go index 4735346..b279549 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -229,7 +229,7 @@ func TestMatchWithSyncParams(t *testing.T) { require.NotNil(t, testFile) assert.Len(t, testFile.TestCases, testFileInfo.numTests) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) for _, tc := range testFile.TestCases { t.Run(tc.FuncName, func(t *testing.T) { switch tc.FuncType { diff --git a/functions/dialect_test.go b/functions/dialect_test.go index f90751b..13ade37 100644 --- a/functions/dialect_test.go +++ b/functions/dialect_test.go @@ -16,7 +16,7 @@ import ( var gFunctionRegistry FunctionRegistry func TestMain(m *testing.M) { - gFunctionRegistry = NewFunctionRegistry(extensions.GetDefaultCollection()) + gFunctionRegistry = NewFunctionRegistry(extensions.GetDefaultCollectionWithNoError()) m.Run() } diff --git a/functions/local_functions_test.go b/functions/local_functions_test.go index 7653487..c1068f7 100644 --- a/functions/local_functions_test.go +++ b/functions/local_functions_test.go @@ -94,7 +94,7 @@ add(120::i8, 10::i8) [overflow:SILENT] = assert.Len(t, testFile.TestCases, len(testResults)) require.GreaterOrEqual(t, len(testFile.TestCases), len(testResults)) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) for i, result := range testResults { tc := testFile.TestCases[i] t.Run(result.name, func(t *testing.T) { @@ -220,7 +220,7 @@ sum((2.5000007152557373046875, 7.0000007152557373046875, 0, 7.000000715255737304 testCases := append(testFile.TestCases, testFile1.TestCases...) require.GreaterOrEqual(t, len(testCases), len(testResults)) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) for i, result := range testResults { tc := testCases[i] t.Run(result.name, func(t *testing.T) { diff --git a/plan/builders.go b/plan/builders.go index e1d6de7..0e3aae9 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -154,7 +154,7 @@ type Builder interface { const FETCH_COUNT_ALL_RECORDS = -1 func NewBuilderDefault() Builder { - return NewBuilder(extensions.GetDefaultCollection()) + return NewBuilder(extensions.GetDefaultCollectionWithNoError()) } func NewBuilder(c *extensions.Collection) Builder { diff --git a/plan/internal/helper_test.go b/plan/internal/helper_test.go index 29e05a5..ef48306 100644 --- a/plan/internal/helper_test.go +++ b/plan/internal/helper_test.go @@ -41,7 +41,7 @@ func TestVirtualTableExpressionFromProto(t *testing.T) { literal1 := expr.NewPrimitiveLiteral(int32(1), false) expr1 := literal1.ToProto() - reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollection()) + reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollectionWithNoError()) rows := &proto.Expression_Nested_Struct{Fields: []*proto.Expression{ expr1, }} diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index 1aabbc7..82791a7 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -65,7 +65,7 @@ func TestBasicEmitPlan(t *testing.T) { protoPlan, err := p.ToProto() require.NoError(t, err) - roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollection()) + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) require.NoError(t, err) assert.Equal(t, p, roundTrip) @@ -105,7 +105,7 @@ func TestEmitEmptyPlan(t *testing.T) { protoPlan, err := p.ToProto() require.NoError(t, err) - roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollection()) + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) require.NoError(t, err) assert.Equal(t, p, roundTrip) @@ -169,7 +169,7 @@ func checkRoundTrip(t *testing.T, expectedJSON string, p *plan.Plan) { assert.Truef(t, proto.Equal(&expectedProto, protoPlan), "JSON expected: %s\ngot: %s", protojson.Format(&expectedProto), protojson.Format(protoPlan)) - roundTrip, err := plan.FromProto(&expectedProto, extensions.GetDefaultCollection()) + roundTrip, err := plan.FromProto(&expectedProto, extensions.GetDefaultCollectionWithNoError()) require.NoError(t, err) roundTripProto, err := roundTrip.ToProto() diff --git a/plan/plan_test.go b/plan/plan_test.go index cd831cd..e74d6dc 100644 --- a/plan/plan_test.go +++ b/plan/plan_test.go @@ -13,7 +13,7 @@ import ( func TestRelFromProto(t *testing.T) { - registry := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollection()) + registry := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()) literal5 := &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_I64{I64: 5}} exprLiteral5 := &proto.Expression{RexType: &proto.Expression_Literal_{Literal: literal5}} diff --git a/plan/relations_test.go b/plan/relations_test.go index cb90b76..08b385e 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -28,7 +28,7 @@ func createPrimitiveBool(value bool) expr.Expression { } func TestRelations_Copy(t *testing.T) { - extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollection()) + extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollectionWithNoError()) aggregateFnID := extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", Name: "avg", @@ -414,7 +414,7 @@ func TestRelations_Copy(t *testing.T) { } func TestAggregateRelToBuilder(t *testing.T) { - extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollection()) + extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollectionWithNoError()) aggregateFnID := extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", Name: "avg", diff --git a/testcases/parser/parse_test.go b/testcases/parser/parse_test.go index 0f27e22..9b5d13f 100644 --- a/testcases/parser/parse_test.go +++ b/testcases/parser/parse_test.go @@ -50,7 +50,7 @@ add(120::i8, 10::i8) [overflow:ERROR] = {&types.Int16Type{Nullability: types.NullabilityRequired}, &types.Int16Type{Nullability: types.NullabilityRequired}}, {&types.Int8Type{Nullability: types.NullabilityRequired}, &types.Int8Type{Nullability: types.NullabilityRequired}}, } - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) basicGroupDesc := "'Basic examples without any special cases'" overflowGroupDesc := "Overflow examples demonstrating overflow behavior" groupDescs := []string{basicGroupDesc, basicGroupDesc, overflowGroupDesc} @@ -325,7 +325,7 @@ func TestParseAggregateFunc(t *testing.T) { avg((1,2,3)::fp32) = 2::fp64 sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = ` - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) arithUri := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" testFile, err := ParseTestCasesFromString(header + tests) require.NoError(t, err) @@ -544,7 +544,7 @@ LIST_AGG(t1.col0, ','::string) = 1::fp64 require.NotNil(t, testFile) assert.Len(t, testFile.TestCases, 1) assert.Equal(t, AggregateFuncType, testFile.TestCases[0].FuncType) - reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollection()) + reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()) aggFun, err := testFile.TestCases[0].GetAggregateFunctionInvocation(®, nil) require.NoError(t, err) assert.Equal(t, "string_agg", aggFun.Name()) @@ -724,7 +724,7 @@ func TestLoadAllSubstraitTestFiles(t *testing.T) { testFile, err := ParseTestCaseFileFromFS(got, filePath) require.NoError(t, err) require.NotNil(t, testFile) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection()) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) for _, tc := range testFile.TestCases { testGetFunctionInvocation(t, tc, ®, funcRegistry) }