Skip to content

Commit

Permalink
Add test for AssemblyLoadContext.LoadUnmanagedDll
Browse files Browse the repository at this point in the history
  • Loading branch information
elinor-fung committed Apr 1, 2020
1 parent 1f90029 commit 8c68764
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -1,116 +1,161 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using System.Runtime.Loader;
using System.Reflection;
using System.Runtime.InteropServices;

using TestLibrary;

public class ALC : AssemblyLoadContext
{
protected override Assembly Load(AssemblyName assemblyName)
public bool LoadUnmanagedDllCalled { get; private set; }

protected override IntPtr LoadUnmanagedDll(string unmanagedDllName)
{
return Assembly.Load(assemblyName);
LoadUnmanagedDllCalled = true;

if (string.Equals(unmanagedDllName, NativeLibraryToLoad.InvalidName))
return LoadUnmanagedDllFromPath(NativeLibraryToLoad.GetFullPath());

return IntPtr.Zero;
}
}

public class ResolveUnmanagedDllTests
{
static int HandlerTracker = 1;
private static readonly int seed = 123;
private static readonly Random rand = new Random(seed);

public static int Main()
{
// Events on the Default Load Context

try
{
AssemblyLoadContext.Default.ResolvingUnmanagedDll += HandlerFail;
NativeSum(10, 10);
}
catch (DllNotFoundException e)
{
if (HandlerTracker != 0)
{
Console.WriteLine("Event Handlers not called as expected");
return 101;
}
ValidateLoadUnmanagedDll();
ValidateResolvingUnmanagedDllEvent();
}
catch (Exception e)
{
Console.WriteLine($"Unexpected exception: {e.Message}");
return 102;
Console.WriteLine($"Test Failure: {e}");
return 101;
}

try
{
AssemblyLoadContext.Default.ResolvingUnmanagedDll += HandlerPass;
return 100;
}

if(NativeSum(10, 10) != 20)
{
Console.WriteLine("Unexpected ReturnValue from NativeSum()");
return 103;
}
if (HandlerTracker != 0)
{
Console.WriteLine("Event Handlers not called as expected");
return 104;
}
}
catch (Exception e)
{
Console.WriteLine($"Unexpected exception: {e.Message}");
return 105;
}
public static void ValidateLoadUnmanagedDll()
{
Console.WriteLine($"Running {nameof(ValidateLoadUnmanagedDll)}...");

// Events on a Custom Load Context
Console.WriteLine(" -- Validate p/invoke...");
int addend1 = rand.Next(int.MaxValue / 2);
int addend2 = rand.Next(int.MaxValue / 2);
int expected = addend1 + addend2;

try
{
string currentDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location);
string testAsmDir = Path.Combine(currentDir, "..", "TestAsm", "TestAsm");
ALC alc = new ALC();
int value = NativeSumInAssemblyLoadContext(alc, addend1, addend2);
Assert.IsTrue(alc.LoadUnmanagedDllCalled, "AssemblyLoadContext.LoadUnmanagedDll should have been called.");
Assert.AreEqual(expected, value, $"Unexpected return value for {nameof(NativeSum)}");
}

public static void ValidateResolvingUnmanagedDllEvent()
{
Console.WriteLine($"Running {nameof(ValidateResolvingUnmanagedDllEvent)}...");

Console.WriteLine(" -- Validate p/invoke: custom ALC...");
AssemblyLoadContext alc = new AssemblyLoadContext(nameof(ValidateResolvingUnmanagedDllEvent));
ValidateResolvingUnmanagedDllEvent_PInvoke(alc);

ALC alc = new ALC();
alc.ResolvingUnmanagedDll += HandlerPass;
Console.WriteLine(" -- Validate p/invoke: default ALC...");
ValidateResolvingUnmanagedDllEvent_PInvoke(AssemblyLoadContext.Default);
}

var assembly = alc.LoadFromAssemblyPath(Path.Combine(testAsmDir, "TestAsm.dll"));
var type = assembly.GetType("TestAsm");
var method = type.GetMethod("Sum");
int value = (int)method.Invoke(null, new object[] { 10, 10 });
private static void ValidateResolvingUnmanagedDllEvent_PInvoke(AssemblyLoadContext alc)
{
int addend1 = rand.Next(int.MaxValue / 2);
int addend2 = rand.Next(int.MaxValue / 2);
int expected = addend1 + addend2;

if(value != 20)
using (var handler = new Handlers(alc, returnValid: false))
{
if (alc == AssemblyLoadContext.Default)
{
Console.WriteLine("Unexpected ReturnValue from TestAsm.Sum()");
return 106;
Assert.Throws<DllNotFoundException>(() => NativeSum(addend1, addend2));
}
if (HandlerTracker != 1)
else
{
Console.WriteLine("Event Handlers not called as expected");
return 107;
TargetInvocationException ex = Assert.Throws<TargetInvocationException>(() => NativeSumInAssemblyLoadContext(alc, addend1, addend2));
Assert.AreEqual(typeof(DllNotFoundException), ex.InnerException.GetType());
}

Assert.IsTrue(handler.EventHandlerInvoked, "Event handler should have been invoked");
}
catch (Exception e)

// Multiple handlers - first valid result is used
// Test using valid handlers is done last, as the result will be cached for the ALC
using (var handlerInvalid = new Handlers(alc, returnValid: false))
using (var handlerValid1 = new Handlers(alc, returnValid: true))
using (var handlerValid2 = new Handlers(alc, returnValid: true))
{
Console.WriteLine($"Unexpected exception: {e.Message}");
return 108;
int value = alc == AssemblyLoadContext.Default
? NativeSum(addend1, addend2)
: NativeSumInAssemblyLoadContext(alc, addend1, addend2);

Assert.IsTrue(handlerInvalid.EventHandlerInvoked, "Event handler should have been invoked");
Assert.IsTrue(handlerValid1.EventHandlerInvoked, "Event handler should have been invoked");
Assert.IsFalse(handlerValid2.EventHandlerInvoked, "Event handler should not have been invoked");
Assert.AreEqual(expected, value, $"Unexpected return value for {nameof(NativeSum)} in {alc}");
}

return 100;
}

public static IntPtr HandlerFail(Assembly assembly, string libraryName)
private static int NativeSumInAssemblyLoadContext(AssemblyLoadContext alc, int addend1, int addend2)
{
HandlerTracker--;
return IntPtr.Zero;
}
string currentDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location);
var assembly = alc.LoadFromAssemblyPath(Path.Combine(currentDir, "TestAsm.dll"));
var type = assembly.GetType("TestAsm");
var method = type.GetMethod("Sum");

