Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform function analysis once #362

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions server/functions/framework/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,7 @@ func Initialize() {

// Store the compiled function into the engine's built-in functions
createFunc := func(params ...sql.Expression) (sql.Expression, error) {
return &CompiledFunction{
Name: funcName,
Parameters: params,
Functions: baseOverload,
AllOverloads: baseOverload.collectOverloadPermutations(),
IsOperator: false,
}, nil
return NewCompiledFunction(funcName, params, baseOverload, false), nil
}
function.BuiltIns = append(function.BuiltIns, sql.FunctionN{
Name: funcName,
Expand Down
92 changes: 59 additions & 33 deletions server/functions/framework/compiled_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,58 @@ import (

// CompiledFunction is an expression that represents a fully-analyzed PostgreSQL function.
type CompiledFunction struct {
Name string
Parameters []sql.Expression
Functions *OverloadDeduction
AllOverloads [][]pgtypes.DoltgresTypeBaseID
IsOperator bool
Name string
Parameters []sql.Expression
Functions *OverloadDeduction
AllOverloads [][]pgtypes.DoltgresTypeBaseID
IsOperator bool
callableFunc FunctionInterface
casts []TypeCastFunction
originalTypes []pgtypes.DoltgresType
stashedErr error
}

var _ sql.FunctionExpression = (*CompiledFunction)(nil)
var _ sql.NonDeterministicExpression = (*CompiledFunction)(nil)

// NewCompiledFunction returns a newly compiled function.
func NewCompiledFunction(name string, parameters []sql.Expression, functions *OverloadDeduction, isOperator bool) *CompiledFunction {
return newCompiledFunctionInternal(name, parameters, functions, functions.collectOverloadPermutations(), isOperator)
}

// newCompiledFunctionInternal is called internally, which skips steps that may have already been processed.
func newCompiledFunctionInternal(name string, params []sql.Expression, funcs *OverloadDeduction, allFuncs [][]pgtypes.DoltgresTypeBaseID, isOperator bool) *CompiledFunction {
c := &CompiledFunction{
Name: name,
Parameters: params,
Functions: funcs,
AllOverloads: allFuncs,
IsOperator: isOperator,
}
// First we'll analyze all of the parameters.
originalTypes, sources, err := c.analyzeParameters()
if err != nil {
// Errors should be returned from the call to Eval, so we'll stash it for now
c.stashedErr = err
return c
}
// Next we'll resolve the overload based on the parameters given.
overload, casts, err := c.resolve(originalTypes, sources)
if err != nil {
c.stashedErr = err
return c
}
// If we do not receive an overload, then the parameters given did not result in a valid match
if overload == nil || overload.Function == nil {
c.stashedErr = fmt.Errorf("function %s does not exist", c.OverloadString(originalTypes))
return c
}
c.callableFunc = overload.Function
c.casts = casts
c.originalTypes = originalTypes
return c
}

// FunctionName implements the interface sql.Expression.
func (c *CompiledFunction) FunctionName() string {
return c.Name
Expand Down Expand Up @@ -122,41 +164,32 @@ func (c *CompiledFunction) IsNonDeterministic() bool {

// Eval implements the interface sql.Expression.
func (c *CompiledFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
// First we'll analyze all of the parameters.
originalTypes, sources, err := c.analyzeParameters()
if err != nil {
return nil, err
}
// Next we'll resolve the overload based on the parameters given.
overload, casts, err := c.resolve(originalTypes, sources)
if err != nil {
return nil, err
}
// If we do not receive an overload, then the parameters given did not result in a valid match
if overload == nil {
return nil, fmt.Errorf("function %s does not exist", c.OverloadString(originalTypes))
// If we have a stashed error, then we should return that now. Errors are stashed when they're supposed to be
// returned during the call to Eval. This helps to ensure consistency with how errors are returned in Postgres.
if c.stashedErr != nil {
return nil, c.stashedErr
}
// With the overload figured out, we evaluate all of the parameters.
// Evaluate all of the parameters.
parameters, err := c.evalParameters(ctx, row)
if err != nil {
return nil, err
}
// Convert the parameter values into their correct types
resultTypes := overload.Function.GetParameters()
if len(casts) > 0 {
resultTypes := c.callableFunc.GetParameters()
if len(c.casts) > 0 {
for i := range parameters {
if casts[i] != nil {
parameters[i], err = casts[i](ctx, parameters[i], resultTypes[i])
if c.casts[i] != nil {
parameters[i], err = c.casts[i](ctx, parameters[i], resultTypes[i])
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("function %s is missing the appropriate implicit cast", c.OverloadString(originalTypes))
return nil, fmt.Errorf("function %s is missing the appropriate implicit cast", c.OverloadString(c.originalTypes))
}
}
}
// Pass the parameters to the function
switch f := overload.Function.(type) {
switch f := c.callableFunc.(type) {
case Function0:
return f.Callable(ctx)
case Function1:
Expand All @@ -179,13 +212,7 @@ func (c *CompiledFunction) Children() []sql.Expression {

// WithChildren implements the interface sql.Expression.
func (c *CompiledFunction) WithChildren(children ...sql.Expression) (sql.Expression, error) {
return &CompiledFunction{
Name: c.Name,
Parameters: children,
Functions: c.Functions,
AllOverloads: c.AllOverloads,
IsOperator: c.IsOperator,
}, nil
return newCompiledFunctionInternal(c.Name, children, c.Functions, c.AllOverloads, c.IsOperator), nil
}

// resolve returns an overload that either matches the given parameters exactly, or is a viable match after casting.
Expand Down Expand Up @@ -385,7 +412,6 @@ func (c *CompiledFunction) evalParameters(ctx *sql.Context, row sql.Row) ([]any,

// analyzeParameters analyzes the parameters within an Eval call.
func (c *CompiledFunction) analyzeParameters() (originalTypes []pgtypes.DoltgresType, sources []Source, err error) {
// TODO: should this be within Eval or sometime before that?
originalTypes = make([]pgtypes.DoltgresType, len(c.Parameters))
sources = make([]Source, len(c.Parameters))
for i, param := range c.Parameters {
Expand Down
8 changes: 1 addition & 7 deletions server/functions/framework/intermediate_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,5 @@ func (f IntermediateFunction) Compile(name string, parameters ...sql.Expression)
if f.Functions == nil {
return nil
}
return &CompiledFunction{
Name: name,
Parameters: parameters,
Functions: f.Functions,
AllOverloads: f.AllOverloads,
IsOperator: f.IsOperator,
}
return newCompiledFunctionInternal(name, parameters, f.Functions, f.AllOverloads, f.IsOperator)
}
Loading