-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
585 additions
and
323 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Runtime.CompilerServices; | ||
using System.Runtime.ExceptionServices; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using System.Threading.Tasks.Sources; | ||
|
||
namespace ValueTaskSupplement | ||
{ | ||
public static partial class ValueTaskEx | ||
{ | ||
public static ValueTask<T> FromResult<T>(T result) | ||
{ | ||
return new ValueTask<T>(result); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
using System; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using System.Threading.Tasks.Sources; | ||
|
||
namespace ValueTaskSupplement | ||
{ | ||
public static partial class ValueTaskEx | ||
{ | ||
public static ValueTask<T> Lazy<T>(Func<ValueTask<T>> factory) | ||
{ | ||
return new ValueTask<T>(new AsyncLazySource<T>(factory), 0); | ||
} | ||
|
||
class AsyncLazySource<T> : IValueTaskSource<T> | ||
{ | ||
static readonly ContextCallback execContextCallback = ExecutionContextCallback; | ||
static readonly SendOrPostCallback syncContextCallback = SynchronizationContextCallback; | ||
|
||
Func<ValueTask<T>> factory; | ||
object syncLock; | ||
ValueTask<T> source; | ||
bool initialized; | ||
|
||
public AsyncLazySource(Func<ValueTask<T>> factory) | ||
{ | ||
this.factory = factory; | ||
this.syncLock = new object(); | ||
} | ||
|
||
ValueTask<T> GetSource() | ||
{ | ||
return LazyInitializer.EnsureInitialized(ref source, ref initialized, ref syncLock, factory); | ||
} | ||
|
||
public T GetResult(short token) | ||
{ | ||
return GetSource().Result; | ||
} | ||
|
||
public ValueTaskSourceStatus GetStatus(short token) | ||
{ | ||
var task = GetSource(); | ||
return task.IsCompletedSuccessfully ? ValueTaskSourceStatus.Succeeded | ||
: task.IsCanceled ? ValueTaskSourceStatus.Canceled | ||
: task.IsFaulted ? ValueTaskSourceStatus.Faulted | ||
: ValueTaskSourceStatus.Pending; | ||
} | ||
|
||
public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) | ||
{ | ||
var task = GetSource(); | ||
if (task.IsCompleted) | ||
{ | ||
continuation(state); | ||
} | ||
OnCompletedSlow(task, continuation, state, flags); | ||
} | ||
|
||
static async void OnCompletedSlow(ValueTask<T> source, Action<object> continuation, object state, ValueTaskSourceOnCompletedFlags flags) | ||
{ | ||
ExecutionContext execContext = null; | ||
SynchronizationContext syncContext = null; | ||
if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) == ValueTaskSourceOnCompletedFlags.FlowExecutionContext) | ||
{ | ||
execContext = ExecutionContext.Capture(); | ||
} | ||
if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) == ValueTaskSourceOnCompletedFlags.UseSchedulingContext) | ||
{ | ||
syncContext = SynchronizationContext.Current; | ||
} | ||
|
||
await source.ConfigureAwait(false); | ||
|
||
if (execContext != null) | ||
{ | ||
ExecutionContext.Run(execContext, execContextCallback, Tuple.Create(continuation, state, syncContext)); | ||
} | ||
else if (syncContext != null) | ||
{ | ||
syncContext.Post(syncContextCallback, Tuple.Create(continuation, state, syncContext)); | ||
} | ||
else | ||
{ | ||
continuation(state); | ||
} | ||
} | ||
|
||
static void ExecutionContextCallback(object state) | ||
{ | ||
var t = (Tuple<Action<object>, object, SynchronizationContext>)state; | ||
if (t.Item3 != null) | ||
{ | ||
SynchronizationContextCallback(state); | ||
} | ||
else | ||
{ | ||
t.Item1.Invoke(t.Item2); | ||
} | ||
} | ||
|
||
static void SynchronizationContextCallback(object state) | ||
{ | ||
var t = (Tuple<Action<object>, object, SynchronizationContext>)state; | ||
t.Item1.Invoke(t.Item2); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Runtime.CompilerServices; | ||
using System.Runtime.ExceptionServices; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using System.Threading.Tasks.Sources; | ||
|
||
namespace ValueTaskSupplement | ||
{ | ||
public static partial class ValueTaskEx | ||
{ | ||
public static ValueTask<T[]> WhenAll<T>(IEnumerable<ValueTask<T>> tasks) | ||
{ | ||
return new ValueTask<T[]>(new WhenAllPromiseAll<T>(tasks), 0); | ||
} | ||
|
||
class WhenAllPromiseAll<T> : IValueTaskSource<T[]> | ||
{ | ||
static readonly ContextCallback execContextCallback = ExecutionContextCallback; | ||
static readonly SendOrPostCallback syncContextCallback = SynchronizationContextCallback; | ||
|
||
int completedCount = 0; | ||
ExceptionDispatchInfo exception; | ||
Action<object> continuation = ContinuationSentinel.AvailableContinuation; | ||
object state; | ||
SynchronizationContext syncContext; | ||
ExecutionContext execContext; | ||
|
||
T[] result; | ||
|
||
public WhenAllPromiseAll(IEnumerable<ValueTask<T>> tasks) | ||
{ | ||
if (tasks is ValueTask<T>[] array) | ||
{ | ||
Run(array); | ||
return; | ||
} | ||
if (tasks is IReadOnlyCollection<ValueTask<T>> c) | ||
{ | ||
Run(c, c.Count); | ||
return; | ||
} | ||
if (tasks is ICollection<ValueTask<T>> c2) | ||
{ | ||
Run(c2, c2.Count); | ||
return; | ||
} | ||
|
||
var list = new TempList<ValueTask<T>>(99); | ||
try | ||
{ | ||
foreach (var item in tasks) | ||
{ | ||
list.Add(item); | ||
} | ||
|
||
Run(list.AsSpan()); | ||
} | ||
finally | ||
{ | ||
list.Dispose(); | ||
} | ||
} | ||
|
||
void Run(ReadOnlySpan<ValueTask<T>> tasks) | ||
{ | ||
result = new T[tasks.Length]; | ||
|
||
var i = 0; | ||
foreach (var task in tasks) | ||
{ | ||
var awaiter = task.GetAwaiter(); | ||
if (awaiter.IsCompleted) | ||
{ | ||
try | ||
{ | ||
result[i] = awaiter.GetResult(); | ||
} | ||
catch (Exception ex) | ||
{ | ||
exception = ExceptionDispatchInfo.Capture(ex); | ||
return; | ||
} | ||
TryInvokeContinuationWithIncrement(); | ||
} | ||
else | ||
{ | ||
RegisterContinuation(awaiter, i); | ||
} | ||
|
||
i++; | ||
} | ||
} | ||
|
||
void Run(IEnumerable<ValueTask<T>> tasks, int length) | ||
{ | ||
result = new T[length]; | ||
|
||
var i = 0; | ||
foreach (var task in tasks) | ||
{ | ||
var awaiter = task.GetAwaiter(); | ||
if (awaiter.IsCompleted) | ||
{ | ||
try | ||
{ | ||
result[i] = awaiter.GetResult(); | ||
} | ||
catch (Exception ex) | ||
{ | ||
exception = ExceptionDispatchInfo.Capture(ex); | ||
return; | ||
} | ||
TryInvokeContinuationWithIncrement(); | ||
} | ||
else | ||
{ | ||
RegisterContinuation(awaiter, i); | ||
} | ||
|
||
i++; | ||
} | ||
} | ||
|
||
void RegisterContinuation(ValueTaskAwaiter<T> awaiter, int index) | ||
{ | ||
awaiter.UnsafeOnCompleted(() => | ||
{ | ||
try | ||
{ | ||
result[index] = awaiter.GetResult(); | ||
} | ||
catch (Exception ex) | ||
{ | ||
exception = ExceptionDispatchInfo.Capture(ex); | ||
TryInvokeContinuation(); | ||
return; | ||
} | ||
TryInvokeContinuationWithIncrement(); | ||
}); | ||
} | ||
|
||
void TryInvokeContinuationWithIncrement() | ||
{ | ||
if (Interlocked.Increment(ref completedCount) == result.Length) | ||
{ | ||
TryInvokeContinuation(); | ||
} | ||
} | ||
|
||
void TryInvokeContinuation() | ||
{ | ||
var c = Interlocked.Exchange(ref continuation, ContinuationSentinel.CompletedContinuation); | ||
if (c != ContinuationSentinel.AvailableContinuation && c != ContinuationSentinel.CompletedContinuation) | ||
{ | ||
var spinWait = new SpinWait(); | ||
while (state == null) // worst case, state is not set yet so wait. | ||
{ | ||
spinWait.SpinOnce(); | ||
} | ||
|
||
if (execContext != null) | ||
{ | ||
ExecutionContext.Run(execContext, execContextCallback, Tuple.Create(c, this)); | ||
} | ||
else if (syncContext != null) | ||
{ | ||
syncContext.Post(syncContextCallback, Tuple.Create(c, this)); | ||
} | ||
else | ||
{ | ||
c(state); | ||
} | ||
} | ||
} | ||
|
||
public T[] GetResult(short token) | ||
{ | ||
if (exception != null) | ||
{ | ||
exception.Throw(); | ||
} | ||
return result; | ||
} | ||
|
||
public ValueTaskSourceStatus GetStatus(short token) | ||
{ | ||
return (completedCount == result.Length) ? ValueTaskSourceStatus.Succeeded | ||
: (exception != null) ? ((exception.SourceException is OperationCanceledException) ? ValueTaskSourceStatus.Canceled : ValueTaskSourceStatus.Faulted) | ||
: ValueTaskSourceStatus.Pending; | ||
} | ||
|
||
public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) | ||
{ | ||
if (Interlocked.CompareExchange(ref this.continuation, continuation, ContinuationSentinel.AvailableContinuation) != ContinuationSentinel.AvailableContinuation) | ||
{ | ||
throw new InvalidOperationException("does not allow multiple await."); | ||
} | ||
|
||
this.state = state; | ||
if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) == ValueTaskSourceOnCompletedFlags.FlowExecutionContext) | ||
{ | ||
execContext = ExecutionContext.Capture(); | ||
} | ||
if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) == ValueTaskSourceOnCompletedFlags.UseSchedulingContext) | ||
{ | ||
syncContext = SynchronizationContext.Current; | ||
} | ||
|
||
if (GetStatus(token) != ValueTaskSourceStatus.Pending) | ||
{ | ||
TryInvokeContinuation(); | ||
} | ||
} | ||
|
||
static void ExecutionContextCallback(object state) | ||
{ | ||
var t = (Tuple<Action<object>, WhenAllPromiseAll<T>>)state; | ||
var self = t.Item2; | ||
if (self.syncContext != null) | ||
{ | ||
SynchronizationContextCallback(state); | ||
} | ||
else | ||
{ | ||
var invokeState = self.state; | ||
self.state = null; | ||
t.Item1.Invoke(invokeState); | ||
} | ||
} | ||
|
||
static void SynchronizationContextCallback(object state) | ||
{ | ||
var t = (Tuple<Action<object>, WhenAllPromiseAll<T>>)state; | ||
var self = t.Item2; | ||
var invokeState = self.state; | ||
self.state = null; | ||
t.Item1.Invoke(invokeState); | ||
} | ||
} | ||
|
||
} | ||
} |
Oops, something went wrong.