diff --git a/builtin/builtin.go b/builtin/builtin.go index cd943bd0..14fc8a4b 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -12,16 +12,6 @@ import ( "github.com/expr-lang/expr/vm/runtime" ) -type Function struct { - Name string - Func func(args ...any) (any, error) - Fast func(arg any) any - ValidateArgs func(args ...any) (any, error) - Types []reflect.Type - Validate func(args []reflect.Type) (reflect.Type, error) - Predicate bool -} - var ( Index map[string]int Names []string diff --git a/builtin/function.go b/builtin/function.go new file mode 100644 index 00000000..e6d88234 --- /dev/null +++ b/builtin/function.go @@ -0,0 +1,22 @@ +package builtin + +import ( + "reflect" +) + +type Function struct { + Name string + Func func(args ...any) (any, error) + Fast func(arg any) any + ValidateArgs func(args ...any) (any, error) + Types []reflect.Type + Validate func(args []reflect.Type) (reflect.Type, error) + Predicate bool +} + +func (f *Function) Type() reflect.Type { + if len(f.Types) > 0 { + return f.Types[0] + } + return reflect.TypeOf(f.Func) +} diff --git a/checker/checker.go b/checker/checker.go index 5dd722fe..0b5a0227 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -169,10 +169,10 @@ func (v *checker) ident(node ast.Node, name string, strict, builtins bool) (refl } if builtins { if fn, ok := v.config.Functions[name]; ok { - return functionType, info{fn: fn} + return fn.Type(), info{fn: fn} } if fn, ok := v.config.Builtins[name]; ok { - return functionType, info{fn: fn} + return fn.Type(), info{fn: fn} } } if v.config.Strict && strict { @@ -833,7 +833,7 @@ func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] } return t, info{} } else if len(f.Types) == 0 { - t, err := v.checkArguments(f.Name, functionType, false, arguments, node) + t, err := v.checkArguments(f.Name, f.Type(), false, arguments, node) if err != nil { if v.err == nil { v.err = err diff --git a/checker/checker_test.go b/checker/checker_test.go index d03e3a8e..35daeda4 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -994,7 +994,7 @@ func TestCheck_builtin_without_call(t *testing.T) { err string }{ {`len + 1`, "invalid operation: + (mismatched types func(...interface {}) (interface {}, error) and int) (1:5)\n | len + 1\n | ....^"}, - {`string.A`, "type func(...interface {}) (interface {}, error)[string] is undefined (1:8)\n | string.A\n | .......^"}, + {`string.A`, "type func(interface {}) string[string] is undefined (1:8)\n | string.A\n | .......^"}, } for _, test := range tests { diff --git a/checker/types.go b/checker/types.go index 662139c3..8c080504 100644 --- a/checker/types.go +++ b/checker/types.go @@ -18,7 +18,6 @@ var ( anyType = reflect.TypeOf(new(any)).Elem() timeType = reflect.TypeOf(time.Time{}) durationType = reflect.TypeOf(time.Duration(0)) - functionType = reflect.TypeOf(new(func(...any) (any, error))).Elem() ) func combined(a, b reflect.Type) reflect.Type { diff --git a/patcher/with_context_test.go b/patcher/with_context_test.go index 0d0917af..4c2bd048 100644 --- a/patcher/with_context_test.go +++ b/patcher/with_context_test.go @@ -30,3 +30,34 @@ func TestWithContext(t *testing.T) { require.NoError(t, err) require.Equal(t, 42, output) } + +func TestWithContext_with_env_Function(t *testing.T) { + env := map[string]any{ + "ctx": context.TODO(), + } + + fn := expr.Function("fn", + func(params ...any) (any, error) { + ctx := params[0].(context.Context) + a := params[1].(int) + + return ctx.Value("value").(int) + a, nil + }, + new(func(context.Context, int) int), + ) + + program, err := expr.Compile( + `fn(40)`, + expr.Env(env), + expr.WithContext("ctx"), + fn, + ) + require.NoError(t, err) + + ctx := context.WithValue(context.Background(), "value", 2) + env["ctx"] = ctx + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 42, output) +}