Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass original .NET callback exception as TrapException's inner exception #172

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions src/Function.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2450,12 +2450,26 @@ internal unsafe static IntPtr InvokeCallback(Delegate callback, MethodInfo callb
}
catch (Exception ex)
{
var bytes = Encoding.UTF8.GetBytes(ex.Message);
return HandleCallbackException(ex);
}
}

fixed (byte* ptr = bytes)
{
return Native.wasmtime_trap_new(ptr, (UIntPtr)bytes.Length);
}
internal static unsafe IntPtr HandleCallbackException(Exception ex)
{
// Store the exception as trap cause, so that we can use it as the TrapException's
// InnerException when the trap bubbles up to the next host-to-wasm transition.
// If the exception is already a TrapException, we use that one's InnerException,
// even if it's null.
// Note: This code currently requires that on every host-to-wasm transition where a
// trap can occur, TrapException.FromOwnedTrap() is called when a trap actually occured,
// which will then clear this field.
CallbackTrapCause = ex is TrapException trapException ? trapException.InnerException : ex;

var bytes = Encoding.UTF8.GetBytes(ex.Message);

fixed (byte* ptr = bytes)
{
return Native.wasmtime_trap_new(ptr, (nuint)bytes.Length);
}
}

Expand Down Expand Up @@ -2503,7 +2517,7 @@ internal static class Native
public static extern void wasm_functype_delete(IntPtr functype);

[DllImport(Engine.LibraryName)]
public static unsafe extern IntPtr wasmtime_trap_new(byte* bytes, UIntPtr len);
public static unsafe extern IntPtr wasmtime_trap_new(byte* bytes, nuint len);
}

private readonly IStore? store;
Expand All @@ -2512,6 +2526,19 @@ internal static class Native
internal readonly List<ValueKind> results = new List<ValueKind>();
internal static readonly Native.Finalizer Finalizer = (p) => GCHandle.FromIntPtr(p).Free();

/// <summary>
/// Contains the cause for a trap returned by invoking a wasm function, in case
/// the trap was caused by the host.
/// </summary>
/// <remarks>
/// This thread-local field will be set when catching a .NET exception at the
/// wasm-to-host transition. When the trap bubbles up to the next host-to-wasm
/// transition, the field needs to be cleared, and its value can be used to set
/// the inner exception of the created <see cref="TrapException"/>.
/// </remarks>
[ThreadStatic]
internal static Exception? CallbackTrapCause;

private static readonly Function _null = new Function();
private static readonly object?[] NullParams = new object?[1];
}
Expand Down
17 changes: 14 additions & 3 deletions src/TrapException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public TrapException() { }
public TrapException(string message) : base(message) { }

/// <inheritdoc/>
public TrapException(string message, Exception inner) : base(message, inner) { }
public TrapException(string message, Exception? inner) : base(message, inner) { }

/// <summary>
/// Gets the trap's frames.
Expand All @@ -281,18 +281,29 @@ public TrapException(string message, Exception inner) : base(message, inner) { }
/// <inheritdoc/>
protected TrapException(SerializationInfo info, StreamingContext context) : base(info, context) { }

internal TrapException(string message, IReadOnlyList<TrapFrame>? frames, TrapCode type) : base(message)
internal TrapException(string message, IReadOnlyList<TrapFrame>? frames, TrapCode type, Exception? innerException = null)
: base(message, innerException)
{
Type = type;
Frames = frames;
}