public static IntPtr HandlerPass(Assembly assembly, string libraryName)
{
HandlerTracker++;
return NativeLibrary.Load(NativeLibraryToLoad.Name, assembly, null);
int value = (int)method.Invoke(null, new object[] { addend1, addend2 });
return value;
}

[DllImport("DoesNotExist")]
[DllImport(NativeLibraryToLoad.InvalidName)]
static extern int NativeSum(int arg1, int arg2);

private class Handlers : IDisposable
{
private AssemblyLoadContext alc;
private bool returnValid;

public bool EventHandlerInvoked { get; private set; }

public Handlers(AssemblyLoadContext alc, bool returnValid)
{
this.alc = alc;
this.returnValid = returnValid;
this.EventHandlerInvoked = false;
this.alc.ResolvingUnmanagedDll += OnResolvingUnmanagedDll;
}

public void Dispose()
{
this.alc.ResolvingUnmanagedDll -= OnResolvingUnmanagedDll;
}

private IntPtr OnResolvingUnmanagedDll(Assembly assembly, string libraryName)
{
EventHandlerInvoked = true;

if (!this.returnValid)
return IntPtr.Zero;

if (string.Equals(libraryName, NativeLibraryToLoad.InvalidName))
return NativeLibrary.Load(NativeLibraryToLoad.Name, assembly, null);

return IntPtr.Zero;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
</ItemGroup>
<ItemGroup>
<ProjectReference Include="../NativeLibraryToLoad/CMakeLists.txt" />
<ProjectReference Include="$(TestSourceDir)Common/CoreCLRTestLibrary/CoreCLRTestLibrary.csproj" />
<ProjectReference Include="TestAsm/TestAsm.csproj">
<ReferenceOutputAssembly>false</ReferenceOutputAssembly>
<OutputItemType>Content</OutputItemType>
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</ProjectReference>
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ public static int Sum(int a, int b)
return NativeSum(a, b);
}

[DllImport("DoesNotExist")]
[DllImport(NativeLibraryToLoad.InvalidName)]
static extern int NativeSum(int arg1, int arg2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
</PropertyGroup>
<ItemGroup>
<Compile Include="*.cs" />
<Compile Include="../../NativeLibraryToLoad/NativeLibraryToLoad.cs" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public static int Main()
return -1;
}

[DllImport("DoesNotExist")]
[DllImport(NativeLibraryToLoad.InvalidName)]
[DefaultDllImportSearchPaths(DllImportSearchPath.System32)]
static extern int NativeSum(int arg1, int arg2);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public static int Main()
return 100;
}

[DllImport("DoesNotExist")]
[DllImport(NativeLibraryToLoad.InvalidName)]
[DefaultDllImportSearchPaths(DllImportSearchPath.System32)]
static extern int NativeSum(int arg1, int arg2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
public class NativeLibraryToLoad
{
public const string Name = "NativeLibrary";
public const string InvalidName = "DoesNotExist";

public static string GetFileName()
{
Expand Down

0 comments on commit 8c68764

Please sign in to comment.