Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ComWrappers.RegisterForTrackerSupport to be able create CCW #1544

Merged
merged 7 commits into from
Sep 10, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ namespace System.Runtime.InteropServices
/// </summary>
public abstract partial class ComWrappers
{
private const int TrackerRefShift = 32;
private const ulong TrackerRefCounter = 1UL << TrackerRefShift;
private const ulong DestroySentinel = 0x0000000080000000UL;
private const ulong TrackerRefCountMask = 0xffffffff00000000UL;
private const ulong ComRefCountMask = 0x000000007fffffffUL;

internal static IntPtr DefaultIUnknownVftblPtr { get; } = CreateDefaultIUnknownVftbl();
internal static IntPtr DefaultIReferenceTrackerTargetVftblPtr { get; } = CreateDefaultIReferenceTrackerTargetVftbl();

internal static Guid IID_IUnknown = new Guid(0x00000000, 0x0000, 0x0000, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46);
internal static Guid IID_IReferenceTrackerTarget = new Guid(0x64bd43f8, 0xbfee, 0x4ec4, 0xb7, 0xeb, 0x29, 0x35, 0x15, 0x8d, 0xae, 0x21);

private readonly ConditionalWeakTable<object, ManagedObjectWrapperHolder> _ccwTable = new ConditionalWeakTable<object, ManagedObjectWrapperHolder>();
private readonly Lock _lock = new Lock();
Expand Down Expand Up @@ -55,26 +63,100 @@ internal unsafe struct InternalComInterfaceDispatch
internal ManagedObjectWrapper* _thisPtr;
}

internal enum CreateComInterfaceFlagsEx
{
None = 0,

/// <summary>
/// The caller will provide an IUnknown Vtable.
/// </summary>
/// <remarks>
/// This is useful in scenarios when the caller has no need to rely on an IUnknown instance
/// that is used when running managed code is not possible (i.e. during a GC). In traditional
/// COM scenarios this is common, but scenarios involving <see href="https://docs.microsoft.com/windows/win32/api/windows.ui.xaml.hosting.referencetracker/nn-windows-ui-xaml-hosting-referencetracker-ireferencetrackertarget">Reference Tracker hosting</see>
/// calling of the IUnknown API during a GC is possible.
/// </remarks>
CallerDefinedIUnknown = 1,

/// <summary>
/// Flag used to indicate the COM interface should implement <see href="https://docs.microsoft.com/windows/win32/api/windows.ui.xaml.hosting.referencetracker/nn-windows-ui-xaml-hosting-referencetracker-ireferencetrackertarget">IReferenceTrackerTarget</see>.
/// When this flag is passed, the resulting COM interface will have an internal implementation of IUnknown
/// and as such none should be supplied by the caller.
/// </summary>
TrackerSupport = 2,

LacksICustomQueryInterface = 1 << 29,
IsComActivated = 1 << 30,
IsPegged = 1 << 31,

InternalMask = IsPegged | IsComActivated | LacksICustomQueryInterface,
}

