Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: abstract away Wazero's api.ValueType and api.EncodeXX/api.DecodeXX #44

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ kvRead := extism.NewHostFunctionWithStack(

stack[0], err = p.WriteBytes(value)
},
[]api.ValueType{extism.PTR},
[]api.ValueType{extism.PTR},
[]ValueType{ValueTypePTR},
[]ValueType{ValueTypePTR},
)

kvWrite := extism.NewHostFunctionWithStack(
Expand All @@ -184,8 +184,8 @@ kvWrite := extism.NewHostFunctionWithStack(

kvStore[key] = value
},
[]api.ValueType{extism.PTR, extism.PTR},
[]api.ValueType{},
[]ValueType{ValueTypePTR, ValueTypePTR},
[]ValueType{},
)
```

Expand Down
23 changes: 11 additions & 12 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

"github.com/stretchr/testify/assert"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/sys"
)

Expand Down Expand Up @@ -227,13 +226,13 @@ func TestHost_simple(t *testing.T) {
mult := NewHostFunctionWithStack(
"mult",
func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
a := api.DecodeI32(stack[0])
b := api.DecodeI32(stack[1])
a := DecodeI32(stack[0])
b := DecodeI32(stack[1])

stack[0] = api.EncodeI32(a * b)
stack[0] = EncodeI32(a * b)
},
[]api.ValueType{PTR, PTR},
[]api.ValueType{PTR},
[]ValueType{ValueTypePTR, ValueTypePTR},
[]ValueType{ValueTypePTR},
)

if plugin, ok := plugin(t, manifest, mult); ok {
Expand Down Expand Up @@ -274,8 +273,8 @@ func TestHost_memory(t *testing.T) {

stack[0] = offset
},
[]api.ValueType{PTR},
[]api.ValueType{PTR},
[]ValueType{ValueTypePTR},
[]ValueType{ValueTypePTR},
)

mult.SetNamespace("host")
Expand Down Expand Up @@ -323,8 +322,8 @@ func TestHost_multiple(t *testing.T) {

stack[0] = offset
},
[]api.ValueType{PTR},
[]api.ValueType{PTR},
[]ValueType{ValueTypePTR},
[]ValueType{ValueTypePTR},
)

purple_message := NewHostFunctionWithStack(
Expand All @@ -348,8 +347,8 @@ func TestHost_multiple(t *testing.T) {

stack[0] = offset
},
[]api.ValueType{PTR},
[]api.ValueType{PTR},
[]ValueType{ValueTypePTR},
[]ValueType{ValueTypePTR},
)

hostFunctions := []HostFunction{
Expand Down
132 changes: 97 additions & 35 deletions host.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,20 @@ import (
"github.com/tetratelabs/wazero/api"
)

type ValType = api.ValueType

const I32 = api.ValueTypeI32
const I64 = api.ValueTypeI64
const PTR = I64
type ValueType = byte

const (
// ValueTypeI32 is a 32-bit integer.
ValueTypeI32 = api.ValueTypeI32
// ValueTypeI64 is a 64-bit integer.
ValueTypeI64 = api.ValueTypeI64
// ValueTypeF32 is a 32-bit floating point number.
ValueTypeF32 = api.ValueTypeF32
// ValueTypeF64 is a 64-bit floating point number.
ValueTypeF64 = api.ValueTypeF64
// ValueTypePTR represents a pointer to an Extism memory block. Alias for ValueTypeI64
ValueTypePTR = ValueTypeI64
)

// HostFunctionStackCallback is a Function implemented in Go instead of a wasm binary.
// The plugin parameter is the calling plugin, used to access memory or
Expand All @@ -33,18 +42,18 @@ const PTR = I64
// Here's a typical way to read three parameters and write back one.
//
// // read parameters in index order
// argv, argvBuf := api.DecodeU32(inputs[0]), api.DecodeU32(inputs[1])
// argv, argvBuf := DecodeU32(inputs[0]), DecodeU32(inputs[1])
//
// // write results back to the stack in index order
// stack[0] = api.EncodeU32(ErrnoSuccess)
// stack[0] = EncodeU32(ErrnoSuccess)
//
// This function can be non-deterministic or cause side effects. It also
// has special properties not defined in the WebAssembly Core specification.
// Notably, this uses the caller's memory (via Module.Memory). See
// https://www.w3.org/TR/wasm-core-1/#host-functions%E2%91%A0
//
// To safely decode/encode values from/to the uint64 inputs/ouputs, users are encouraged to use
// Wazero's api.EncodeXXX or api.DecodeXXX functions.
// Extism's EncodeXXX or DecodeXXX functions.
type HostFunctionStackCallback func(ctx context.Context, p *CurrentPlugin, stack []uint64)

// HostFunction represents a custom function defined by the host.
Expand All @@ -67,19 +76,19 @@ func (f *HostFunction) SetNamespace(namespace string) {
// mult := NewHostFunctionWithStack(
// "mult",
// func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
// a := api.DecodeI32(stack[0])
// b := api.DecodeI32(stack[1])
// a := DecodeI32(stack[0])
// b := DecodeI32(stack[1])
//
// stack[0] = api.EncodeI32(a * b)
// stack[0] = EncodeI32(a * b)
// },
// []api.ValueType{api.ValueTypeI64, api.ValueTypeI64},
// api.ValueTypeI64
// []ValueType{ValueTypeI64, ValueTypeI64},
// ValueTypeI64
// )
func NewHostFunctionWithStack(
name string,
callback HostFunctionStackCallback,
params []api.ValueType,
returnTypes []api.ValueType) HostFunction {
params []ValueType,
returnTypes []ValueType) HostFunction {

return HostFunction{
stackCallback: callback,
Expand Down Expand Up @@ -225,7 +234,7 @@ func defineCustomHostFunctions(builder wazero.HostModuleBuilder, funcs []HostFun
func buildEnvModule(ctx context.Context, rt wazero.Runtime, extism api.Module) (api.Module, error) {
builder := rt.NewHostModuleBuilder("extism:host/env")

wrap := func(name string, params []ValType, results []ValType) {
wrap := func(name string, params []ValueType, results []ValueType) {
f := extism.ExportedFunction(name)
builder.
NewFunctionBuilder().
Expand All @@ -238,33 +247,33 @@ func buildEnvModule(ctx context.Context, rt wazero.Runtime, extism api.Module) (
Export(name)
}

wrap("alloc", []ValType{I64}, []ValType{I64})
wrap("free", []ValType{I64}, []ValType{})
wrap("load_u8", []ValType{I64}, []ValType{I32})
wrap("input_load_u8", []ValType{I64}, []ValType{I32})
wrap("store_u64", []ValType{I64, I64}, []ValType{})
wrap("store_u8", []ValType{I64, I32}, []ValType{})
wrap("input_set", []ValType{I64, I64}, []ValType{})
wrap("output_set", []ValType{I64, I64}, []ValType{})
wrap("input_length", []ValType{}, []ValType{I64})
wrap("output_length", []ValType{}, []ValType{I64})
wrap("output_offset", []ValType{}, []ValType{I64})
wrap("length", []ValType{I64}, []ValType{I64})
wrap("reset", []ValType{}, []ValType{})
wrap("error_set", []ValType{I64}, []ValType{})
wrap("error_get", []ValType{}, []ValType{I64})
wrap("memory_bytes", []ValType{}, []ValType{I64})
wrap("alloc", []ValueType{ValueTypeI64}, []ValueType{ValueTypeI64})
wrap("free", []ValueType{ValueTypeI64}, []ValueType{})
wrap("load_u8", []ValueType{ValueTypeI64}, []ValueType{ValueTypeI32})
wrap("input_load_u8", []ValueType{ValueTypeI64}, []ValueType{ValueTypeI32})
wrap("store_u64", []ValueType{ValueTypeI64, ValueTypeI64}, []ValueType{})
wrap("store_u8", []ValueType{ValueTypeI64, ValueTypeI32}, []ValueType{})
wrap("input_set", []ValueType{ValueTypeI64, ValueTypeI64}, []ValueType{})
wrap("output_set", []ValueType{ValueTypeI64, ValueTypeI64}, []ValueType{})
wrap("input_length", []ValueType{}, []ValueType{ValueTypeI64})
wrap("output_length", []ValueType{}, []ValueType{ValueTypeI64})
wrap("output_offset", []ValueType{}, []ValueType{ValueTypeI64})
wrap("length", []ValueType{ValueTypeI64}, []ValueType{ValueTypeI64})
wrap("reset", []ValueType{}, []ValueType{})
wrap("error_set", []ValueType{ValueTypeI64}, []ValueType{})
wrap("error_get", []ValueType{}, []ValueType{ValueTypeI64})
wrap("memory_bytes", []ValueType{}, []ValueType{ValueTypeI64})

builder.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(api.GoModuleFunc(inputLoad_u64)), []ValType{I64}, []ValType{I64}).
WithGoModuleFunction(api.GoModuleFunc(api.GoModuleFunc(inputLoad_u64)), []ValueType{ValueTypeI64}, []ValueType{ValueTypeI64}).
Export("input_load_u64")

builder.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(load_u64), []ValType{I64}, []ValType{I64}).
WithGoModuleFunction(api.GoModuleFunc(load_u64), []ValueType{ValueTypeI64}, []ValueType{ValueTypeI64}).
Export("load_u64")

builder.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(store_u64), []ValType{I64, I64}, []ValType{}).
WithGoModuleFunction(api.GoModuleFunc(store_u64), []ValueType{ValueTypeI64, ValueTypeI64}, []ValueType{}).
Export("store_u64")

hostFunc := func(name string, f interface{}) {
Expand Down Expand Up @@ -528,3 +537,56 @@ func httpStatusCode(ctx context.Context, m api.Module) int32 {

panic("Invalid context, `plugin` key not found")
}

// EncodeI32 encodes the input as a ValueTypeI32.
func EncodeI32(input int32) uint64 {
return api.EncodeI32(input)
}

// DecodeI32 decodes the input as a ValueTypeI32.
func DecodeI32(input uint64) int32 {
return api.DecodeI32(input)
}

// EncodeU32 encodes the input as a ValueTypeI32.
func EncodeU32(input uint32) uint64 {
return api.EncodeU32(input)
}

// DecodeU32 decodes the input as a ValueTypeI32.
func DecodeU32(input uint64) uint32 {
return api.DecodeU32(input)
}

// EncodeI64 encodes the input as a ValueTypeI64.
func EncodeI64(input int64) uint64 {
return api.EncodeI64(input)
}

// EncodeF32 encodes the input as a ValueTypeF32.
//
// See DecodeF32
func EncodeF32(input float32) uint64 {
return api.EncodeF32(input)
}

// DecodeF32 decodes the input as a ValueTypeF32.
//
// See EncodeF32
func DecodeF32(input uint64) float32 {
return api.DecodeF32(input)
}

// EncodeF64 encodes the input as a ValueTypeF64.
//
// See EncodeF32
func EncodeF64(input float64) uint64 {
return api.EncodeF64(input)
}

// DecodeF64 decodes the input as a ValueTypeF64.
//
// See EncodeF64
func DecodeF64(input uint64) float64 {
return api.DecodeF64(input)
}