Skip to content

Commit

Permalink
Add OpCallTyped (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
antonmedv authored Nov 5, 2022
1 parent 07e6b41 commit 848ff39
Show file tree
Hide file tree
Showing 8 changed files with 376 additions and 6 deletions.
1 change: 1 addition & 0 deletions ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ type CallNode struct {
base
Callee Node
Arguments []Node
Typed int
Fast bool
}

Expand Down
33 changes: 32 additions & 1 deletion checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/antonmedv/expr/conf"
"github.com/antonmedv/expr/file"
"github.com/antonmedv/expr/parser"
"github.com/antonmedv/expr/vm"
)

func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
Expand Down Expand Up @@ -498,7 +499,7 @@ func (v *visitor) CallNode(node *ast.CallNode) (reflect.Type, info) {
}

// checkFunc checks func arguments and returns "return type" of func or method.
func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name string, arguments []ast.Node) (reflect.Type, info) {
func (v *visitor) checkFunc(fn reflect.Type, method bool, node *ast.CallNode, name string, arguments []ast.Node) (reflect.Type, info) {
if isAny(fn) {
return anyType, info{}
}
Expand Down Expand Up @@ -564,6 +565,36 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st
}
}

if !fn.IsVariadic() {
funcTypes:
for i := range vm.FuncTypes {
if i == 0 {
continue
}
typed := reflect.ValueOf(vm.FuncTypes[i]).Elem().Type()
if typed.Kind() != reflect.Func {
continue
}
if typed.NumOut() != fn.NumOut() {
continue
}
for j := 0; j < typed.NumOut(); j++ {
if typed.Out(j) != fn.Out(j) {
continue funcTypes
}
}
if typed.NumIn() != len(arguments) {
continue
}
for j, arg := range arguments {
if typed.In(j) != arg.Type() {
continue funcTypes
}
}
node.Typed = i
}
}

return fn.Out(0), info{}
}

Expand Down
13 changes: 8 additions & 5 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,15 @@ func (c *compiler) CallNode(node *ast.CallNode) {
for _, arg := range node.Arguments {
c.compile(arg)
}
op := OpCall
if node.Fast {
op = OpCallFast
}
c.compile(node.Callee)
c.emit(op, len(node.Arguments))
if node.Typed > 0 {
c.emit(OpCallTyped, node.Typed)
return
} else if node.Fast {
c.emit(OpCallFast, len(node.Arguments))
} else {
c.emit(OpCall, len(node.Arguments))
}
}

func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
Expand Down
122 changes: 122 additions & 0 deletions vm/func_types/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package main

import (
"bytes"
"fmt"
"go/format"
"reflect"
"strings"
"text/template"
. "time"
)

// Keep sorted.
var types = []interface{}{
nil,
new(func() Duration),
new(func() Month),
new(func() Time),
new(func() Weekday),
new(func() []byte),
new(func() []interface{}),
new(func() bool),
new(func() byte),
new(func() float64),
new(func() int),
new(func() int64),
new(func() interface{}),
new(func() map[string]interface{}),
new(func() rune),
new(func() string),
new(func() uint),
new(func() uint64),
new(func(Duration) Duration),
new(func(Duration) Time),
new(func(Time) Duration),
new(func(Time) bool),
new(func([]interface{}, string) string),
new(func([]string, string) string),
new(func(bool) bool),
new(func(bool) float64),
new(func(bool) int),
new(func(bool) string),
new(func(float64) bool),
new(func(float64) float64),
new(func(float64) int),
new(func(float64) string),
new(func(int) bool),
new(func(int) float64),
new(func(int) int),
new(func(int) string),
new(func(int, int) int),
new(func(int, int) string),
new(func(int64) Time),
new(func(string) []string),
new(func(string) bool),
new(func(string) float64),
new(func(string) int),
new(func(string) string),
new(func(string, byte) int),
new(func(string, int) int),
new(func(string, rune) int),
new(func(string, string) bool),
new(func(string, string) string),
}

func main() {
data := struct {
Index string
Code string
}{}

for i, t := range types {
if i == 0 {
continue
}
fn := reflect.ValueOf(t).Elem().Type()
data.Index += fmt.Sprintf("%v: new(%v),\n", i, fn)
data.Code += fmt.Sprintf("case %d:\n", i)
args := make([]string, fn.NumIn())
for j := fn.NumIn() - 1; j >= 0; j-- {
data.Code += fmt.Sprintf("arg%v := vm.pop().(%v)\n", j+1, fn.In(j))
args[j] = fmt.Sprintf("arg%v", j+1)
}
data.Code += fmt.Sprintf("return fn.(%v)(%v)\n", fn, strings.Join(args, ", "))
}

var b bytes.Buffer
err := template.Must(
template.New("func_types").
Parse(source),
).Execute(&b, data)
if err != nil {
panic(err)
}

formatted, err := format.Source(b.Bytes())
if err != nil {
panic(err)
}
fmt.Print(string(formatted))
}

const source = `// Code generated by vm/func_types/main.go. DO NOT EDIT.
package vm
import (
"fmt"
"time"
)
var FuncTypes = []interface{}{
{{ .Index }}
}
func (vm *VM) call(fn interface{}, kind int) interface{} {
switch kind {
{{ .Code }}
}
panic(fmt.Sprintf("unknown function kind (%v)", kind))
}
`
Loading

0 comments on commit 848ff39

Please sign in to comment.