Skip to content

Commit

Permalink
Add new Task-based UDP socket methods and reimplement existing ones u…
Browse files Browse the repository at this point in the history
…sing SocketAsyncEventArgs (#47229)

Closes #41502, but does not change the existing APM methods
  • Loading branch information
antonfirsov authored Jan 27, 2021
1 parent c084072 commit e9695b8
Show file tree
Hide file tree
Showing 14 changed files with 776 additions and 154 deletions.
3 changes: 3 additions & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,11 @@ public void Listen(int backlog) { }
public int ReceiveFrom(byte[] buffer, ref System.Net.EndPoint remoteEP) { throw null; }
public int ReceiveFrom(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.SocketReceiveFromResult> ReceiveFromAsync(System.ArraySegment<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.SocketReceiveFromResult> ReceiveFromAsync(System.Memory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public bool ReceiveFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.SocketReceiveMessageFromResult> ReceiveMessageFromAsync(System.ArraySegment<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.SocketReceiveMessageFromResult> ReceiveMessageFromAsync(System.Memory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default) { throw null; }
public bool ReceiveMessageFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public static void Select(System.Collections.IList? checkRead, System.Collections.IList? checkWrite, System.Collections.IList? checkError, int microSeconds) { }
public int Send(byte[] buffer) { throw null; }
Expand All @@ -414,6 +416,7 @@ public void SendFile(string? fileName, System.ReadOnlySpan<byte> preBuffer, Syst
public int SendTo(byte[] buffer, System.Net.EndPoint remoteEP) { throw null; }
public int SendTo(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; }
public System.Threading.Tasks.Task<int> SendToAsync(System.ArraySegment<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; }
public System.Threading.Tasks.ValueTask<int> SendToAsync(System.ReadOnlyMemory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public bool SendToAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
public void SetIPProtectionLevel(System.Net.Sockets.IPProtectionLevel level) { }
Expand Down
280 changes: 232 additions & 48 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3266,7 +3266,9 @@ private bool ReceiveAsync(SocketAsyncEventArgs e, CancellationToken cancellation
return socketError == SocketError.IOPending;
}

public bool ReceiveFromAsync(SocketAsyncEventArgs e)
public bool ReceiveFromAsync(SocketAsyncEventArgs e) => ReceiveFromAsync(e, default);

private bool ReceiveFromAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
ThrowIfDisposed();

Expand Down Expand Up @@ -3300,7 +3302,7 @@ public bool ReceiveFromAsync(SocketAsyncEventArgs e)
SocketError socketError;
try
{
socketError = e.DoOperationReceiveFrom(_handle);
socketError = e.DoOperationReceiveFrom(_handle, cancellationToken);
}
catch
{
Expand All @@ -3313,7 +3315,9 @@ public bool ReceiveFromAsync(SocketAsyncEventArgs e)
return pending;
}

public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e)
public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e) => ReceiveMessageFromAsync(e, default);

private bool ReceiveMessageFromAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
ThrowIfDisposed();

Expand Down Expand Up @@ -3349,7 +3353,7 @@ public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e)
SocketError socketError;
try
{
socketError = e.DoOperationReceiveMessageFrom(this, _handle);
socketError = e.DoOperationReceiveMessageFrom(this, _handle, cancellationToken);
}
catch
{
Expand Down Expand Up @@ -3423,7 +3427,9 @@ public bool SendPacketsAsync(SocketAsyncEventArgs e)
return socketError == SocketError.IOPending;
}

public bool SendToAsync(SocketAsyncEventArgs e)
public bool SendToAsync(SocketAsyncEventArgs e) => SendToAsync(e, default);

private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
ThrowIfDisposed();

Expand Down Expand Up @@ -3452,7 +3458,7 @@ public bool SendToAsync(SocketAsyncEventArgs e)
SocketError socketError;
try
{
socketError = e.DoOperationSendTo(_handle);
socketError = e.DoOperationSendTo(_handle, cancellationToken);
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1731,7 +1731,7 @@ public SocketError ReceiveMessageFrom(
return operation.ErrorCode;
}

public SocketError ReceiveMessageFromAsync(Memory<byte> buffer, IList<ArraySegment<byte>>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError> callback)
public SocketError ReceiveMessageFromAsync(Memory<byte> buffer, IList<ArraySegment<byte>>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError> callback, CancellationToken cancellationToken = default)
{
SetNonBlocking();

Expand All @@ -1755,7 +1755,7 @@ public SocketError ReceiveMessageFromAsync(Memory<byte> buffer, IList<ArraySegme
IsIPv6 = isIPv6,
};

if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
{
socketAddressLen = operation.SocketAddressLen;
receivedFlags = operation.ReceivedFlags;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ internal unsafe SocketError DoOperationReceive(SafeSocketHandle handle, Cancella
return errorCode;
}

internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle)
internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
Expand All @@ -164,7 +164,7 @@ internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle)
int socketAddressLen = _socketAddress!.Size;
if (_bufferList == null)
{
errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer.Slice(_offset, _count), _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback);
errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer.Slice(_offset, _count), _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken);
}
else
{
Expand Down Expand Up @@ -197,7 +197,7 @@ private void CompleteReceiveMessageFromOperation(int bytesTransferred, byte[] so
_receiveMessageFromPacketInfo = ipPacketInformation;
}

internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSocketHandle handle)
internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receiveMessageFromPacketInfo = default(IPPacketInformation);
_receivedFlags = System.Net.Sockets.SocketFlags.None;
Expand All @@ -210,7 +210,7 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc
int bytesReceived;
SocketFlags receivedFlags;
IPPacketInformation ipPacketInformation;
SocketError socketError = handle.AsyncContext.ReceiveMessageFromAsync(_buffer.Slice(_offset, _count), _bufferListInternal, _socketFlags, _socketAddress.Buffer, ref socketAddressSize, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, ReceiveMessageFromCompletionCallback);
SocketError socketError = handle.AsyncContext.ReceiveMessageFromAsync(_buffer.Slice(_offset, _count), _bufferListInternal, _socketFlags, _socketAddress.Buffer, ref socketAddressSize, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, ReceiveMessageFromCompletionCallback, cancellationToken);
if (socketError != SocketError.IOPending)
{
CompleteReceiveMessageFromOperation(bytesReceived, _socketAddress.Buffer, socketAddressSize, receivedFlags, ipPacketInformation, socketError);
Expand Down Expand Up @@ -303,7 +303,7 @@ internal SocketError DoOperationSendPackets(Socket socket, SafeSocketHandle hand
return SocketError.IOPending;
}

internal SocketError DoOperationSendTo(SafeSocketHandle handle)
internal SocketError DoOperationSendTo(SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
Expand All @@ -313,7 +313,7 @@ internal SocketError DoOperationSendTo(SafeSocketHandle handle)
SocketError errorCode;
if (_bufferList == null)
{
errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesSent, TransferCompletionCallback);
errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesSent, TransferCompletionCallback, cancellationToken);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ internal unsafe SocketError DoOperationReceiveMultiBuffer(SafeSocketHandle handl
}
}

internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle)
internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle, CancellationToken cancellationToken)
{
// WSARecvFrom uses a WSABuffer array describing buffers in which to
// receive data and from which to send data respectively. Single and multiple buffers
Expand All @@ -454,11 +454,11 @@ internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle)
PinSocketAddressBuffer();

return _bufferList == null ?
DoOperationReceiveFromSingleBuffer(handle) :
DoOperationReceiveFromSingleBuffer(handle, cancellationToken) :
DoOperationReceiveFromMultiBuffer(handle);
}

