Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make sure we instantiate non-main modules #93

Merged
merged 11 commits into from
Jan 30, 2025
30 changes: 14 additions & 16 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ func RuntimeVersion() string {

// Runtime represents the Extism plugin's runtime environment, including the underlying Wazero runtime and modules.
type Runtime struct {
Wazero wazero.Runtime
Extism api.Module
Env api.Module
hasWasi bool
Wazero wazero.Runtime
Extism api.Module
Env api.Module
}

// PluginInstanceConfig contains configuration options for the Extism plugin.
Expand Down Expand Up @@ -112,13 +111,12 @@ func (l LogLevel) String() string {

// Plugin is used to call WASM functions
type Plugin struct {
close []func(ctx context.Context) error
extism api.Module

module api.Module
Timeout time.Duration
Config map[string]string
// NOTE: maybe we can have some nice methods for getting/setting vars
close []func(ctx context.Context) error
extism api.Module
mainModule api.Module
modules map[string]api.Module
Timeout time.Duration
Config map[string]string
Var map[string][]byte
AllowedHosts []string
AllowedPaths map[string]string
Expand All @@ -138,7 +136,7 @@ func logStd(level LogLevel, message string) {
}

func (p *Plugin) Module() *Module {
return &Module{inner: p.module}
return &Module{inner: p.mainModule}
}

// SetLogger sets a custom logging callback
Expand Down Expand Up @@ -443,7 +441,7 @@ func (p *Plugin) GetErrorWithContext(ctx context.Context) string {

// FunctionExists returns true when the named function is present in the plugin's main Module
func (p *Plugin) FunctionExists(name string) bool {
return p.module.ExportedFunction(name) != nil
return p.mainModule.ExportedFunction(name) != nil
}

// Call a function by name with the given input, returning the output
Expand All @@ -469,15 +467,15 @@ func (p *Plugin) CallWithContext(ctx context.Context, name string, data []byte)

ctx = context.WithValue(ctx, InputOffsetKey("inputOffset"), intputOffset)

var f = p.module.ExportedFunction(name)
var f = p.mainModule.ExportedFunction(name)

if f == nil {
return 1, []byte{}, fmt.Errorf("unknown function: %s", name)
} else if n := len(f.Definition().ResultTypes()); n > 1 {
return 1, []byte{}, fmt.Errorf("function %s has %v results, expected 0 or 1", name, n)
}

var isStart = name == "_start"
var isStart = name == "_start" || name == "_initialize"
if p.guestRuntime.init != nil && !isStart && !p.guestRuntime.initialized {
err := p.guestRuntime.init(ctx)
if err != nil {
Expand All @@ -501,7 +499,7 @@ func (p *Plugin) CallWithContext(ctx context.Context, name string, data []byte)
if exitCode == 0 {
// It's possible for the function to return 0 as an error code, even
// if the module is closed.
if p.module.IsClosed() {
if p.mainModule.IsClosed() {
return 0, nil, fmt.Errorf("module is closed")
}
err = nil
Expand Down
177 changes: 171 additions & 6 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"strings"
"sync"
"testing"
"time"

observe "github.com/dylibso/observe-sdk/go"
"github.com/dylibso/observe-sdk/go/adapter/stdout"
"github.com/stretchr/testify/assert"
Expand All @@ -13,12 +20,6 @@ import (
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/experimental/logging"
"github.com/tetratelabs/wazero/sys"
"log"
"os"
"strings"
"sync"
"testing"
"time"
)

func TestWasmUrl(t *testing.T) {
Expand Down Expand Up @@ -1038,6 +1039,170 @@ func TestEnableExperimentalFeature(t *testing.T) {
}
}

func TestModuleLinking(t *testing.T) {
manifest := Manifest{
Wasm: []Wasm{
WasmFile{
Path: "wasm/lib.wasm",
Name: "lib",
},
WasmFile{
Path: "wasm/main.wasm",
Name: "main",
},
},
}

if plugin, ok := pluginInstance(t, manifest); ok {
defer plugin.Close(context.Background())

exit, output, err := plugin.Call("run_test", []byte("benjamin"))

if assertCall(t, err, exit) {
expected := "Hello, BENJAMIN"

actual := string(output)

assert.Equal(t, expected, actual)
}
}
}

func TestModuleLinkingMultipleInstances(t *testing.T) {
manifest := Manifest{
Wasm: []Wasm{
WasmFile{
Path: "wasm/lib.wasm",
Name: "lib",
},
WasmFile{
Path: "wasm/main.wasm",
Name: "main",
},
},
}

ctx := context.Background()
config := wasiPluginConfig()

compiledPlugin, err := NewCompiledPlugin(ctx, manifest, PluginConfig{
EnableWasi: true,
}, []HostFunction{})

if err != nil {
t.Fatalf("Could not create plugin: %v", err)
}

for i := 0; i < 3; i++ {
plugin, err := compiledPlugin.Instance(ctx, config)
if err != nil {
t.Fatalf("Could not create plugin instance: %v", err)
}
// purposefully not closing the plugin instance

for j := 0; j < 3; j++ {

exit, output, err := plugin.Call("run_test", []byte("benjamin"))

if assertCall(t, err, exit) {
expected := "Hello, BENJAMIN"

actual := string(output)

assert.Equal(t, expected, actual)
}
}
}
}

func TestCompiledModuleMultipleInstances(t *testing.T) {
manifest := Manifest{
Wasm: []Wasm{
WasmFile{
Path: "wasm/count_vowels.wasm",
Name: "main",
},
},
}

ctx := context.Background()
config := wasiPluginConfig()

compiledPlugin, err := NewCompiledPlugin(ctx, manifest, PluginConfig{
EnableWasi: true,
}, []HostFunction{})

if err != nil {
t.Fatalf("Could not create plugin: %v", err)
}

var wg sync.WaitGroup
numInstances := 300

// Create and test instances in parallel
for i := 0; i < numInstances; i++ {
wg.Add(1)
go func(instanceNum int) {
defer wg.Done()

plugin, err := compiledPlugin.Instance(ctx, config)
if err != nil {
t.Errorf("Could not create plugin instance %d: %v", instanceNum, err)
return
}
// purposefully not closing the plugin instance

// Sequential calls for this instance
for j := 0; j < 3; j++ {
exit, _, err := plugin.Call("count_vowels", []byte("benjamin"))
if err != nil {
t.Errorf("Instance %d, call %d failed: %v", instanceNum, j, err)
return
}
if exit != 0 {
t.Errorf("Instance %d, call %d returned non-zero exit code: %d", instanceNum, j, exit)
}
}
}(i)
}
wg.Wait()
}

func TestMultipleCallsOutputParallel(t *testing.T) {
manifest := manifest("count_vowels.wasm")
numInstances := 300

var wg sync.WaitGroup

// Create and test instances in parallel
for i := 0; i < numInstances; i++ {
wg.Add(1)
go func(instanceNum int) {
defer wg.Done()

if plugin, ok := pluginInstance(t, manifest); ok {
defer plugin.Close(context.Background())

// Sequential calls for this instance
exit, output1, err := plugin.Call("count_vowels", []byte("aaa"))
if !assertCall(t, err, exit) {
return
}

exit, output2, err := plugin.Call("count_vowels", []byte("bbba"))
if !assertCall(t, err, exit) {
return
}

assert.Equal(t, `{"count":3,"total":3,"vowels":"aeiouAEIOU"}`, string(output1))
assert.Equal(t, `{"count":1,"total":4,"vowels":"aeiouAEIOU"}`, string(output2))
}
}(i)
}

wg.Wait()
}

func BenchmarkInitialize(b *testing.B) {
ctx := context.Background()
cache := wazero.NewCompilationCache()
Expand Down
Loading