Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: load the default extension collection lazily #118

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion expr/binding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

var (
extReg = NewEmptyExtensionRegistry(&extensions.DefaultCollection)
extReg = NewEmptyExtensionRegistry(extensions.GetDefaultCollection())
uPointRef = extReg.GetTypeAnchor(extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml",
Name: "point",
Expand Down
2 changes: 1 addition & 1 deletion expr/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func TestExprBuilder(t *testing.T) {
b := expr.ExprBuilder{
Reg: expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollection()),
BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct),
}
precomputedLiteral, _ := expr.NewLiteral(int32(3), false)
Expand Down
8 changes: 4 additions & 4 deletions expr/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func TestExpressionsRoundtrip(t *testing.T) {
}
// get the extension set
extSet := ext.GetExtensionSet(&plan)
reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection)
reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollection())
tests := []expr.Expression{
sampleNestedExpr(reg, substraitExtURI),
}
Expand All @@ -240,7 +240,7 @@ func TestExpressionsRoundtrip(t *testing.T) {
func ExampleExpression_Visit() {
const substraitExtURI = "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
var (
exp = sampleNestedExpr(expr.NewEmptyExtensionRegistry(&ext.DefaultCollection), substraitExtURI)
exp = sampleNestedExpr(expr.NewEmptyExtensionRegistry(ext.GetDefaultCollection()), substraitExtURI)
preVisit, postVisit expr.VisitFunc
)

Expand Down Expand Up @@ -347,7 +347,7 @@ func TestRoundTripUsingTestData(t *testing.T) {
require.NoError(t, err)
require.NoError(t, protojson.Unmarshal(raw, &protoSchema))
baseSchema := types.NewNamedStructFromProto(&protoSchema)
reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection)
reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollection())
for _, tc := range tmp["cases"].([]any) {
tt := tc.(map[string]any)
t.Run(tt["name"].(string), func(t *testing.T) {
Expand Down Expand Up @@ -403,7 +403,7 @@ func TestRoundTripExtendedExpression(t *testing.T) {
var ex proto.ExtendedExpression
require.NoError(t, protojson.Unmarshal(buf.Bytes(), &ex))

result, err := expr.ExtendedFromProto(&ex, &ext.DefaultCollection)
result, err := expr.ExtendedFromProto(&ex, ext.GetDefaultCollection())
require.NoError(t, err)

out := result.ToProto()
Expand Down
89 changes: 59 additions & 30 deletions extensions/extension_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
package extensions

import (
"embed"
"fmt"
"io"
"io/fs"
"path"
"sort"
"sync"

"github.com/creasty/defaults"
"github.com/goccy/go-yaml"
substrait "github.com/substrait-io/substrait"
"github.com/substrait-io/substrait"
substraitgo "github.com/substrait-io/substrait-go/v3"
"github.com/substrait-io/substrait-go/v3/proto/extensions"
)
Expand All @@ -20,45 +22,72 @@

const SubstraitDefaultURIPrefix = "https://github.com/substrait-io/substrait/blob/main/extensions/"

// DefaultCollection is loaded with the default Substrait extension
// definitions. Not all files are currently parsable.
// Parser needs to enhanced to support all files
var DefaultCollection Collection
var (
defaultCollection Collection
collectionOnce sync.Once
collectionLoadError error
)

// GetDefaultCollection returns a Collection that is loaded with the default Substrait extension definitions.
func GetDefaultCollection() *Collection {
c, err := GetDefaultCollectionWithError()
if err != nil {
panic(err)

Check warning on line 35 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L35

Added line #L35 was not covered by tests
}
return c
}

// GetDefaultCollectionWithError returns a Collection that is loaded with the default Substrait extension definitions.
func GetDefaultCollectionWithError() (*Collection, error) {
collectionOnce.Do(func() {
collectionLoadError = loadDefaultCollection()
})

func init() {
if collectionLoadError != nil {
return nil, collectionLoadError
}

Check warning on line 48 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L47-L48

Added lines #L47 - L48 were not covered by tests
return &defaultCollection, nil
}

func loadDefaultCollection() error {
substraitFS := substrait.GetSubstraitExtensionsFS()
entries, err := substraitFS.ReadDir("extensions")
if err != nil {
return
return err

Check warning on line 56 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L56

Added line #L56 was not covered by tests
}

for _, ent := range entries {
f, err := substraitFS.Open(path.Join("extensions/", ent.Name()))
if err != nil {
panic(err)
err2, done := loadExtensionFile(substraitFS, ent)
if done {
return err2

Check warning on line 62 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L62

Added line #L62 was not covered by tests
}
fileStat, err := f.Stat()
if err != nil {
panic(err)
}
fileName := path.Base(fileStat.Name())
// Catch and ignore load error for a file
// Currently extension grammar is not fully parseable
// There is a parser fix planned, once that is done,
// we can panic instead of ignoring failed extension file load
defer func(f fs.File, fileName string) {
if r := recover(); r != nil {
fmt.Printf("Ignoring extension file:%s, Recovered from panic: %v\n", fileName, r)
}
if err1 := f.Close(); err1 != nil {
panic(err1)
}
}(f, fileName)
err1 := DefaultCollection.Load(SubstraitDefaultURIPrefix+ent.Name(), f)
if err1 != nil {
fmt.Printf("Ignoring extension file:%s err:%v, Skipping it \n", fileName, err1)
}
return nil
}

func loadExtensionFile(substraitFS embed.FS, ent fs.DirEntry) (error, bool) {
f, err := substraitFS.Open(path.Join("extensions/", ent.Name()))
if err != nil {
return err, true
}

Check warning on line 72 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L71-L72

Added lines #L71 - L72 were not covered by tests
defer func() {
_ = f.Close()
}()
fileStat, err := f.Stat()
if err != nil {
return err, true
}

Check warning on line 79 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L78-L79

Added lines #L78 - L79 were not covered by tests
fileName := path.Base(fileStat.Name())
err = defaultCollection.Load(SubstraitDefaultURIPrefix+ent.Name(), f)
if err != nil {
if fileName == "unknown.yaml" {
// TODO: Remove this once extension parser is fixed to support unknown.yaml extension file
fmt.Printf("Ignoring extension file:%s err:%v, Skipping it \n", fileName, err)
} else {
return err, true

Check warning on line 87 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L87

Added line #L87 was not covered by tests
}
}
return nil, false
}

// ID is the unique identifier for a substrait object
Expand Down
22 changes: 11 additions & 11 deletions extensions/extension_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,11 @@ func TestDefaultCollection(t *testing.T) {
)
switch tt.typ {
case scalarFunc:
variant, ok = extensions.DefaultCollection.GetScalarFunc(id)
variant, ok = extensions.GetDefaultCollection().GetScalarFunc(id)
case aggFunc:
variant, ok = extensions.DefaultCollection.GetAggregateFunc(id)
variant, ok = extensions.GetDefaultCollection().GetAggregateFunc(id)
case windowFunc:
variant, ok = extensions.DefaultCollection.GetWindowFunc(id)
variant, ok = extensions.GetDefaultCollection().GetWindowFunc(id)
}

require.True(t, ok)
Expand All @@ -295,17 +295,18 @@ func TestDefaultCollection(t *testing.T) {
})
}

et, ok := extensions.DefaultCollection.GetType(extensions.ID{
et, ok := extensions.GetDefaultCollection().GetType(extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml", Name: "point"})
assert.True(t, ok)
assert.Equal(t, "point", et.Name)
assert.Equal(t, map[string]interface{}{"latitude": "i32", "longitude": "i32"}, et.Structure)
}

func TestCollection_GetAllScalarFunctions(t *testing.T) {
scalarFunctions := extensions.DefaultCollection.GetAllScalarFunctions()
aggregateFunctions := extensions.DefaultCollection.GetAllAggregateFunctions()
windowFunctions := extensions.DefaultCollection.GetAllWindowFunctions()
defaultExtensions := extensions.GetDefaultCollection()
scalarFunctions := defaultExtensions.GetAllScalarFunctions()
aggregateFunctions := defaultExtensions.GetAllAggregateFunctions()
windowFunctions := defaultExtensions.GetAllWindowFunctions()
assert.GreaterOrEqual(t, len(scalarFunctions), 309)
assert.GreaterOrEqual(t, len(aggregateFunctions), 62)
assert.GreaterOrEqual(t, len(windowFunctions), 7)
Expand All @@ -323,21 +324,20 @@ func TestCollection_GetAllScalarFunctions(t *testing.T) {
for _, tt := range tests {
t.Run(tt.signature, func(t *testing.T) {
assert.True(t, tt.isScalar || tt.isAggregate || tt.isWindow)
c := extensions.DefaultCollection
if tt.isScalar {
sf, ok := c.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
sf, ok := defaultExtensions.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, scalarFunctions, sf)
// verify that default nullability is set to MIRROR
assert.Equal(t, extensions.MirrorNullability, sf.Nullability())
}
if tt.isAggregate {
af, ok := c.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
af, ok := defaultExtensions.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, aggregateFunctions, af)
}
if tt.isWindow {
wf, ok := c.GetWindowFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
wf, ok := defaultExtensions.GetWindowFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, windowFunctions, wf)
}
Expand Down
2 changes: 1 addition & 1 deletion extensions/variants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func TestMatchWithSyncParams(t *testing.T) {
require.NotNil(t, testFile)
assert.Len(t, testFile.TestCases, testFileInfo.numTests)

reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection())
for _, tc := range testFile.TestCases {
t.Run(tc.FuncName, func(t *testing.T) {
switch tc.FuncType {
Expand Down
2 changes: 1 addition & 1 deletion functions/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
var gFunctionRegistry FunctionRegistry

func TestMain(m *testing.M) {
gFunctionRegistry = NewFunctionRegistry(&extensions.DefaultCollection)
gFunctionRegistry = NewFunctionRegistry(extensions.GetDefaultCollection())
m.Run()
}

Expand Down
4 changes: 2 additions & 2 deletions functions/local_functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ add(120::i8, 10::i8) [overflow:SILENT] = <!UNDEFINED>
assert.Len(t, testFile.TestCases, len(testResults))
require.GreaterOrEqual(t, len(testFile.TestCases), len(testResults))

reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection())
for i, result := range testResults {
tc := testFile.TestCases[i]
t.Run(result.name, func(t *testing.T) {
Expand Down Expand Up @@ -220,7 +220,7 @@ sum((2.5000007152557373046875, 7.0000007152557373046875, 0, 7.000000715255737304
testCases := append(testFile.TestCases, testFile1.TestCases...)
require.GreaterOrEqual(t, len(testCases), len(testResults))

reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection())
for i, result := range testResults {
tc := testCases[i]
t.Run(result.name, func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion plan/builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ type Builder interface {
const FETCH_COUNT_ALL_RECORDS = -1

func NewBuilderDefault() Builder {
return NewBuilder(&extensions.DefaultCollection)
return NewBuilder(extensions.GetDefaultCollection())
}

func NewBuilder(c *extensions.Collection) Builder {
Expand Down
2 changes: 1 addition & 1 deletion plan/internal/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestVirtualTableExpressionFromProto(t *testing.T) {
literal1 := expr.NewPrimitiveLiteral(int32(1), false)
expr1 := literal1.ToProto()

reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection)
reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollection())
rows := &proto.Expression_Nested_Struct{Fields: []*proto.Expression{
expr1,
}}
Expand Down
6 changes: 3 additions & 3 deletions plan/plan_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestBasicEmitPlan(t *testing.T) {
protoPlan, err := p.ToProto()
require.NoError(t, err)

roundTrip, err := plan.FromProto(protoPlan, &extensions.DefaultCollection)
roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollection())
require.NoError(t, err)

assert.Equal(t, p, roundTrip)
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestEmitEmptyPlan(t *testing.T) {
protoPlan, err := p.ToProto()
require.NoError(t, err)

roundTrip, err := plan.FromProto(protoPlan, &extensions.DefaultCollection)
roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollection())
require.NoError(t, err)

assert.Equal(t, p, roundTrip)
Expand Down Expand Up @@ -169,7 +169,7 @@ func checkRoundTrip(t *testing.T, expectedJSON string, p *plan.Plan) {
assert.Truef(t, proto.Equal(&expectedProto, protoPlan), "JSON expected: %s\ngot: %s",
protojson.Format(&expectedProto), protojson.Format(protoPlan))

roundTrip, err := plan.FromProto(&expectedProto, &extensions.DefaultCollection)
roundTrip, err := plan.FromProto(&expectedProto, extensions.GetDefaultCollection())
require.NoError(t, err)

roundTripProto, err := roundTrip.ToProto()
Expand Down
2 changes: 1 addition & 1 deletion plan/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

func TestRelFromProto(t *testing.T) {

registry := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection)
registry := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollection())
literal5 := &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_I64{I64: 5}}
exprLiteral5 := &proto.Expression{RexType: &proto.Expression_Literal_{Literal: literal5}}

Expand Down
4 changes: 2 additions & 2 deletions plan/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func createPrimitiveBool(value bool) expr.Expression {
}

func TestRelations_Copy(t *testing.T) {
extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection)
extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollection())
aggregateFnID := extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
Name: "avg",
Expand Down Expand Up @@ -414,7 +414,7 @@ func TestRelations_Copy(t *testing.T) {
}

func TestAggregateRelToBuilder(t *testing.T) {
extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection)
extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollection())
aggregateFnID := extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
Name: "avg",
Expand Down
8 changes: 4 additions & 4 deletions testcases/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ add(120::i8, 10::i8) [overflow:ERROR] = <!ERROR>
{&types.Int16Type{Nullability: types.NullabilityRequired}, &types.Int16Type{Nullability: types.NullabilityRequired}},
{&types.Int8Type{Nullability: types.NullabilityRequired}, &types.Int8Type{Nullability: types.NullabilityRequired}},
}
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection())
basicGroupDesc := "'Basic examples without any special cases'"
overflowGroupDesc := "Overflow examples demonstrating overflow behavior"
groupDescs := []string{basicGroupDesc, basicGroupDesc, overflowGroupDesc}
Expand Down Expand Up @@ -325,7 +325,7 @@ func TestParseAggregateFunc(t *testing.T) {
avg((1,2,3)::fp32) = 2::fp64
sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ERROR>`

reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection())
arithUri := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
testFile, err := ParseTestCasesFromString(header + tests)
require.NoError(t, err)
Expand Down Expand Up @@ -544,7 +544,7 @@ LIST_AGG(t1.col0, ','::string) = 1::fp64
require.NotNil(t, testFile)
assert.Len(t, testFile.TestCases, 1)
assert.Equal(t, AggregateFuncType, testFile.TestCases[0].FuncType)
reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection)
reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollection())
aggFun, err := testFile.TestCases[0].GetAggregateFunctionInvocation(&reg, nil)
require.NoError(t, err)
assert.Equal(t, "string_agg", aggFun.Name())
Expand Down Expand Up @@ -724,7 +724,7 @@ func TestLoadAllSubstraitTestFiles(t *testing.T) {
testFile, err := ParseTestCaseFileFromFS(got, filePath)
require.NoError(t, err)
require.NotNil(t, testFile)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollection())
for _, tc := range testFile.TestCases {
testGetFunctionInvocation(t, tc, &reg, funcRegistry)
}
Expand Down
Loading