internal unsafe SocketError DoOperationReceiveFromSingleBuffer(SafeSocketHandle handle)
internal unsafe SocketError DoOperationReceiveFromSingleBuffer(SafeSocketHandle handle, CancellationToken cancellationToken)
{
fixed (byte* bufferPtr = &MemoryMarshal.GetReference(_buffer.Span))
{
Expand All @@ -481,7 +481,7 @@ internal unsafe SocketError DoOperationReceiveFromSingleBuffer(SafeSocketHandle
overlapped,
IntPtr.Zero);

return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped);
return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken);
}
catch
{
Expand Down Expand Up @@ -518,7 +518,7 @@ internal unsafe SocketError DoOperationReceiveFromMultiBuffer(SafeSocketHandle h
}
}

internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSocketHandle handle)
internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
// WSARecvMsg uses a WSAMsg descriptor.
// The WSAMsg buffer is a pinned array to avoid complicating the use of Overlapped.
Expand Down Expand Up @@ -558,25 +558,33 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc
_wsaRecvMsgWSABufferArrayPinned = GC.AllocateUninitializedArray<WSABuffer>(1, pinned: true);
}

Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None);
_singleBufferHandle = _buffer.Pin();
_singleBufferHandleState = SingleBufferHandleState.Set;
fixed (byte* bufferPtr = &MemoryMarshal.GetReference(_buffer.Span))
{
Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None);
_singleBufferHandleState = SingleBufferHandleState.InProcess;

