Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: change host function API and hide GuestRuntime #13

Merged
merged 1 commit into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ type Plugin struct {
LastStatusCode int
log func(LogLevel, string)
logLevel LogLevel
guestRuntime GuestRuntime
guestRuntime guestRuntime
}

func logStd(level LogLevel, message string) {
Expand Down Expand Up @@ -363,7 +363,7 @@ func NewPlugin(
log: logStd,
logLevel: logLevel}

p.guestRuntime = guestRuntime(p)
p.guestRuntime = detectGuestRuntime(p)
return p, nil
}

Expand Down Expand Up @@ -470,8 +470,8 @@ func (plugin *Plugin) Call(name string, data []byte) (uint32, []byte, error) {
}

var isStart = name == "_start"
if plugin.guestRuntime.Init != nil && !isStart && !plugin.guestRuntime.initialized {
err := plugin.guestRuntime.Init()
if plugin.guestRuntime.init != nil && !isStart && !plugin.guestRuntime.initialized {
err := plugin.guestRuntime.init()
if err != nil {
return 1, []byte{}, errors.New(fmt.Sprintf("failed to initialize runtime: %v", err))
}
Expand Down
56 changes: 28 additions & 28 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,18 +224,18 @@ func TestExit(t *testing.T) {
func TestHost_simple(t *testing.T) {
manifest := manifest("host.wasm")

mult := HostFunction{
Name: "mult",
Namespace: "env",
Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) {
mult := NewHostFunctionWithStack(
"mult",
"env",
func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
a := api.DecodeI32(stack[0])
b := api.DecodeI32(stack[1])

stack[0] = api.EncodeI32(a * b)
},
Params: []api.ValueType{api.ValueTypeI64, api.ValueTypeI64},
Results: []api.ValueType{api.ValueTypeI64},
}
[]api.ValueType{api.ValueTypeI64, api.ValueTypeI64},
api.ValueTypeI64,
)

if plugin, ok := plugin(t, manifest, mult); ok {
defer plugin.Close()
Expand All @@ -254,10 +254,10 @@ func TestHost_simple(t *testing.T) {
func TestHost_memory(t *testing.T) {
manifest := manifest("host_memory.wasm")

mult := HostFunction{
Name: "to_upper",
Namespace: "host",
Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) {
mult := NewHostFunctionWithStack(
"to_upper",
"host",
func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
offset := stack[0]
buffer, err := plugin.ReadBytes(offset)
if err != nil {
Expand All @@ -276,9 +276,9 @@ func TestHost_memory(t *testing.T) {

stack[0] = offset
},
Params: []api.ValueType{api.ValueTypeI64},
Results: []api.ValueType{api.ValueTypeI64},
}
[]api.ValueType{api.ValueTypeI64},
api.ValueTypeI64,
)

if plugin, ok := plugin(t, manifest, mult); ok {
defer plugin.Close()
Expand All @@ -302,10 +302,10 @@ func TestHost_multiple(t *testing.T) {
EnableWasi: true,
}

green_message := HostFunction{
Name: "hostGreenMessage",
Namespace: "env",
Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) {
green_message := NewHostFunctionWithStack(
"hostGreenMessage",
"env",
func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
offset := stack[0]
input, err := plugin.ReadString(offset)

Expand All @@ -324,14 +324,14 @@ func TestHost_multiple(t *testing.T) {

stack[0] = offset
},
Params: []api.ValueType{api.ValueTypeI64},
Results: []api.ValueType{api.ValueTypeI64},
}
[]api.ValueType{api.ValueTypeI64},
api.ValueTypeI64,
)

purple_message := HostFunction{
Name: "hostPurpleMessage",
Namespace: "env",
Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) {
purple_message := NewHostFunctionWithStack(
"hostPurpleMessage",
"env",
func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
offset := stack[0]
input, err := plugin.ReadString(offset)

Expand All @@ -350,9 +350,9 @@ func TestHost_multiple(t *testing.T) {

stack[0] = offset
},
Params: []api.ValueType{api.ValueTypeI64},
Results: []api.ValueType{api.ValueTypeI64},
}
[]api.ValueType{api.ValueTypeI64},
api.ValueTypeI64,
)

hostFunctions := []HostFunction{
purple_message,
Expand Down
56 changes: 36 additions & 20 deletions host.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type ValType = api.ValueType
const I32 = api.ValueTypeI32
const I64 = api.ValueTypeI64

// HostFunctionCallback is a Function implemented in Go instead of a wasm binary.
// HostFunctionStackCallback is a Function implemented in Go instead of a wasm binary.
// The plugin parameter is the calling plugin, used to access memory or
// exported functions and logging.
//
Expand All @@ -45,30 +45,47 @@ const I64 = api.ValueTypeI64
//
// To safely decode/encode values from/to the uint64 inputs/ouputs, users are encouraged to use
// Wazero's api.EncodeXXX or api.DecodeXXX functions.
type HostFunctionCallback func(ctx context.Context, p *CurrentPlugin, userData interface{}, stack []uint64)
type HostFunctionStackCallback func(ctx context.Context, p *CurrentPlugin, stack []uint64)

// HostFunction represents a custom function defined by the host.
type HostFunction struct {
stackCallback HostFunctionStackCallback
Name string
Namespace string
Params []api.ValueType
Returns []api.ValueType
}

// NewHostFunctionWithStack creates a new instance of a HostFunction, which is designed
// to provide custom functionality in a given host environment.
// Here's an example multiplication function that loads operands from memory:
//
// mult := HostFunction{
// Name: "mult",
// Namespace: "env",
// Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) {
// mult := NewHostFunctionWithStack(
// "mult",
// "env",
// func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
// a := api.DecodeI32(stack[0])
// b := api.DecodeI32(stack[1])
//
// stack[0] = api.EncodeI32(a * b)
// },
// Params: []api.ValueType{api.ValueTypeI64, api.ValueTypeI64},
// Results: []api.ValueType{api.ValueTypeI64},
// }
type HostFunction struct {
Callback HostFunctionCallback
Name string
Namespace string
Params []api.ValueType
Results []api.ValueType
UserData interface{}
// []api.ValueType{api.ValueTypeI64, api.ValueTypeI64},
// api.ValueTypeI64
// )
func NewHostFunctionWithStack(
name string,
namespace string,
callback HostFunctionStackCallback,
params []api.ValueType,
returnType api.ValueType) HostFunction {

return HostFunction{
stackCallback: callback,
Name: name,
Namespace: namespace,
Params: params,
Returns: []api.ValueType{returnType},
}
}

type CurrentPlugin struct {
Expand Down Expand Up @@ -187,17 +204,16 @@ func defineCustomHostFunctions(builder wazero.HostModuleBuilder, funcs []HostFun
// a separate variable (closure) and assigning the value of f to it, you might run into unexpected behavior.
// All the closures created in the loop would end up referencing the same f, which could lead to incorrect or unintended results.
// See: https://github.com/extism/go-sdk/issues/5#issuecomment-1666774486
closure := f.Callback
userData := f.UserData
closure := f.stackCallback

builder.NewFunctionBuilder().WithGoFunction(api.GoFunc(func(ctx context.Context, stack []uint64) {
if plugin, ok := ctx.Value("plugin").(*Plugin); ok {
closure(ctx, &CurrentPlugin{plugin}, userData, stack)
closure(ctx, &CurrentPlugin{plugin}, stack)
return
}

panic("Invalid context, `plugin` key not found")
}), f.Params, f.Results).Export(f.Name)
}), f.Params, f.Returns).Export(f.Name)
}
}

Expand Down
36 changes: 18 additions & 18 deletions runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ import (

// TODO: test runtime initialization for WASI and Haskell

type RuntimeType uint8
type runtimeType uint8

const (
None RuntimeType = iota
None runtimeType = iota
Haskell
Wasi
)

type GuestRuntime struct {
Type RuntimeType
Init func() error
type guestRuntime struct {
runtimeType runtimeType
init func() error
initialized bool
}

func guestRuntime(p *Plugin) GuestRuntime {
func detectGuestRuntime(p *Plugin) guestRuntime {
m := p.Main

runtime, ok := haskellRuntime(p, m)
Expand All @@ -34,16 +34,16 @@ func guestRuntime(p *Plugin) GuestRuntime {
}

p.Log(Trace, "No runtime detected")
return GuestRuntime{Type: None, Init: func() error { return nil }, initialized: true}
return guestRuntime{runtimeType: None, init: func() error { return nil }, initialized: true}
}

// Check for Haskell runtime initialization functions
// Initialize Haskell runtime if `hs_init` and `hs_exit` are present,
// by calling the `hs_init` export
func haskellRuntime(p *Plugin, m api.Module) (GuestRuntime, bool) {
func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {
initFunc := m.ExportedFunction("hs_init")
if initFunc == nil {
return GuestRuntime{}, false
return guestRuntime{}, false
}

params := initFunc.Definition().ParamTypes()
Expand All @@ -70,13 +70,13 @@ func haskellRuntime(p *Plugin, m api.Module) (GuestRuntime, bool) {
}

p.Log(Trace, "Haskell runtime detected")
return GuestRuntime{Type: Haskell, Init: init}, true
return guestRuntime{runtimeType: Haskell, init: init}, true
}

// Check for initialization functions defined by the WASI standard
func wasiRuntime(p *Plugin, m api.Module) (GuestRuntime, bool) {
func wasiRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {
if !p.Runtime.hasWasi {
return GuestRuntime{}, false
return guestRuntime{}, false
}

// WASI supports two modules: Reactors and Commands
Expand All @@ -90,30 +90,30 @@ func wasiRuntime(p *Plugin, m api.Module) (GuestRuntime, bool) {
}

// Check for `_initialize` this is used by WASI to initialize certain interfaces.
func reactorModule(m api.Module, p *Plugin) (GuestRuntime, bool) {
func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(m, p, "_initialize")
if init == nil {
return GuestRuntime{}, false
return guestRuntime{}, false
}

p.Logf(Trace, "WASI runtime detected")
p.Logf(Trace, "Reactor module detected")

return GuestRuntime{Type: Wasi, Init: init}, true
return guestRuntime{runtimeType: Wasi, init: init}, true
}

// Check for `__wasm__call_ctors`, this is used by WASI to
// initialize certain interfaces.
func commandModule(m api.Module, p *Plugin) (GuestRuntime, bool) {
func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(m, p, "__wasm_call_ctors")
if init == nil {
return GuestRuntime{}, false
return guestRuntime{}, false
}

p.Logf(Trace, "WASI runtime detected")
p.Logf(Trace, "Command module detected")

return GuestRuntime{Type: Wasi, Init: init}, true
return guestRuntime{runtimeType: Wasi, init: init}, true
}

func findFunc(m api.Module, p *Plugin, name string) func() error {
Expand Down