internal unsafe struct ManagedObjectWrapper
{
public IntPtr Target; // This is GC Handle
public uint RefCount;
public ulong RefCount;

public int UserDefinedCount;
public ComInterfaceEntry* UserDefined;
internal InternalComInterfaceDispatch* Dispatches;

internal CreateComInterfaceFlags Flags;
internal CreateComInterfaceFlagsEx Flags;

public uint AddRef()
{
return Interlocked.Increment(ref RefCount);
return GetComCount(Interlocked.Increment(ref RefCount));
}

public uint Release()
{
Debug.Assert(RefCount != 0);
return Interlocked.Decrement(ref RefCount);
Debug.Assert(GetComCount(RefCount) != 0);
return GetComCount(Interlocked.Decrement(ref RefCount));
}

public uint AddRefFromReferenceTracker()
{
ulong prev;
ulong curr;
do
{
prev = RefCount;
curr = prev + TrackerRefCounter;
} while (Interlocked.CompareExchange(ref RefCount, curr, prev) != prev);

return GetTrackerCount(curr);
}

public uint ReleaseFromReferenceTracker()
{
Debug.Assert(GetTrackerCount(RefCount) != 0);
ulong prev;
ulong curr;
do
{
prev = RefCount;
curr = prev - TrackerRefCounter;
}
while (Interlocked.CompareExchange(ref RefCount, curr, prev) != prev);

// If we observe the destroy sentinel, then this release
// must destroy the wrapper.
if (RefCount == DestroySentinel)
Destroy();

return GetTrackerCount(RefCount);
}

public uint Peg()
{
SetFlag(CreateComInterfaceFlagsEx.IsPegged);
return HResults.S_OK;
}

public uint Unpeg()
{
ResetFlag(CreateComInterfaceFlagsEx.IsPegged);
return HResults.S_OK;
}

public unsafe int QueryInterface(in Guid riid, out IntPtr ppvObject)
Expand Down Expand Up @@ -114,12 +196,25 @@ public unsafe void Destroy()

private unsafe IntPtr AsRuntimeDefined(in Guid riid)
{
if ((Flags & CreateComInterfaceFlags.CallerDefinedIUnknown) == CreateComInterfaceFlags.None)
int i = UserDefinedCount;
if ((Flags & CreateComInterfaceFlagsEx.CallerDefinedIUnknown) == 0)
{
if (riid == IID_IUnknown)
{
return (IntPtr)(Dispatches + UserDefinedCount);
return (IntPtr)(Dispatches + i);
}

i++;
}

if ((Flags & CreateComInterfaceFlagsEx.TrackerSupport) != 0)
{
if (riid == IID_IReferenceTrackerTarget)
{
return (IntPtr)(Dispatches + i);
}
jkotas marked this conversation as resolved.
Show resolved Hide resolved

i++;
}

return IntPtr.Zero;
Expand All @@ -137,6 +232,33 @@ private unsafe IntPtr AsUserDefined(in Guid riid)

return IntPtr.Zero;
}

private void SetFlag(CreateComInterfaceFlagsEx flag)
{
int setMask = (int)flag;
Interlocked.Or(ref Unsafe.As<CreateComInterfaceFlagsEx, int>(ref Flags), setMask);
}

private void ResetFlag(CreateComInterfaceFlagsEx flag)
{
int resetMask = ~(int)flag;
Interlocked.And(ref Unsafe.As<CreateComInterfaceFlagsEx, int>(ref Flags), resetMask);
}

private static uint GetTrackerCount(ulong c)
{
return (uint)((c & TrackerRefCountMask) >> TrackerRefShift);
}

private static uint GetComCount(ulong c)
{
return (uint)(c & ComRefCountMask);
}

private static bool IsMarkedToDestroy(ulong c)
{
return (c & DestroySentinel) != 0;
}
}

internal unsafe class ManagedObjectWrapperHolder
Expand Down Expand Up @@ -184,12 +306,10 @@ public NativeObjectWrapper(IntPtr externalComObject, ComWrappers comWrappers, ob
}
}

#if false
/// <summary>
/// Globally registered instance of the ComWrappers class for reference tracker support.
/// </summary>
private static ComWrappers? s_globalInstanceForTrackerSupport;
#endif

/// <summary>
/// Globally registered instance of the ComWrappers class for marshalling.
Expand Down Expand Up @@ -240,6 +360,11 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
runtimeDefinedVtable[runtimeDefinedCount++] = DefaultIUnknownVftblPtr;
}

if ((flags & CreateComInterfaceFlags.TrackerSupport) != 0)
{
runtimeDefinedVtable[runtimeDefinedCount++] = DefaultIReferenceTrackerTargetVftblPtr;
}

// Compute size for ManagedObjectWrapper instance.
int totalDefinedCount = runtimeDefinedCount + userDefinedCount;