_wsaRecvMsgWSABufferArrayPinned[0].Pointer = (IntPtr)bufferPtr;
_wsaRecvMsgWSABufferArrayPinned[0].Length = _count;
wsaRecvMsgWSABufferArray = _wsaRecvMsgWSABufferArrayPinned;
wsaRecvMsgWSABufferCount = 1;

_wsaRecvMsgWSABufferArrayPinned[0].Pointer = (IntPtr)_singleBufferHandle.Pointer;
_wsaRecvMsgWSABufferArrayPinned[0].Length = _count;
wsaRecvMsgWSABufferArray = _wsaRecvMsgWSABufferArrayPinned;
wsaRecvMsgWSABufferCount = 1;
return Core();
}
}
else
{
// Use the multi-buffer WSABuffer.
wsaRecvMsgWSABufferArray = _wsaBufferArrayPinned!;
wsaRecvMsgWSABufferCount = (uint)_bufferListInternal!.Count;

return Core();
}

// Fill in WSAMessageBuffer.
unsafe
// Fill in WSAMessageBuffer, run WSARecvMsg and process the IOCP result.
// Logic is in a separate method so we can share code between the (pinned) single buffer and the multi-buffer case
SocketError Core()
{
// Fill in WSAMessageBuffer.
Interop.Winsock.WSAMsg* pMessage = (Interop.Winsock.WSAMsg*)Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBufferPinned, 0);
pMessage->socketAddress = PtrSocketAddressBuffer;
pMessage->addressLength = (uint)_socketAddress.Size;
Expand All @@ -596,26 +604,27 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc
pMessage->controlBuffer.Length = _controlBufferPinned.Length;
}
pMessage->flags = _socketFlags;
}

NativeOverlapped* overlapped = AllocateNativeOverlapped();
try
{
SocketError socketError = socket.WSARecvMsg(
handle,
Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBufferPinned, 0),
out int bytesTransferred,
overlapped,
IntPtr.Zero);
NativeOverlapped* overlapped = AllocateNativeOverlapped();
try
{
SocketError socketError = socket.WSARecvMsg(
handle,
Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBufferPinned, 0),
out int bytesTransferred,
overlapped,
IntPtr.Zero);

return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped);
}
catch
{
_singleBufferHandleState = SingleBufferHandleState.None;
FreeNativeOverlapped(overlapped);
_singleBufferHandle.Dispose();
throw;
return _bufferList == null ?
ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken) :
ProcessIOCPResult(socketError == SocketError.Success, bytesTransferred, overlapped);
}
catch
{
_singleBufferHandleState = SingleBufferHandleState.None;
FreeNativeOverlapped(overlapped);
throw;
}
}
}

Expand Down Expand Up @@ -773,7 +782,7 @@ internal unsafe SocketError DoOperationSendPackets(Socket socket, SafeSocketHand
}
}

internal unsafe SocketError DoOperationSendTo(SafeSocketHandle handle)
internal unsafe SocketError DoOperationSendTo(SafeSocketHandle handle, CancellationToken cancellationToken)
{
// WSASendTo uses a WSABuffer array describing buffers in which to
// receive data and from which to send data respectively. Single and multiple buffers
Expand All @@ -784,11 +793,11 @@ internal unsafe SocketError DoOperationSendTo(SafeSocketHandle handle)
PinSocketAddressBuffer();

return _bufferList == null ?
DoOperationSendToSingleBuffer(handle) :
DoOperationSendToSingleBuffer(handle, cancellationToken) :
DoOperationSendToMultiBuffer(handle);
}

internal unsafe SocketError DoOperationSendToSingleBuffer(SafeSocketHandle handle)
internal unsafe SocketError DoOperationSendToSingleBuffer(SafeSocketHandle handle, CancellationToken cancellationToken)
{
fixed (byte* bufferPtr = &MemoryMarshal.GetReference(_buffer.Span))
{
Expand All @@ -810,7 +819,7 @@ internal unsafe SocketError DoOperationSendToSingleBuffer(SafeSocketHandle handl
overlapped,
IntPtr.Zero);

return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped);
return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken);
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public static ValueTask<int> SendAsync(this Socket socket, ReadOnlyMemory<byte>
[EditorBrowsable(EditorBrowsableState.Never)]
public static Task<int> SendAsync(this Socket socket, IList<ArraySegment<byte>> buffers, SocketFlags socketFlags) =>
socket.SendAsync(buffers, socketFlags);

[EditorBrowsable(EditorBrowsableState.Never)]
public static Task<int> SendToAsync(this Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags, EndPoint remoteEP) =>
socket.SendToAsync(buffer, socketFlags, remoteEP);
Expand Down
Loading

0 comments on commit e9695b8

Please sign in to comment.