Skip to content

Commit

Permalink
fix: make sure we instantiate non-main modules (#93)
Browse files Browse the repository at this point in the history
Fixes #92

---------

Signed-off-by: Edoardo Vacchi <[email protected]>
Co-authored-by: Edoardo Vacchi <[email protected]>
  • Loading branch information
mhmd-azeez and evacchi authored Jan 30, 2025
1 parent 1e14b80 commit 9679867
Show file tree
Hide file tree
Showing 14 changed files with 440 additions and 67 deletions.
30 changes: 14 additions & 16 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ func RuntimeVersion() string {

// Runtime represents the Extism plugin's runtime environment, including the underlying Wazero runtime and modules.
type Runtime struct {
Wazero wazero.Runtime
Extism api.Module
Env api.Module
hasWasi bool
Wazero wazero.Runtime
Extism api.Module
Env api.Module
}

// PluginInstanceConfig contains configuration options for the Extism plugin.
Expand Down Expand Up @@ -112,13 +111,12 @@ func (l LogLevel) String() string {

// Plugin is used to call WASM functions
type Plugin struct {
close []func(ctx context.Context) error
extism api.Module

module api.Module
Timeout time.Duration
Config map[string]string
// NOTE: maybe we can have some nice methods for getting/setting vars
close []func(ctx context.Context) error
extism api.Module
mainModule api.Module
modules map[string]api.Module
Timeout time.Duration
Config map[string]string
Var map[string][]byte
AllowedHosts []string
AllowedPaths map[string]string
Expand All @@ -138,7 +136,7 @@ func logStd(level LogLevel, message string) {
}

func (p *Plugin) Module() *Module {
return &Module{inner: p.module}
return &Module{inner: p.mainModule}
}

// SetLogger sets a custom logging callback
Expand Down Expand Up @@ -443,7 +441,7 @@ func (p *Plugin) GetErrorWithContext(ctx context.Context) string {

// FunctionExists returns true when the named function is present in the plugin's main Module
func (p *Plugin) FunctionExists(name string) bool {
return p.module.ExportedFunction(name) != nil
return p.mainModule.ExportedFunction(name) != nil
}

// Call a function by name with the given input, returning the output
Expand All @@ -469,15 +467,15 @@ func (p *Plugin) CallWithContext(ctx context.Context, name string, data []byte)

ctx = context.WithValue(ctx, InputOffsetKey("inputOffset"), intputOffset)

var f = p.module.ExportedFunction(name)
var f = p.mainModule.ExportedFunction(name)

if f == nil {
return 1, []byte{}, fmt.Errorf("unknown function: %s", name)
} else if n := len(f.Definition().ResultTypes()); n > 1 {
return 1, []byte{}, fmt.Errorf("function %s has %v results, expected 0 or 1", name, n)
}

var isStart = name == "_start"
var isStart = name == "_start" || name == "_initialize"
if p.guestRuntime.init != nil && !isStart && !p.guestRuntime.initialized {
err := p.guestRuntime.init(ctx)
if err != nil {
Expand All @@ -501,7 +499,7 @@ func (p *Plugin) CallWithContext(ctx context.Context, name string, data []byte)
if exitCode == 0 {
// It's possible for the function to return 0 as an error code, even
// if the module is closed.
if p.module.IsClosed() {
if p.mainModule.IsClosed() {
return 0, nil, fmt.Errorf("module is closed")
}
err = nil
Expand Down
177 changes: 171 additions & 6 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"strings"
"sync"
"testing"
"time"

observe "github.com/dylibso/observe-sdk/go"
"github.com/dylibso/observe-sdk/go/adapter/stdout"
"github.com/stretchr/testify/assert"
Expand All @@ -13,12 +20,6 @@ import (
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/experimental/logging"
"github.com/tetratelabs/wazero/sys"
"log"
"os"
"strings"
"sync"
"testing"
"time"
)

func TestWasmUrl(t *testing.T) {
Expand Down Expand Up @@ -1038,6 +1039,170 @@ func TestEnableExperimentalFeature(t *testing.T) {
}
}

func TestModuleLinking(t *testing.T) {
manifest := Manifest{
Wasm: []Wasm{
WasmFile{
Path: "wasm/lib.wasm",
Name: "lib",
},
WasmFile{
Path: "wasm/main.wasm",
Name: "main",
},
},
}

if plugin, ok := pluginInstance(t, manifest); ok {
defer plugin.Close(context.Background())

exit, output, err := plugin.Call("run_test", []byte("benjamin"))

if assertCall(t, err, exit) {
expected := "Hello, BENJAMIN"

actual := string(output)

assert.Equal(t, expected, actual)
}
}
}

func TestModuleLinkingMultipleInstances(t *testing.T) {
manifest := Manifest{
Wasm: []Wasm{
WasmFile{
Path: "wasm/lib.wasm",
Name: "lib",
},
WasmFile{
Path: "wasm/main.wasm",
Name: "main",
},
},
}

ctx := context.Background()
config := wasiPluginConfig()

compiledPlugin, err := NewCompiledPlugin(ctx, manifest, PluginConfig{
EnableWasi: true,
}, []HostFunction{})

if err != nil {
t.Fatalf("Could not create plugin: %v", err)
}

for i := 0; i < 3; i++ {
plugin, err := compiledPlugin.Instance(ctx, config)
if err != nil {
t.Fatalf("Could not create plugin instance: %v", err)
}
// purposefully not closing the plugin instance

for j := 0; j < 3; j++ {

exit, output, err := plugin.Call("run_test", []byte("benjamin"))

if assertCall(t, err, exit) {
expected := "Hello, BENJAMIN"

actual := string(output)

assert.Equal(t, expected, actual)
}
}
}
}

func TestCompiledModuleMultipleInstances(t *testing.T) {
manifest := Manifest{
Wasm: []Wasm{
WasmFile{
Path: "wasm/count_vowels.wasm",
Name: "main",
},
},
}

ctx := context.Background()
config := wasiPluginConfig()

compiledPlugin, err := NewCompiledPlugin(ctx, manifest, PluginConfig{
EnableWasi: true,
}, []HostFunction{})

if err != nil {
t.Fatalf("Could not create plugin: %v", err)
}

var wg sync.WaitGroup
numInstances := 300

// Create and test instances in parallel
for i := 0; i < numInstances; i++ {
wg.Add(1)
go func(instanceNum int) {
defer wg.Done()

plugin, err := compiledPlugin.Instance(ctx, config)
if err != nil {
t.Errorf("Could not create plugin instance %d: %v", instanceNum, err)
return
}
// purposefully not closing the plugin instance

// Sequential calls for this instance
for j := 0; j < 3; j++ {
exit, _, err := plugin.Call("count_vowels", []byte("benjamin"))
if err != nil {
t.Errorf("Instance %d, call %d failed: %v", instanceNum, j, err)
return
}
if exit != 0 {
t.Errorf("Instance %d, call %d returned non-zero exit code: %d", instanceNum, j, exit)
}
}
}(i)
}
wg.Wait()
}

func TestMultipleCallsOutputParallel(t *testing.T) {
manifest := manifest("count_vowels.wasm")
numInstances := 300

var wg sync.WaitGroup

// Create and test instances in parallel
for i := 0; i < numInstances; i++ {
wg.Add(1)
go func(instanceNum int) {
defer wg.Done()

if plugin, ok := pluginInstance(t, manifest); ok {
defer plugin.Close(context.Background())

// Sequential calls for this instance
exit, output1, err := plugin.Call("count_vowels", []byte("aaa"))
if !assertCall(t, err, exit) {
return
}

exit, output2, err := plugin.Call("count_vowels", []byte("bbba"))
if !assertCall(t, err, exit) {
return
}

assert.Equal(t, `{"count":3,"total":3,"vowels":"aeiouAEIOU"}`, string(output1))
assert.Equal(t, `{"count":1,"total":4,"vowels":"aeiouAEIOU"}`, string(output2))
}
}(i)
}

wg.Wait()
}

func BenchmarkInitialize(b *testing.B) {
ctx := context.Background()
cache := wazero.NewCompilationCache()
Expand Down
Loading

0 comments on commit 9679867

Please sign in to comment.