Expand All @@ -262,7 +387,7 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
mow->RefCount = 1;
mow->UserDefinedCount = userDefinedCount;
mow->UserDefined = userDefined;
mow->Flags = flags;
mow->Flags = (CreateComInterfaceFlagsEx)flags;
mow->Dispatches = pDispatches;
return mow;
}
Expand Down Expand Up @@ -364,6 +489,9 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
if (flags.HasFlag(CreateObjectFlags.Aggregation))
throw new NotImplementedException();

if (flags.HasFlag(CreateObjectFlags.TrackerObject))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure that I understand this right: This blocks the newly added code from getting executed for now. Is that correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implement only CCW. I thought that throw that for RCW was a right way, but if I remove that lines, then app from #1453 is basically working (except it does not produce any Toast for me). But exit code == 0. So I would remove these lines, and let it leak. At least you wasn't agains that previously.

throw new NotImplementedException();

if (flags.HasFlag(CreateObjectFlags.Unwrap))
{
var comInterfaceDispatch = TryGetComInterfaceDispatch(externalComObject);
Expand Down Expand Up @@ -440,17 +568,13 @@ private void RemoveRCWFromCache(IntPtr comPointer)
/// </remarks>
public static void RegisterForTrackerSupport(ComWrappers instance)
{
#if false
if (instance == null)
throw new ArgumentNullException(nameof(instance));

if (null != Interlocked.CompareExchange(ref s_globalInstanceForTrackerSupport, instance, null))
{
throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance);
}
#else
throw new NotImplementedException();
#endif
}

/// <summary>
Expand Down Expand Up @@ -554,11 +678,60 @@ internal static unsafe uint IUnknown_Release(IntPtr pThis)
return refcount;
}

[UnmanagedCallersOnly]
internal static unsafe int IReferenceTrackerTarget_QueryInterface(IntPtr pThis, Guid* guid, IntPtr* ppObject)
{
ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis);
return wrapper->QueryInterface(in *guid, out *ppObject);
}

[UnmanagedCallersOnly]
internal static unsafe uint IReferenceTrackerTarget_AddRefFromReferenceTracker(IntPtr pThis)
{
ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis);
return wrapper->AddRefFromReferenceTracker();
}

[UnmanagedCallersOnly]
internal static unsafe uint IReferenceTrackerTarget_ReleaseFromReferenceTracker(IntPtr pThis)
{
ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis);
return wrapper->ReleaseFromReferenceTracker();
}

[UnmanagedCallersOnly]
internal static unsafe uint IReferenceTrackerTarget_Peg(IntPtr pThis)
{
ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis);
return wrapper->Peg();
}

[UnmanagedCallersOnly]
internal static unsafe uint IReferenceTrackerTarget_Unpeg(IntPtr pThis)
{
ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis);
return wrapper->Unpeg();
}

private static unsafe IntPtr CreateDefaultIUnknownVftbl()
{
IntPtr* vftbl = (IntPtr*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComWrappers), 3 * sizeof(IntPtr));
GetIUnknownImpl(out vftbl[0], out vftbl[1], out vftbl[2]);
return (IntPtr)vftbl;
}

private static unsafe IntPtr CreateDefaultIReferenceTrackerTargetVftbl()
{
IntPtr* vftbl = (IntPtr*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComWrappers), 7 * sizeof(IntPtr));
GetIUnknownImpl(out vftbl[0], out vftbl[1], out vftbl[2]);
jkotas marked this conversation as resolved.
Show resolved Hide resolved
vftbl[0] = (IntPtr)(delegate* unmanaged<IntPtr, Guid*, IntPtr*, int>)&ComWrappers.IReferenceTrackerTarget_QueryInterface;
vftbl[1] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IUnknown_AddRef;
vftbl[2] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IUnknown_Release;
vftbl[3] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IReferenceTrackerTarget_AddRefFromReferenceTracker;
vftbl[4] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IReferenceTrackerTarget_ReleaseFromReferenceTracker;
vftbl[5] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IReferenceTrackerTarget_Peg;
vftbl[6] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IReferenceTrackerTarget_Unpeg;
return (IntPtr)vftbl;
}
}
}