diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index e01874449e..35c09da02a 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -121,13 +121,7 @@ func Initialize() { // Store the compiled function into the engine's built-in functions createFunc := func(params ...sql.Expression) (sql.Expression, error) { - return &CompiledFunction{ - Name: funcName, - Parameters: params, - Functions: baseOverload, - AllOverloads: baseOverload.collectOverloadPermutations(), - IsOperator: false, - }, nil + return NewCompiledFunction(funcName, params, baseOverload, false), nil } function.BuiltIns = append(function.BuiltIns, sql.FunctionN{ Name: funcName, diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index d0b0a50b80..488958583b 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -27,16 +27,58 @@ import ( // CompiledFunction is an expression that represents a fully-analyzed PostgreSQL function. type CompiledFunction struct { - Name string - Parameters []sql.Expression - Functions *OverloadDeduction - AllOverloads [][]pgtypes.DoltgresTypeBaseID - IsOperator bool + Name string + Parameters []sql.Expression + Functions *OverloadDeduction + AllOverloads [][]pgtypes.DoltgresTypeBaseID + IsOperator bool + callableFunc FunctionInterface + casts []TypeCastFunction + originalTypes []pgtypes.DoltgresType + stashedErr error } var _ sql.FunctionExpression = (*CompiledFunction)(nil) var _ sql.NonDeterministicExpression = (*CompiledFunction)(nil) +// NewCompiledFunction returns a newly compiled function. +func NewCompiledFunction(name string, parameters []sql.Expression, functions *OverloadDeduction, isOperator bool) *CompiledFunction { + return newCompiledFunctionInternal(name, parameters, functions, functions.collectOverloadPermutations(), isOperator) +} + +// newCompiledFunctionInternal is called internally, which skips steps that may have already been processed. +func newCompiledFunctionInternal(name string, params []sql.Expression, funcs *OverloadDeduction, allFuncs [][]pgtypes.DoltgresTypeBaseID, isOperator bool) *CompiledFunction { + c := &CompiledFunction{ + Name: name, + Parameters: params, + Functions: funcs, + AllOverloads: allFuncs, + IsOperator: isOperator, + } + // First we'll analyze all of the parameters. + originalTypes, sources, err := c.analyzeParameters() + if err != nil { + // Errors should be returned from the call to Eval, so we'll stash it for now + c.stashedErr = err + return c + } + // Next we'll resolve the overload based on the parameters given. + overload, casts, err := c.resolve(originalTypes, sources) + if err != nil { + c.stashedErr = err + return c + } + // If we do not receive an overload, then the parameters given did not result in a valid match + if overload == nil || overload.Function == nil { + c.stashedErr = fmt.Errorf("function %s does not exist", c.OverloadString(originalTypes)) + return c + } + c.callableFunc = overload.Function + c.casts = casts + c.originalTypes = originalTypes + return c +} + // FunctionName implements the interface sql.Expression. func (c *CompiledFunction) FunctionName() string { return c.Name @@ -122,41 +164,32 @@ func (c *CompiledFunction) IsNonDeterministic() bool { // Eval implements the interface sql.Expression. func (c *CompiledFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - // First we'll analyze all of the parameters. - originalTypes, sources, err := c.analyzeParameters() - if err != nil { - return nil, err - } - // Next we'll resolve the overload based on the parameters given. - overload, casts, err := c.resolve(originalTypes, sources) - if err != nil { - return nil, err - } - // If we do not receive an overload, then the parameters given did not result in a valid match - if overload == nil { - return nil, fmt.Errorf("function %s does not exist", c.OverloadString(originalTypes)) + // If we have a stashed error, then we should return that now. Errors are stashed when they're supposed to be + // returned during the call to Eval. This helps to ensure consistency with how errors are returned in Postgres. + if c.stashedErr != nil { + return nil, c.stashedErr } - // With the overload figured out, we evaluate all of the parameters. + // Evaluate all of the parameters. parameters, err := c.evalParameters(ctx, row) if err != nil { return nil, err } // Convert the parameter values into their correct types - resultTypes := overload.Function.GetParameters() - if len(casts) > 0 { + resultTypes := c.callableFunc.GetParameters() + if len(c.casts) > 0 { for i := range parameters { - if casts[i] != nil { - parameters[i], err = casts[i](ctx, parameters[i], resultTypes[i]) + if c.casts[i] != nil { + parameters[i], err = c.casts[i](ctx, parameters[i], resultTypes[i]) if err != nil { return nil, err } } else { - return nil, fmt.Errorf("function %s is missing the appropriate implicit cast", c.OverloadString(originalTypes)) + return nil, fmt.Errorf("function %s is missing the appropriate implicit cast", c.OverloadString(c.originalTypes)) } } } // Pass the parameters to the function - switch f := overload.Function.(type) { + switch f := c.callableFunc.(type) { case Function0: return f.Callable(ctx) case Function1: @@ -179,13 +212,7 @@ func (c *CompiledFunction) Children() []sql.Expression { // WithChildren implements the interface sql.Expression. func (c *CompiledFunction) WithChildren(children ...sql.Expression) (sql.Expression, error) { - return &CompiledFunction{ - Name: c.Name, - Parameters: children, - Functions: c.Functions, - AllOverloads: c.AllOverloads, - IsOperator: c.IsOperator, - }, nil + return newCompiledFunctionInternal(c.Name, children, c.Functions, c.AllOverloads, c.IsOperator), nil } // resolve returns an overload that either matches the given parameters exactly, or is a viable match after casting. @@ -385,7 +412,6 @@ func (c *CompiledFunction) evalParameters(ctx *sql.Context, row sql.Row) ([]any, // analyzeParameters analyzes the parameters within an Eval call. func (c *CompiledFunction) analyzeParameters() (originalTypes []pgtypes.DoltgresType, sources []Source, err error) { - // TODO: should this be within Eval or sometime before that? originalTypes = make([]pgtypes.DoltgresType, len(c.Parameters)) sources = make([]Source, len(c.Parameters)) for i, param := range c.Parameters { diff --git a/server/functions/framework/intermediate_function.go b/server/functions/framework/intermediate_function.go index dd1b2a1df0..b4809ec7e6 100644 --- a/server/functions/framework/intermediate_function.go +++ b/server/functions/framework/intermediate_function.go @@ -33,11 +33,5 @@ func (f IntermediateFunction) Compile(name string, parameters ...sql.Expression) if f.Functions == nil { return nil } - return &CompiledFunction{ - Name: name, - Parameters: parameters, - Functions: f.Functions, - AllOverloads: f.AllOverloads, - IsOperator: f.IsOperator, - } + return newCompiledFunctionInternal(name, parameters, f.Functions, f.AllOverloads, f.IsOperator) }