internal static TrapException FromOwnedTrap(IntPtr trap, bool delete = true)
{
// Get the cause of the trap if available (in case the trap was caused by a
// .NET exception thrown in a callback).
var callbackTrapCause = Function.CallbackTrapCause;

if (callbackTrapCause is not null)
{
// Clear the field as we consumed the value.
Function.CallbackTrapCause = null;
}

var accessor = new TrapAccessor(trap);
try
{
var trappedException = new TrapException(accessor.Message, accessor.GetFrames(), accessor.TrapCode)
var trappedException = new TrapException(accessor.Message, accessor.GetFrames(), accessor.TrapCode, callbackTrapCause)
{
ExitCode = accessor.ExitStatus
};
Expand Down
2 changes: 1 addition & 1 deletion src/WasmtimeException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public WasmtimeException() { }
public WasmtimeException(string message) : base(message) { }

/// <inheritdoc/>
public WasmtimeException(string message, Exception inner) : base(message, inner) { }
public WasmtimeException(string message, Exception? inner) : base(message, inner) { }

/// <inheritdoc/>
protected WasmtimeException(SerializationInfo info, StreamingContext context) : base(info, context) { }
Expand Down
10 changes: 10 additions & 0 deletions tests/Modules/Trap.wat
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
(module
(import "" "host_trap" (func $host_trap))
(import "" "trap_from_host_exception" (func $trap_from_host_exception))
(import "" "call_host_callback" (func $call_host_callback))
(export "ok" (func $ok))
(export "ok_value" (func $ok_value))
(export "run" (func $run))
(export "run_div_zero" (func $run_div_zero))
(export "run_div_zero_with_result" (func $run_div_zero_with_result))
(export "host_trap" (func $host_trap))
(export "trap_from_host_exception" (func $trap_from_host_exception))
(export "call_host_callback" (func $call_host_callback))
(export "trap_in_wasm" (func $third))
(start $start)

(func $start
(call $call_host_callback)
)

(func $run
(call $first)
Expand Down
91 changes: 91 additions & 0 deletions tests/TrapTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.IO;

using FluentAssertions;
using Xunit;

Expand Down Expand Up @@ -47,13 +49,25 @@ public class TrapTests : IClassFixture<TrapFixture>, IDisposable

private Linker Linker { get; set; }

private Action TrapFromHostExceptionCallback { get; set; }

private Action HostCallback { get; set; }

public TrapTests(TrapFixture fixture)
{
Fixture = fixture;
Store = new Store(Fixture.Engine);
Linker = new Linker(Fixture.Engine);

Linker.Define("", "host_trap", Function.FromCallback(Store, () => throw new Exception()));

Linker.Define("", "trap_from_host_exception", Function.FromCallback(
Store,
() => TrapFromHostExceptionCallback?.Invoke()));

Linker.Define("", "call_host_callback", Function.FromCallback(
Store,
() => HostCallback?.Invoke()));
}

[Fact]
Expand Down Expand Up @@ -203,6 +217,83 @@ public void ItHandlesCustomResultTypeWithTrapResult()
result.TrapStackDepth.Should().Be(1);
}

[Fact]
public void ItPassesCallbackTrapCauseAsInnerException()
{
var instance = Linker.Instantiate(Store, Fixture.Module);
var callTrap = instance.GetAction("trap_from_host_exception");
var trapInWasm = instance.GetAction("trap_in_wasm");

var exceptionToThrow = new IOException("My I/O exception.");

TrapFromHostExceptionCallback = () => throw exceptionToThrow;

// Verify that the IOException thrown at the host callback is passed as
// InnerException to the TrapException thrown on the host-to-wasm transition.
var action = callTrap;

action
.Should()
.Throw<TrapException>()
.Where(e => e.Type == TrapCode.Undefined &&
e.InnerException == exceptionToThrow);

// After that, ensure that when invoking another function that traps in wasm
// (so it cannot have a cause), the TrapException's InnerException is now null.
action = trapInWasm;
action
.Should()
.Throw<TrapException>()
.Where(e => e.Type == TrapCode.Unreachable &&
e.InnerException == null);

// Also verify the InnerException is set when using an ActionResult.
var callTrapAsActionResult = instance.GetFunction<ActionResult>("trap_from_host_exception");
var result = callTrapAsActionResult();

result.Type.Should().Be(ResultType.Trap);
result.Trap.Type.Should().Be(TrapCode.Undefined);
result.Trap.InnerException.Should().Be(exceptionToThrow);
}

[Fact]
public void ItPassesCallbackTrapCauseAsInnerExceptionOverTwoLevels()
{
var instance = Linker.Instantiate(Store, Fixture.Module);
var callTrap = instance.GetAction("trap_from_host_exception");
var callHostCallback = instance.GetAction("call_host_callback");

var exceptionToThrow = new IOException("My I/O exception.");

TrapFromHostExceptionCallback = () => throw exceptionToThrow;
HostCallback = callTrap;

// Verify that the IOException is passed as InnerException to the
// TrapException even after two levels of wasm-to-host transitions.
var action = callHostCallback;

action
.Should()
.Throw<TrapException>()
.Where(e => e.Type == TrapCode.Undefined &&
e.InnerException == exceptionToThrow);
}

[Fact]
public void ItPassesCallbackTrapCauseAsInnerExceptionWhenInstantiating()
{
var exceptionToThrow = new IOException("My I/O exception.");
HostCallback = () => throw exceptionToThrow;

var action = () => Linker.Instantiate(Store, Fixture.Module);

action
.Should()
.Throw<TrapException>()
.Where(e => e.Type == TrapCode.Undefined &&
e.InnerException == exceptionToThrow);
}

public void Dispose()
{
Store.Dispose();
Expand Down