From 1324c71d7a39bdf9a1d851a80b6c5e0baebafee8 Mon Sep 17 00:00:00 2001 From: Andrii Kurdiumov Date: Thu, 9 Sep 2021 20:23:01 +0600 Subject: [PATCH] Implement ComWrappers.RegisterForTrackerSupport to be able create CCW This moves a bit towards #306 and #1453 --- .../InteropServices/ComWrappers.CoreRT.cs | 204 ++++++++++++++++-- 1 file changed, 191 insertions(+), 13 deletions(-) diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreRT.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreRT.cs index fce3316e0f9b..538f3b3edb5b 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreRT.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreRT.cs @@ -17,9 +17,17 @@ namespace System.Runtime.InteropServices /// public abstract partial class ComWrappers { + private const int TrackerRefShift = 32; + private const ulong TrackerRefCounter = 1UL << TrackerRefShift; + private const ulong DestroySentinel = 0x0000000080000000UL; + private const ulong TrackerRefCountMask = 0xffffffff00000000UL; + private const ulong ComRefCountMask = 0x000000007fffffffUL; + internal static IntPtr DefaultIUnknownVftblPtr { get; } = CreateDefaultIUnknownVftbl(); + internal static IntPtr DefaultIReferenceTrackerTargetVftblPtr { get; } = CreateDefaultIReferenceTrackerTargetVftbl(); internal static Guid IID_IUnknown = new Guid(0x00000000, 0x0000, 0x0000, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46); + internal static Guid IID_IReferenceTrackerTarget = new Guid(0x64bd43f8, 0xbfee, 0x4ec4, 0xb7, 0xeb, 0x29, 0x35, 0x15, 0x8d, 0xae, 0x21); private readonly ConditionalWeakTable _ccwTable = new ConditionalWeakTable(); private readonly Lock _lock = new Lock(); @@ -55,26 +63,100 @@ internal unsafe struct InternalComInterfaceDispatch internal ManagedObjectWrapper* _thisPtr; } + internal enum CreateComInterfaceFlagsEx + { + None = 0, + + /// + /// The caller will provide an IUnknown Vtable. + /// + /// + /// This is useful in scenarios when the caller has no need to rely on an IUnknown instance + /// that is used when running managed code is not possible (i.e. during a GC). In traditional + /// COM scenarios this is common, but scenarios involving Reference Tracker hosting + /// calling of the IUnknown API during a GC is possible. + /// + CallerDefinedIUnknown = 1, + + /// + /// Flag used to indicate the COM interface should implement IReferenceTrackerTarget. + /// When this flag is passed, the resulting COM interface will have an internal implementation of IUnknown + /// and as such none should be supplied by the caller. + /// + TrackerSupport = 2, + + LacksICustomQueryInterface = 1 << 29, + IsComActivated = 1 << 30, + IsPegged = 1 << 31, + + InternalMask = IsPegged | IsComActivated | LacksICustomQueryInterface, + } + internal unsafe struct ManagedObjectWrapper { public IntPtr Target; // This is GC Handle - public uint RefCount; + public ulong RefCount; public int UserDefinedCount; public ComInterfaceEntry* UserDefined; internal InternalComInterfaceDispatch* Dispatches; - internal CreateComInterfaceFlags Flags; + internal CreateComInterfaceFlagsEx Flags; public uint AddRef() { - return Interlocked.Increment(ref RefCount); + return GetComCount(Interlocked.Increment(ref RefCount)); } public uint Release() { - Debug.Assert(RefCount != 0); - return Interlocked.Decrement(ref RefCount); + Debug.Assert(GetComCount(RefCount) != 0); + return GetComCount(Interlocked.Decrement(ref RefCount)); + } + + public uint AddRefFromReferenceTracker() + { + ulong prev; + ulong curr; + do + { + prev = RefCount; + curr = prev + TrackerRefCounter; + } while (Interlocked.CompareExchange(ref RefCount, curr, prev) != prev); + + return GetTrackerCount(curr); + } + + public uint ReleaseFromReferenceTracker() + { + Debug.Assert(GetTrackerCount(RefCount) != 0); + ulong prev; + ulong curr; + do + { + prev = RefCount; + curr = prev - TrackerRefCounter; + } + while (Interlocked.CompareExchange(ref RefCount, curr, prev) != prev); + + // If we observe the destroy sentinel, then this release + // must destroy the wrapper. + if (RefCount == DestroySentinel) + Destroy(); + + return GetTrackerCount(RefCount); + } + + public uint Peg() + { + SetFlag(CreateComInterfaceFlagsEx.IsPegged); + return HResults.S_OK; + } + + public uint Unpeg() + { + ResetFlag(CreateComInterfaceFlagsEx.IsPegged); + return HResults.S_OK; } public unsafe int QueryInterface(in Guid riid, out IntPtr ppvObject) @@ -114,12 +196,30 @@ public unsafe void Destroy() private unsafe IntPtr AsRuntimeDefined(in Guid riid) { - if ((Flags & CreateComInterfaceFlags.CallerDefinedIUnknown) == CreateComInterfaceFlags.None) + if ((Flags & CreateComInterfaceFlagsEx.CallerDefinedIUnknown) == CreateComInterfaceFlagsEx.None) { if (riid == IID_IUnknown) { return (IntPtr)(Dispatches + UserDefinedCount); } + + if ((Flags & CreateComInterfaceFlagsEx.TrackerSupport) == CreateComInterfaceFlagsEx.TrackerSupport) + { + if (riid == IID_IReferenceTrackerTarget) + { + return (IntPtr)(Dispatches + UserDefinedCount + 1); + } + } + } + else + { + if ((Flags & CreateComInterfaceFlagsEx.TrackerSupport) == CreateComInterfaceFlagsEx.TrackerSupport) + { + if (riid == IID_IReferenceTrackerTarget) + { + return (IntPtr)(Dispatches + UserDefinedCount); + } + } } return IntPtr.Zero; @@ -137,6 +237,33 @@ private unsafe IntPtr AsUserDefined(in Guid riid) return IntPtr.Zero; } + + private void SetFlag(CreateComInterfaceFlagsEx flag) + { + int setMask = (int)flag; + Interlocked.Or(ref Unsafe.As(ref Flags), setMask); + } + + private void ResetFlag(CreateComInterfaceFlagsEx flag) + { + int resetMask = ~(int)flag; + Interlocked.And(ref Unsafe.As(ref Flags), resetMask); + } + + private static uint GetTrackerCount(ulong c) + { + return (uint)((c & TrackerRefCountMask) >> TrackerRefShift); + } + + private static uint GetComCount(ulong c) + { + return (uint)(c & ComRefCountMask); + } + + private static bool IsMarkedToDestroy(ulong c) + { + return (c & DestroySentinel) != 0; + } } internal unsafe class ManagedObjectWrapperHolder @@ -184,12 +311,10 @@ public NativeObjectWrapper(IntPtr externalComObject, ComWrappers comWrappers, ob } } -#if false /// /// Globally registered instance of the ComWrappers class for reference tracker support. /// private static ComWrappers? s_globalInstanceForTrackerSupport; -#endif /// /// Globally registered instance of the ComWrappers class for marshalling. @@ -240,6 +365,11 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom runtimeDefinedVtable[runtimeDefinedCount++] = DefaultIUnknownVftblPtr; } + if ((flags & CreateComInterfaceFlags.TrackerSupport) == CreateComInterfaceFlags.TrackerSupport) + { + runtimeDefinedVtable[runtimeDefinedCount++] = DefaultIReferenceTrackerTargetVftblPtr; + } + // Compute size for ManagedObjectWrapper instance. int totalDefinedCount = runtimeDefinedCount + userDefinedCount; @@ -262,7 +392,7 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom mow->RefCount = 1; mow->UserDefinedCount = userDefinedCount; mow->UserDefined = userDefined; - mow->Flags = flags; + mow->Flags = (CreateComInterfaceFlagsEx)flags; mow->Dispatches = pDispatches; return mow; } @@ -364,6 +494,9 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( if (flags.HasFlag(CreateObjectFlags.Aggregation)) throw new NotImplementedException(); + if (flags.HasFlag(CreateObjectFlags.TrackerObject)) + throw new NotImplementedException(); + if (flags.HasFlag(CreateObjectFlags.Unwrap)) { var comInterfaceDispatch = TryGetComInterfaceDispatch(externalComObject); @@ -440,7 +573,6 @@ private void RemoveRCWFromCache(IntPtr comPointer) /// public static void RegisterForTrackerSupport(ComWrappers instance) { -#if false if (instance == null) throw new ArgumentNullException(nameof(instance)); @@ -448,9 +580,6 @@ public static void RegisterForTrackerSupport(ComWrappers instance) { throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance); } -#else - throw new NotImplementedException(); -#endif } /// @@ -554,11 +683,60 @@ internal static unsafe uint IUnknown_Release(IntPtr pThis) return refcount; } + [UnmanagedCallersOnly] + internal static unsafe int IReferenceTrackerTarget_QueryInterface(IntPtr pThis, Guid* guid, IntPtr* ppObject) + { + ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis); + return wrapper->QueryInterface(in *guid, out *ppObject); + } + + [UnmanagedCallersOnly] + internal static unsafe uint IReferenceTrackerTarget_AddRefFromReferenceTracker(IntPtr pThis) + { + ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis); + return wrapper->AddRefFromReferenceTracker(); + } + + [UnmanagedCallersOnly] + internal static unsafe uint IReferenceTrackerTarget_ReleaseFromReferenceTracker(IntPtr pThis) + { + ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis); + return wrapper->ReleaseFromReferenceTracker(); + } + + [UnmanagedCallersOnly] + internal static unsafe uint IReferenceTrackerTarget_Peg(IntPtr pThis) + { + ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis); + return wrapper->Peg(); + } + + [UnmanagedCallersOnly] + internal static unsafe uint IReferenceTrackerTarget_Unpeg(IntPtr pThis) + { + ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis); + return wrapper->Unpeg(); + } + private static unsafe IntPtr CreateDefaultIUnknownVftbl() { IntPtr* vftbl = (IntPtr*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComWrappers), 3 * sizeof(IntPtr)); GetIUnknownImpl(out vftbl[0], out vftbl[1], out vftbl[2]); return (IntPtr)vftbl; } + + private static unsafe IntPtr CreateDefaultIReferenceTrackerTargetVftbl() + { + IntPtr* vftbl = (IntPtr*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComWrappers), 7 * sizeof(IntPtr)); + GetIUnknownImpl(out vftbl[0], out vftbl[1], out vftbl[2]); + vftbl[0] = (IntPtr)(delegate* unmanaged)&ComWrappers.IReferenceTrackerTarget_QueryInterface; + vftbl[1] = (IntPtr)(delegate* unmanaged)&ComWrappers.IUnknown_AddRef; + vftbl[2] = (IntPtr)(delegate* unmanaged)&ComWrappers.IUnknown_Release; + vftbl[3] = (IntPtr)(delegate* unmanaged)&ComWrappers.IReferenceTrackerTarget_AddRefFromReferenceTracker; + vftbl[4] = (IntPtr)(delegate* unmanaged)&ComWrappers.IReferenceTrackerTarget_ReleaseFromReferenceTracker; + vftbl[5] = (IntPtr)(delegate* unmanaged)&ComWrappers.IReferenceTrackerTarget_Peg; + vftbl[6] = (IntPtr)(delegate* unmanaged)&ComWrappers.IReferenceTrackerTarget_Unpeg; + return (IntPtr)vftbl; + } } }