From e9695b8e7840c37390ca58b821cfbeab178d1cc6 Mon Sep 17 00:00:00 2001 From: Anton Firszov Date: Wed, 27 Jan 2021 16:25:59 +0100 Subject: [PATCH] Add new Task-based UDP socket methods and reimplement existing ones using SocketAsyncEventArgs (#47229) Closes #41502, but does not change the existing APM methods --- .../ref/System.Net.Sockets.cs | 3 + .../src/System/Net/Sockets/Socket.Tasks.cs | 280 +++++++++++++++--- .../src/System/Net/Sockets/Socket.cs | 18 +- .../Net/Sockets/SocketAsyncContext.Unix.cs | 4 +- .../Net/Sockets/SocketAsyncEventArgs.Unix.cs | 12 +- .../Sockets/SocketAsyncEventArgs.Windows.cs | 81 ++--- .../Net/Sockets/SocketTaskExtensions.cs | 1 + .../ArgumentValidationTests.cs | 4 +- .../tests/FunctionalTests/ReceiveFrom.cs | 239 +++++++++++++++ .../FunctionalTests/ReceiveMessageFrom.cs | 164 +++++++++- .../SendReceive/SendReceiveNonParallel.cs | 5 + .../tests/FunctionalTests/SendTo.cs | 110 ++++--- .../tests/FunctionalTests/SocketTestHelper.cs | 8 +- .../System.Net.Sockets.Tests.csproj | 1 + 14 files changed, 776 insertions(+), 154 deletions(-) create mode 100644 src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index 5333105b924aa7..4a909e43a46300 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -385,9 +385,11 @@ public void Listen(int backlog) { } public int ReceiveFrom(byte[] buffer, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; } + public System.Threading.Tasks.ValueTask ReceiveFromAsync(System.Memory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public bool ReceiveFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; } public System.Threading.Tasks.Task ReceiveMessageFromAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; } + public System.Threading.Tasks.ValueTask ReceiveMessageFromAsync(System.Memory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default) { throw null; } public bool ReceiveMessageFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } public static void Select(System.Collections.IList? checkRead, System.Collections.IList? checkWrite, System.Collections.IList? checkError, int microSeconds) { } public int Send(byte[] buffer) { throw null; } @@ -414,6 +416,7 @@ public void SendFile(string? fileName, System.ReadOnlySpan preBuffer, Syst public int SendTo(byte[] buffer, System.Net.EndPoint remoteEP) { throw null; } public int SendTo(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } public System.Threading.Tasks.Task SendToAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } + public System.Threading.Tasks.ValueTask SendToAsync(System.ReadOnlyMemory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public bool SendToAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } [System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")] public void SetIPProtectionLevel(System.Net.Sockets.IPProtectionLevel level) { } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index 07cc7af8128c48..f5432d56e0009e 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -333,25 +333,51 @@ public Task ReceiveAsync(IList> buffers, SocketFlags soc /// The buffer for the received data. /// A bitwise combination of SocketFlags values that will be used when receiving the data. /// An endpoint of the same type as the endpoint of the remote host. - /// An asynchronous task that completes with a SocketReceiveFromResult containing the number of bytes received and the endpoint of the sending host. + /// An asynchronous task that completes with a containing the number of bytes received and the endpoint of the sending host. public Task ReceiveFromAsync(ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEndPoint) { - var tcs = new StateTaskCompletionSource(this) { _field1 = remoteEndPoint }; - BeginReceiveFrom(buffer.Array!, buffer.Offset, buffer.Count, socketFlags, ref tcs._field1, iar => + ValidateBuffer(buffer); + return ReceiveFromAsync(buffer, socketFlags, remoteEndPoint, default).AsTask(); + } + + /// + /// Receives data and returns the endpoint of the sending host. + /// + /// The buffer for the received data. + /// A bitwise combination of SocketFlags values that will be used when receiving the data. + /// An endpoint of the same type as the endpoint of the remote host. + /// A cancellation token that can be used to signal the asynchronous operation should be canceled. + /// An asynchronous task that completes with a containing the number of bytes received and the endpoint of the sending host. + public ValueTask ReceiveFromAsync(Memory buffer, SocketFlags socketFlags, EndPoint remoteEndPoint, CancellationToken cancellationToken = default) + { + if (remoteEndPoint is null) { - var innerTcs = (StateTaskCompletionSource)iar.AsyncState!; - try - { - int receivedBytes = ((Socket)innerTcs.Task.AsyncState!).EndReceiveFrom(iar, ref innerTcs._field1); - innerTcs.TrySetResult(new SocketReceiveFromResult - { - ReceivedBytes = receivedBytes, - RemoteEndPoint = innerTcs._field1 - }); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; + throw new ArgumentNullException(nameof(remoteEndPoint)); + } + if (!CanTryAddressFamily(remoteEndPoint.AddressFamily)) + { + throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEndPoint.AddressFamily, _addressFamily), nameof(remoteEndPoint)); + } + if (_rightEndPoint == null) + { + throw new InvalidOperationException(SR.net_sockets_mustbind); + } + + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + AwaitableSocketAsyncEventArgs saea = + Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null) ?? + new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: true); + + Debug.Assert(saea.BufferList == null); + saea.SetBuffer(buffer); + saea.SocketFlags = socketFlags; + saea.RemoteEndPoint = remoteEndPoint; + saea.WrapExceptionsForNetworkStream = false; + return saea.ReceiveFromAsync(this, cancellationToken); } /// @@ -360,28 +386,50 @@ public Task ReceiveFromAsync(ArraySegment buffer, /// The buffer for the received data. /// A bitwise combination of SocketFlags values that will be used when receiving the data. /// An endpoint of the same type as the endpoint of the remote host. - /// An asynchronous task that completes with a SocketReceiveMessageFromResult containing the number of bytes received and additional information about the sending host. + /// An asynchronous task that completes with a containing the number of bytes received and additional information about the sending host. public Task ReceiveMessageFromAsync(ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEndPoint) { - var tcs = new StateTaskCompletionSource(this) { _field1 = socketFlags, _field2 = remoteEndPoint }; - BeginReceiveMessageFrom(buffer.Array!, buffer.Offset, buffer.Count, socketFlags, ref tcs._field2, iar => + ValidateBuffer(buffer); + return ReceiveMessageFromAsync(buffer, socketFlags, remoteEndPoint, default).AsTask(); + } + + /// + /// Receives data and returns additional information about the sender of the message. + /// + /// The buffer for the received data. + /// A bitwise combination of SocketFlags values that will be used when receiving the data. + /// An endpoint of the same type as the endpoint of the remote host. + /// A cancellation token that can be used to signal the asynchronous operation should be canceled. + /// An asynchronous task that completes with a containing the number of bytes received and additional information about the sending host. + public ValueTask ReceiveMessageFromAsync(Memory buffer, SocketFlags socketFlags, EndPoint remoteEndPoint, CancellationToken cancellationToken = default) + { + if (remoteEndPoint is null) { - var innerTcs = (StateTaskCompletionSource)iar.AsyncState!; - try - { - IPPacketInformation ipPacketInformation; - int receivedBytes = ((Socket)innerTcs.Task.AsyncState!).EndReceiveMessageFrom(iar, ref innerTcs._field1, ref innerTcs._field2, out ipPacketInformation); - innerTcs.TrySetResult(new SocketReceiveMessageFromResult - { - ReceivedBytes = receivedBytes, - RemoteEndPoint = innerTcs._field2, - SocketFlags = innerTcs._field1, - PacketInformation = ipPacketInformation - }); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; + throw new ArgumentNullException(nameof(remoteEndPoint)); + } + if (!CanTryAddressFamily(remoteEndPoint.AddressFamily)) + { + throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEndPoint.AddressFamily, _addressFamily), nameof(remoteEndPoint)); + } + if (_rightEndPoint == null) + { + throw new InvalidOperationException(SR.net_sockets_mustbind); + } + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + AwaitableSocketAsyncEventArgs saea = + Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null) ?? + new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: true); + + Debug.Assert(saea.BufferList == null); + saea.SetBuffer(buffer); + saea.SocketFlags = socketFlags; + saea.RemoteEndPoint = remoteEndPoint; + saea.WrapExceptionsForNetworkStream = false; + return saea.ReceiveMessageFromAsync(this, cancellationToken); } /// @@ -470,14 +518,40 @@ public Task SendAsync(IList> buffers, SocketFlags socket /// An asynchronous task that completes with the number of bytes sent. public Task SendToAsync(ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEP) { - var tcs = new TaskCompletionSource(this); - BeginSendTo(buffer.Array!, buffer.Offset, buffer.Count, socketFlags, remoteEP, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState!; - try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState!).EndSendTo(iar)); } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; + ValidateBuffer(buffer); + return SendToAsync(buffer, socketFlags, remoteEP, default).AsTask(); + } + + /// + /// Sends data to the specified remote host. + /// + /// The buffer for the data to send. + /// A bitwise combination of SocketFlags values that will be used when sending the data. + /// The remote host to which to send the data. + /// A cancellation token that can be used to cancel the asynchronous operation. + /// An asynchronous task that completes with the number of bytes sent. + public ValueTask SendToAsync(ReadOnlyMemory buffer, SocketFlags socketFlags, EndPoint remoteEP, CancellationToken cancellationToken = default) + { + if (remoteEP is null) + { + throw new ArgumentNullException(nameof(remoteEP)); + } + + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + AwaitableSocketAsyncEventArgs saea = + Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ?? + new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false); + + Debug.Assert(saea.BufferList == null); + saea.SetBuffer(MemoryMarshal.AsMemory(buffer)); + saea.SocketFlags = socketFlags; + saea.RemoteEndPoint = remoteEP; + saea.WrapExceptionsForNetworkStream = false; + return saea.SendToAsync(this, cancellationToken); } private static void ValidateBufferArguments(byte[] buffer, int offset, int size) @@ -743,7 +817,7 @@ internal AsyncTaskMethodBuilder GetCompletionResponsibility(out bool re } /// A SocketAsyncEventArgs that can be awaited to get the result of an operation. - internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource + internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource, IValueTaskSource, IValueTaskSource { private static readonly Action s_completedSentinel = new Action(state => throw new InvalidOperationException(SR.Format(SR.net_sockets_valuetaskmisuse, nameof(s_completedSentinel)))); /// The owning socket. @@ -827,7 +901,7 @@ protected override void OnCompleted(SocketAsyncEventArgs _) /// This instance. public ValueTask ReceiveAsync(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); + Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); if (socket.ReceiveAsync(this, cancellationToken)) { @@ -845,11 +919,55 @@ public ValueTask ReceiveAsync(Socket socket, CancellationToken cancellation ValueTask.FromException(CreateException(error)); } + public ValueTask ReceiveFromAsync(Socket socket, CancellationToken cancellationToken) + { + Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); + + if (socket.ReceiveFromAsync(this, cancellationToken)) + { + _cancellationToken = cancellationToken; + return new ValueTask(this, _token); + } + + int bytesTransferred = BytesTransferred; + EndPoint remoteEndPoint = RemoteEndPoint!; + SocketError error = SocketError; + + Release(); + + return error == SocketError.Success ? + new ValueTask(new SocketReceiveFromResult() { ReceivedBytes = bytesTransferred, RemoteEndPoint = remoteEndPoint }) : + ValueTask.FromException(CreateException(error)); + } + + public ValueTask ReceiveMessageFromAsync(Socket socket, CancellationToken cancellationToken) + { + Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); + + if (socket.ReceiveMessageFromAsync(this, cancellationToken)) + { + _cancellationToken = cancellationToken; + return new ValueTask(this, _token); + } + + int bytesTransferred = BytesTransferred; + EndPoint remoteEndPoint = RemoteEndPoint!; + SocketFlags socketFlags = SocketFlags; + IPPacketInformation packetInformation = ReceiveMessageFromPacketInfo; + SocketError error = SocketError; + + Release(); + + return error == SocketError.Success ? + new ValueTask(new SocketReceiveMessageFromResult() { ReceivedBytes = bytesTransferred, RemoteEndPoint = remoteEndPoint, SocketFlags = socketFlags, PacketInformation = packetInformation }) : + ValueTask.FromException(CreateException(error)); + } + /// Initiates a send operation on the associated socket. /// This instance. public ValueTask SendAsync(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); + Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); if (socket.SendAsync(this, cancellationToken)) { @@ -869,7 +987,7 @@ public ValueTask SendAsync(Socket socket, CancellationToken cancellationTok public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken cancellationToken) { - Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); + Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); if (socket.SendAsync(this, cancellationToken)) { @@ -886,9 +1004,29 @@ public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken canc ValueTask.FromException(CreateException(error)); } + public ValueTask SendToAsync(Socket socket, CancellationToken cancellationToken) + { + Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); + + if (socket.SendToAsync(this, cancellationToken)) + { + _cancellationToken = cancellationToken; + return new ValueTask(this, _token); + } + + int bytesTransferred = BytesTransferred; + SocketError error = SocketError; + + Release(); + + return error == SocketError.Success ? + new ValueTask(bytesTransferred) : + ValueTask.FromException(CreateException(error)); + } + public ValueTask ConnectAsync(Socket socket) { - Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); + Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); try { @@ -1060,6 +1198,52 @@ void IValueTaskSource.GetResult(short token) } } + SocketReceiveFromResult IValueTaskSource.GetResult(short token) + { + if (token != _token) + { + ThrowIncorrectTokenException(); + } + + SocketError error = SocketError; + int bytes = BytesTransferred; + EndPoint remoteEndPoint = RemoteEndPoint!; + CancellationToken cancellationToken = _cancellationToken; + + Release(); + + if (error != SocketError.Success) + { + ThrowException(error, cancellationToken); + } + + return new SocketReceiveFromResult() { ReceivedBytes = bytes, RemoteEndPoint = remoteEndPoint }; + } + + SocketReceiveMessageFromResult IValueTaskSource.GetResult(short token) + { + if (token != _token) + { + ThrowIncorrectTokenException(); + } + + SocketError error = SocketError; + int bytes = BytesTransferred; + EndPoint remoteEndPoint = RemoteEndPoint!; + SocketFlags socketFlags = SocketFlags; + IPPacketInformation packetInformation = ReceiveMessageFromPacketInfo; + CancellationToken cancellationToken = _cancellationToken; + + Release(); + + if (error != SocketError.Success) + { + ThrowException(error, cancellationToken); + } + + return new SocketReceiveMessageFromResult() { ReceivedBytes = bytes, RemoteEndPoint = remoteEndPoint, SocketFlags = socketFlags, PacketInformation = packetInformation }; + } + private void ThrowIncorrectTokenException() => throw new InvalidOperationException(SR.InvalidOperation_IncorrectToken); private void ThrowMultipleContinuationsException() => throw new InvalidOperationException(SR.InvalidOperation_MultipleContinuations); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index efb20cb4166f1b..fdd1055aaef557 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -3266,7 +3266,9 @@ private bool ReceiveAsync(SocketAsyncEventArgs e, CancellationToken cancellation return socketError == SocketError.IOPending; } - public bool ReceiveFromAsync(SocketAsyncEventArgs e) + public bool ReceiveFromAsync(SocketAsyncEventArgs e) => ReceiveFromAsync(e, default); + + private bool ReceiveFromAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken) { ThrowIfDisposed(); @@ -3300,7 +3302,7 @@ public bool ReceiveFromAsync(SocketAsyncEventArgs e) SocketError socketError; try { - socketError = e.DoOperationReceiveFrom(_handle); + socketError = e.DoOperationReceiveFrom(_handle, cancellationToken); } catch { @@ -3313,7 +3315,9 @@ public bool ReceiveFromAsync(SocketAsyncEventArgs e) return pending; } - public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e) + public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e) => ReceiveMessageFromAsync(e, default); + + private bool ReceiveMessageFromAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken) { ThrowIfDisposed(); @@ -3349,7 +3353,7 @@ public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e) SocketError socketError; try { - socketError = e.DoOperationReceiveMessageFrom(this, _handle); + socketError = e.DoOperationReceiveMessageFrom(this, _handle, cancellationToken); } catch { @@ -3423,7 +3427,9 @@ public bool SendPacketsAsync(SocketAsyncEventArgs e) return socketError == SocketError.IOPending; } - public bool SendToAsync(SocketAsyncEventArgs e) + public bool SendToAsync(SocketAsyncEventArgs e) => SendToAsync(e, default); + + private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken) { ThrowIfDisposed(); @@ -3452,7 +3458,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) SocketError socketError; try { - socketError = e.DoOperationSendTo(_handle); + socketError = e.DoOperationSendTo(_handle, cancellationToken); } catch { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs index 4202cdbefdeebf..2f01b6b543ceeb 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs @@ -1731,7 +1731,7 @@ public SocketError ReceiveMessageFrom( return operation.ErrorCode; } - public SocketError ReceiveMessageFromAsync(Memory buffer, IList>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action callback) + public SocketError ReceiveMessageFromAsync(Memory buffer, IList>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action callback, CancellationToken cancellationToken = default) { SetNonBlocking(); @@ -1755,7 +1755,7 @@ public SocketError ReceiveMessageFromAsync(Memory buffer, IList(1, pinned: true); } - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None); - _singleBufferHandle = _buffer.Pin(); - _singleBufferHandleState = SingleBufferHandleState.Set; + fixed (byte* bufferPtr = &MemoryMarshal.GetReference(_buffer.Span)) + { + Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None); + _singleBufferHandleState = SingleBufferHandleState.InProcess; + + _wsaRecvMsgWSABufferArrayPinned[0].Pointer = (IntPtr)bufferPtr; + _wsaRecvMsgWSABufferArrayPinned[0].Length = _count; + wsaRecvMsgWSABufferArray = _wsaRecvMsgWSABufferArrayPinned; + wsaRecvMsgWSABufferCount = 1; - _wsaRecvMsgWSABufferArrayPinned[0].Pointer = (IntPtr)_singleBufferHandle.Pointer; - _wsaRecvMsgWSABufferArrayPinned[0].Length = _count; - wsaRecvMsgWSABufferArray = _wsaRecvMsgWSABufferArrayPinned; - wsaRecvMsgWSABufferCount = 1; + return Core(); + } } else { // Use the multi-buffer WSABuffer. wsaRecvMsgWSABufferArray = _wsaBufferArrayPinned!; wsaRecvMsgWSABufferCount = (uint)_bufferListInternal!.Count; + + return Core(); } - // Fill in WSAMessageBuffer. - unsafe + // Fill in WSAMessageBuffer, run WSARecvMsg and process the IOCP result. + // Logic is in a separate method so we can share code between the (pinned) single buffer and the multi-buffer case + SocketError Core() { + // Fill in WSAMessageBuffer. Interop.Winsock.WSAMsg* pMessage = (Interop.Winsock.WSAMsg*)Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBufferPinned, 0); pMessage->socketAddress = PtrSocketAddressBuffer; pMessage->addressLength = (uint)_socketAddress.Size; @@ -596,26 +604,27 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSoc pMessage->controlBuffer.Length = _controlBufferPinned.Length; } pMessage->flags = _socketFlags; - } - NativeOverlapped* overlapped = AllocateNativeOverlapped(); - try - { - SocketError socketError = socket.WSARecvMsg( - handle, - Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBufferPinned, 0), - out int bytesTransferred, - overlapped, - IntPtr.Zero); + NativeOverlapped* overlapped = AllocateNativeOverlapped(); + try + { + SocketError socketError = socket.WSARecvMsg( + handle, + Marshal.UnsafeAddrOfPinnedArrayElement(_wsaMessageBufferPinned, 0), + out int bytesTransferred, + overlapped, + IntPtr.Zero); - return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped); - } - catch - { - _singleBufferHandleState = SingleBufferHandleState.None; - FreeNativeOverlapped(overlapped); - _singleBufferHandle.Dispose(); - throw; + return _bufferList == null ? + ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken) : + ProcessIOCPResult(socketError == SocketError.Success, bytesTransferred, overlapped); + } + catch + { + _singleBufferHandleState = SingleBufferHandleState.None; + FreeNativeOverlapped(overlapped); + throw; + } } } @@ -773,7 +782,7 @@ internal unsafe SocketError DoOperationSendPackets(Socket socket, SafeSocketHand } } - internal unsafe SocketError DoOperationSendTo(SafeSocketHandle handle) + internal unsafe SocketError DoOperationSendTo(SafeSocketHandle handle, CancellationToken cancellationToken) { // WSASendTo uses a WSABuffer array describing buffers in which to // receive data and from which to send data respectively. Single and multiple buffers @@ -784,11 +793,11 @@ internal unsafe SocketError DoOperationSendTo(SafeSocketHandle handle) PinSocketAddressBuffer(); return _bufferList == null ? - DoOperationSendToSingleBuffer(handle) : + DoOperationSendToSingleBuffer(handle, cancellationToken) : DoOperationSendToMultiBuffer(handle); } - internal unsafe SocketError DoOperationSendToSingleBuffer(SafeSocketHandle handle) + internal unsafe SocketError DoOperationSendToSingleBuffer(SafeSocketHandle handle, CancellationToken cancellationToken) { fixed (byte* bufferPtr = &MemoryMarshal.GetReference(_buffer.Span)) { @@ -810,7 +819,7 @@ internal unsafe SocketError DoOperationSendToSingleBuffer(SafeSocketHandle handl overlapped, IntPtr.Zero); - return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped); + return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken); } catch { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs index 1d9a3d6826a621..5197487c8cbabb 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs @@ -67,6 +67,7 @@ public static ValueTask SendAsync(this Socket socket, ReadOnlyMemory [EditorBrowsable(EditorBrowsableState.Never)] public static Task SendAsync(this Socket socket, IList> buffers, SocketFlags socketFlags) => socket.SendAsync(buffers, socketFlags); + [EditorBrowsable(EditorBrowsableState.Never)] public static Task SendToAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEP) => socket.SendToAsync(buffer, socketFlags, remoteEP); diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs index 003c7fdbbc611b..470ff065748bf8 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs @@ -1334,7 +1334,7 @@ public void BeginReceiveFrom_AddressFamily_Throws_Argument() { EndPoint endpoint = new IPEndPoint(IPAddress.IPv6Loopback, 1); AssertExtensions.Throws("remoteEP", () => GetSocket(AddressFamily.InterNetwork).BeginReceiveFrom(s_buffer, 0, 0, SocketFlags.None, ref endpoint, TheAsyncCallback, null)); - AssertExtensions.Throws("remoteEP", () => { GetSocket(AddressFamily.InterNetwork).ReceiveFromAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, endpoint); }); + AssertExtensions.Throws("remoteEndPoint", () => { GetSocket(AddressFamily.InterNetwork).ReceiveFromAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, endpoint); }); } [Fact] @@ -1402,7 +1402,7 @@ public void BeginReceiveMessageFrom_AddressFamily_Throws_Argument() EndPoint remote = new IPEndPoint(IPAddress.IPv6Loopback, 1); AssertExtensions.Throws("remoteEP", () => GetSocket(AddressFamily.InterNetwork).BeginReceiveMessageFrom(s_buffer, 0, 0, SocketFlags.None, ref remote, TheAsyncCallback, null)); - AssertExtensions.Throws("remoteEP", () => { GetSocket(AddressFamily.InterNetwork).ReceiveMessageFromAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, remote); }); + AssertExtensions.Throws("remoteEndPoint", () => { GetSocket(AddressFamily.InterNetwork).ReceiveMessageFromAsync(new ArraySegment(s_buffer, 0, 0), SocketFlags.None, remote); }); } [Fact] diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs new file mode 100644 index 00000000000000..939595d5f4beb4 --- /dev/null +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs @@ -0,0 +1,239 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.Sockets.Tests +{ + public abstract class ReceiveFrom : SocketTestHelperBase where T : SocketHelperBase, new() + { + protected static IPEndPoint GetGetDummyTestEndpoint(AddressFamily addressFamily = AddressFamily.InterNetwork) => + addressFamily == AddressFamily.InterNetwork ? new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1234) : new IPEndPoint(IPAddress.Parse("1:2:3::4"), 1234); + + protected static readonly TimeSpan CancellationTestTimeout = TimeSpan.FromSeconds(30); + + protected ReceiveFrom(ITestOutputHelper output) : base(output) { } + + [Theory] + [InlineData(1, -1, 0)] // offset low + [InlineData(1, 2, 0)] // offset high + [InlineData(1, 0, -1)] // count low + [InlineData(1, 1, 2)] // count high + public async Task OutOfRange_Throws(int length, int offset, int count) + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + + ArraySegment buffer = new FakeArraySegment + { + Array = new byte[length], + Count = count, + Offset = offset + }.ToActual(); + + await Assert.ThrowsAnyAsync(() => ReceiveFromAsync(socket, buffer, GetGetDummyTestEndpoint())); + } + + [Fact] + public async Task NullBuffer_Throws() + { + if (!ValidatesArrayArguments) return; + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + + await Assert.ThrowsAsync(() => ReceiveFromAsync(socket, null, GetGetDummyTestEndpoint())); + } + + [Fact] + public async Task NullEndpoint_Throws() + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + + await Assert.ThrowsAnyAsync(() => ReceiveFromAsync(socket, new byte[1], null)); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ReceiveSent_TCP_Success(bool ipv6) + { + if (ipv6 && PlatformDetection.IsOSX) + { + // [ActiveIssue("https://github.com/dotnet/runtime/issues/47335")] + // accept() will create a (seemingly) DualMode socket on Mac, + // but since recvmsg() does not work with DualMode on that OS, we throw PNSE CheckDualModeReceiveSupport(). + // Weirdly, the flag is readable, but an attempt to write it leads to EINVAL. + // The best option seems to be to skip this test for the Mac+IPV6 case + return; + } + + (Socket sender, Socket receiver) = SocketTestExtensions.CreateConnectedSocketPair(ipv6); + using (sender) + using (receiver) + { + byte[] sendBuffer = { 1, 2, 3 }; + sender.Send(sendBuffer); + + byte[] receiveBuffer = new byte[3]; + var r = await ReceiveFromAsync(receiver, receiveBuffer, sender.LocalEndPoint); + Assert.Equal(3, r.ReceivedBytes); + AssertExtensions.SequenceEqual(sendBuffer, receiveBuffer); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ClosedBeforeOperation_Throws_ObjectDisposedException(bool closeOrDispose) + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + socket.BindToAnonymousPort(IPAddress.Any); + if (closeOrDispose) socket.Close(); + else socket.Dispose(); + + await Assert.ThrowsAsync(() => ReceiveFromAsync(socket, new byte[1], GetGetDummyTestEndpoint())); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ClosedDuringOperation_Throws_ObjectDisposedExceptionOrSocketException(bool closeOrDispose) + { + if (UsesSync && PlatformDetection.IsOSX) + { + // [ActiveIssue("https://github.com/dotnet/runtime/issues/47342")] + // On Mac, Close/Dispose hangs when invoked concurrently with a blocking UDP receive. + return; + } + + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + socket.BindToAnonymousPort(IPAddress.Any); + + Task receiveTask = ReceiveFromAsync(socket, new byte[1], GetGetDummyTestEndpoint()); + await Task.Delay(100); + if (closeOrDispose) socket.Close(); + else socket.Dispose(); + + if (UsesApm) + { + await Assert.ThrowsAsync(() => receiveTask) + .TimeoutAfter(CancellationTestTimeout); + } + else + { + SocketException ex = await Assert.ThrowsAsync(() => receiveTask) + .TimeoutAfter(CancellationTestTimeout); + SocketError expectedError = UsesSync ? SocketError.Interrupted : SocketError.OperationAborted; + Assert.Equal(expectedError, ex.SocketErrorCode); + } + } + + [PlatformSpecific(TestPlatforms.Windows)] // It's allowed to shutdown() UDP sockets on Windows, however on Unix this will lead to ENOTCONN + [Theory] + [InlineData(SocketShutdown.Both)] + [InlineData(SocketShutdown.Receive)] + public async Task ShutdownReceiveBeforeOperation_ThrowsSocketException(SocketShutdown shutdown) + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + socket.BindToAnonymousPort(IPAddress.Any); + socket.Shutdown(shutdown); + + // [ActiveIssue("https://github.com/dotnet/runtime/issues/47469")] + // Shutdown(Both) does not seem to take immediate effect for Receive(Message)From in a consistent manner, trying to workaround with a delay: + if (shutdown == SocketShutdown.Both) await Task.Delay(50); + + SocketException exception = await Assert.ThrowsAnyAsync(() => ReceiveFromAsync(socket, new byte[1], GetGetDummyTestEndpoint())) + .TimeoutAfter(CancellationTestTimeout); + + Assert.Equal(SocketError.Shutdown, exception.SocketErrorCode); + } + + [PlatformSpecific(TestPlatforms.Windows)] // It's allowed to shutdown() UDP sockets on Windows, however on Unix this will lead to ENOTCONN + [Fact] + public async Task ShutdownSend_ReceiveFromShouldSucceed() + { + using var receiver = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + receiver.BindToAnonymousPort(IPAddress.Loopback); + receiver.Shutdown(SocketShutdown.Send); + + using var sender = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + sender.BindToAnonymousPort(IPAddress.Loopback); + sender.SendTo(new byte[1], receiver.LocalEndPoint); + + var r = await ReceiveFromAsync(receiver, new byte[1], sender.LocalEndPoint); + Assert.Equal(1, r.ReceivedBytes); + } + } + + public sealed class ReceiveFrom_Sync : ReceiveFrom + { + public ReceiveFrom_Sync(ITestOutputHelper output) : base(output) { } + } + + public sealed class ReceiveFrom_SyncForceNonBlocking : ReceiveFrom + { + public ReceiveFrom_SyncForceNonBlocking(ITestOutputHelper output) : base(output) { } + } + + public sealed class ReceiveFrom_Apm : ReceiveFrom + { + public ReceiveFrom_Apm(ITestOutputHelper output) : base(output) { } + } + + public sealed class ReceiveFrom_Task : ReceiveFrom + { + public ReceiveFrom_Task(ITestOutputHelper output) : base(output) { } + } + + public sealed class ReceiveFrom_CancellableTask : ReceiveFrom + { + public ReceiveFrom_CancellableTask(ITestOutputHelper output) : base(output) { } + + [Theory] + [MemberData(nameof(LoopbacksAndBuffers))] + public async Task WhenCanceled_Throws(IPAddress loopback, bool precanceled) + { + using var socket = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using var dummy = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + socket.BindToAnonymousPort(loopback); + dummy.BindToAnonymousPort(loopback); + Memory buffer = new byte[1]; + + CancellationTokenSource cts = new CancellationTokenSource(); + if (precanceled) cts.Cancel(); + else cts.CancelAfter(100); + + OperationCanceledException ex = await Assert.ThrowsAnyAsync( + () => socket.ReceiveFromAsync(buffer, SocketFlags.None, dummy.LocalEndPoint, cts.Token).AsTask()) + .TimeoutAfter(CancellationTestTimeout); + Assert.Equal(cts.Token, ex.CancellationToken); + } + } + + public sealed class ReceiveFrom_Eap : ReceiveFrom + { + public ReceiveFrom_Eap(ITestOutputHelper output) : base(output) { } + } + + public sealed class ReceiveFrom_SpanSync : ReceiveFrom + { + public ReceiveFrom_SpanSync(ITestOutputHelper output) : base(output) { } + } + + public sealed class ReceiveFrom_SpanSyncForceNonBlocking : ReceiveFrom + { + public ReceiveFrom_SpanSyncForceNonBlocking(ITestOutputHelper output) : base(output) { } + } + + public sealed class ReceiveFrom_MemoryArrayTask : ReceiveFrom + { + public ReceiveFrom_MemoryArrayTask(ITestOutputHelper output) : base(output) { } + } + + public sealed class ReceiveFrom_MemoryNativeTask : ReceiveFrom + { + public ReceiveFrom_MemoryNativeTask(ITestOutputHelper output) : base(output) { } + } +} diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs index 6bc70e67a7fa87..e23c84fb0b1e4e 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -11,12 +12,47 @@ namespace System.Net.Sockets.Tests { public abstract class ReceiveMessageFrom : SocketTestHelperBase where T : SocketHelperBase, new() { + protected static IPEndPoint GetGetDummyTestEndpoint(AddressFamily addressFamily = AddressFamily.InterNetwork) => + addressFamily == AddressFamily.InterNetwork ? new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1234) : new IPEndPoint(IPAddress.Parse("1:2:3::4"), 1234); + + protected static readonly TimeSpan CancellationTestTimeout = TimeSpan.FromSeconds(30); + protected ReceiveMessageFrom(ITestOutputHelper output) : base(output) { } + [PlatformSpecific(TestPlatforms.AnyUnix)] + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ReceiveSent_TCP_Success(bool ipv6) + { + if (ipv6 && PlatformDetection.IsOSX) + { + // [ActiveIssue("https://github.com/dotnet/runtime/issues/47335")] + // accept() will create a (seemingly) DualMode socket on Mac, + // but since recvmsg() does not work with DualMode on that OS, we throw PNSE CheckDualModeReceiveSupport(). + // Weirdly, the flag is readable, but an attempt to write it leads to EINVAL. + // The best option seems to be to skip this test for the Mac+IPV6 case + return; + } + + (Socket sender, Socket receiver) = SocketTestExtensions.CreateConnectedSocketPair(ipv6); + using (sender) + using (receiver) + { + byte[] sendBuffer = { 1, 2, 3 }; + sender.Send(sendBuffer); + + byte[] receiveBuffer = new byte[3]; + var r = await ReceiveMessageFromAsync(receiver, receiveBuffer, sender.LocalEndPoint); + Assert.Equal(3, r.ReceivedBytes); + AssertExtensions.SequenceEqual(sendBuffer, receiveBuffer); + } + } + [Theory] [InlineData(false)] [InlineData(true)] - public async Task ReceiveSentMessages_Success(bool ipv4) + public async Task ReceiveSentMessages_UDP_Success(bool ipv4) { const int DatagramSize = 256; const int DatagramsToSend = 16; @@ -52,6 +88,89 @@ public async Task ReceiveSentMessages_Success(bool ipv4) Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, packetInformation.Address); } } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ClosedBeforeOperation_Throws_ObjectDisposedException(bool closeOrDispose) + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + socket.BindToAnonymousPort(IPAddress.Any); + if (closeOrDispose) socket.Close(); + else socket.Dispose(); + + await Assert.ThrowsAsync(() => ReceiveMessageFromAsync(socket, new byte[1], GetGetDummyTestEndpoint())); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ClosedDuringOperation_Throws_ObjectDisposedExceptionOrSocketException(bool closeOrDispose) + { + if (UsesSync && PlatformDetection.IsOSX) + { + // [ActiveIssue("https://github.com/dotnet/runtime/issues/47342")] + // On Mac, Close/Dispose hangs when invoked concurrently with a blocking UDP receive. + return; + } + + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + socket.BindToAnonymousPort(IPAddress.Any); + + Task receiveTask = ReceiveMessageFromAsync(socket, new byte[1], GetGetDummyTestEndpoint()); + await Task.Delay(100); + if (closeOrDispose) socket.Close(); + else socket.Dispose(); + + if (UsesApm) + { + await Assert.ThrowsAsync(() => receiveTask) + .TimeoutAfter(CancellationTestTimeout); + } + else + { + SocketException ex = await Assert.ThrowsAsync(() => receiveTask) + .TimeoutAfter(CancellationTestTimeout); + SocketError expectedError = UsesSync ? SocketError.Interrupted : SocketError.OperationAborted; + Assert.Equal(expectedError, ex.SocketErrorCode); + } + } + + [PlatformSpecific(TestPlatforms.Windows)] // It's allowed to shutdown() UDP sockets on Windows, however on Unix this will lead to ENOTCONN + [Theory] + [InlineData(SocketShutdown.Both)] + [InlineData(SocketShutdown.Receive)] + public async Task ShutdownReceiveBeforeOperation_ThrowsSocketException(SocketShutdown shutdown) + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + socket.BindToAnonymousPort(IPAddress.Any); + socket.Shutdown(shutdown); + + // [ActiveIssue("https://github.com/dotnet/runtime/issues/47469")] + // Shutdown(Both) does not seem to take immediate effect for Receive(Message)From in a consistent manner, trying to workaround with a delay: + if (shutdown == SocketShutdown.Both) await Task.Delay(50); + + SocketException exception = await Assert.ThrowsAnyAsync(() => ReceiveMessageFromAsync(socket, new byte[1], GetGetDummyTestEndpoint())) + .TimeoutAfter(CancellationTestTimeout); + + Assert.Equal(SocketError.Shutdown, exception.SocketErrorCode); + } + + [PlatformSpecific(TestPlatforms.Windows)] // It's allowed to shutdown() UDP sockets on Windows, however on Unix this will lead to ENOTCONN + [Fact] + public async Task ShutdownSend_ReceiveFromShouldSucceed() + { + using var receiver = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + receiver.BindToAnonymousPort(IPAddress.Loopback); + receiver.Shutdown(SocketShutdown.Send); + + using var sender = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + sender.BindToAnonymousPort(IPAddress.Loopback); + sender.SendTo(new byte[1], receiver.LocalEndPoint); + + var r = await ReceiveMessageFromAsync(receiver, new byte[1], sender.LocalEndPoint); + Assert.Equal(1, r.ReceivedBytes); + } } public sealed class ReceiveMessageFrom_Sync : ReceiveMessageFrom @@ -74,6 +193,31 @@ public sealed class ReceiveMessageFrom_Task : ReceiveMessageFrom + { + public ReceiveMessageFrom_CancellableTask(ITestOutputHelper output) : base(output) { } + + [Theory] + [MemberData(nameof(LoopbacksAndBuffers))] + public async Task WhenCanceled_Throws(IPAddress loopback, bool precanceled) + { + using var socket = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using var dummy = new Socket(loopback.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + socket.BindToAnonymousPort(loopback); + dummy.BindToAnonymousPort(loopback); + Memory buffer = new byte[1]; + + CancellationTokenSource cts = new CancellationTokenSource(); + if (precanceled) cts.Cancel(); + else cts.CancelAfter(100); + + OperationCanceledException ex = await Assert.ThrowsAnyAsync( + () => socket.ReceiveMessageFromAsync(buffer, SocketFlags.None, dummy.LocalEndPoint, cts.Token).AsTask()) + .TimeoutAfter(CancellationTestTimeout); + Assert.Equal(cts.Token, ex.CancellationToken); + } + } + public sealed class ReceiveMessageFrom_Eap : ReceiveMessageFrom { public ReceiveMessageFrom_Eap(ITestOutputHelper output) : base(output) { } @@ -87,7 +231,7 @@ public ReceiveMessageFrom_Eap(ITestOutputHelper output) : base(output) { } [InlineData(true, 2)] public void ReceiveSentMessages_ReuseEventArgs_Success(bool ipv4, int bufferMode) { - const int DatagramsToSend = 30; + const int DatagramsToSend = 5; const int TimeoutMs = 30_000; AddressFamily family; @@ -119,34 +263,40 @@ public void ReceiveSentMessages_ReuseEventArgs_Success(bool ipv4, int bufferMode sender.Bind(new IPEndPoint(loopback, 0)); saea.RemoteEndPoint = new IPEndPoint(any, 0); + Random random = new Random(0); + byte[] sendBuffer = new byte[1024]; + random.NextBytes(sendBuffer); + for (int i = 0; i < DatagramsToSend; i++) { + byte[] receiveBuffer = new byte[1024]; switch (bufferMode) { case 0: // single buffer - saea.SetBuffer(new byte[1024], 0, 1024); + saea.SetBuffer(receiveBuffer, 0, 1024); break; case 1: // single buffer in buffer list saea.BufferList = new List> { - new ArraySegment(new byte[1024]) + new ArraySegment(receiveBuffer) }; break; case 2: // multiple buffers in buffer list saea.BufferList = new List> { - new ArraySegment(new byte[512]), - new ArraySegment(new byte[512]) + new ArraySegment(receiveBuffer, 0, 512), + new ArraySegment(receiveBuffer, 512, 512) }; break; } bool pending = receiver.ReceiveMessageFromAsync(saea); - sender.SendTo(new byte[1024], new IPEndPoint(loopback, port)); + sender.SendTo(sendBuffer, new IPEndPoint(loopback, port)); if (pending) Assert.True(completed.Wait(TimeoutMs), "Expected operation to complete within timeout"); completed.Reset(); Assert.Equal(1024, saea.BytesTransferred); + AssertExtensions.SequenceEqual(sendBuffer, receiveBuffer); Assert.Equal(sender.LocalEndPoint, saea.RemoteEndPoint); Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, saea.ReceiveMessageFromPacketInfo.Address); } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceiveNonParallel.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceiveNonParallel.cs index 634d0f4854939f..470472f226d4ec 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceiveNonParallel.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceiveNonParallel.cs @@ -124,6 +124,11 @@ public sealed class SendReceiveNonParallel_Task : SendReceiveNonParallel + { + public SendReceiveNonParallel_CancellableTask(ITestOutputHelper output) : base(output) { } + } + public sealed class SendReceiveNonParallel_Eap : SendReceiveNonParallel { public SendReceiveNonParallel_Eap(ITestOutputHelper output) : base(output) { } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs index 3ff039c3e662c0..1159d40f2e4a33 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs @@ -1,17 +1,21 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; using System.Threading.Tasks; using Xunit; using Xunit.Abstractions; namespace System.Net.Sockets.Tests { - public abstract class SendToBase : SocketTestHelperBase where T : SocketHelperBase, new() + public abstract class SendTo : SocketTestHelperBase where T : SocketHelperBase, new() { - private static readonly IPEndPoint ValidUdpRemoteEndpoint = new IPEndPoint(IPAddress.Parse("10.20.30.40"), 1234); + protected static readonly IPEndPoint ValidUdpRemoteEndpoint = new IPEndPoint(IPAddress.Parse("10.20.30.40"), 1234); - protected SendToBase(ITestOutputHelper output) : base(output) + protected SendTo(ITestOutputHelper output) : base(output) { } @@ -80,59 +84,77 @@ public async Task Datagram_UDP_AccessDenied_Throws_DoesNotBind() Assert.Equal(SocketError.AccessDenied, e.SocketErrorCode); Assert.Null(socket.LocalEndPoint); } + + [Fact] + public async Task Disposed_Throws() + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + socket.Dispose(); + + await Assert.ThrowsAsync(() => SendToAsync(socket, new byte[1], ValidUdpRemoteEndpoint)); + } } - public static class SendTo + public sealed class SendTo_SyncSpan : SendTo { - public static class Sync - { - public sealed class Span : SendToBase - { - public Span(ITestOutputHelper output) : base(output) { } - } + public SendTo_SyncSpan(ITestOutputHelper output) : base(output) { } + } - public sealed class SpanForceNonBlocking : SendToBase - { - public SpanForceNonBlocking(ITestOutputHelper output) : base(output) { } - } + public sealed class SendTo_SyncSpanForceNonBlocking : SendTo + { + public SendTo_SyncSpanForceNonBlocking(ITestOutputHelper output) : base(output) { } + } - public sealed class MemoryArrayTask : SendToBase - { - public MemoryArrayTask(ITestOutputHelper output) : base(output) { } - } + public sealed class SendTo_ArraySync : SendTo + { + public SendTo_ArraySync(ITestOutputHelper output) : base(output) { } + } - public sealed class MemoryNativeTask : SendToBase - { - public MemoryNativeTask(ITestOutputHelper output) : base(output) { } - } + public sealed class SendTo_SyncForceNonBlocking : SendTo + { + public SendTo_SyncForceNonBlocking(ITestOutputHelper output) : base(output) {} + } - public sealed class ArraySync : SendToBase - { - public ArraySync(ITestOutputHelper output) : base(output) { } - } + public sealed class SendTo_Apm : SendTo + { + public SendTo_Apm(ITestOutputHelper output) : base(output) {} + } - public sealed class ArrayForceNonBlocking : SendToBase - { - public ArrayForceNonBlocking(ITestOutputHelper output) : base(output) {} - } - } + public sealed class SendTo_Eap : SendTo + { + public SendTo_Eap(ITestOutputHelper output) : base(output) {} + } - public static class Async + public sealed class SendTo_Task : SendTo + { + public SendTo_Task(ITestOutputHelper output) : base(output) { } + } + + public sealed class SendTo_CancellableTask : SendTo + { + public SendTo_CancellableTask(ITestOutputHelper output) : base(output) { } + + [Fact] + public async Task PreCanceled_Throws() { - public sealed class Apm : SendToBase - { - public Apm(ITestOutputHelper output) : base(output) {} - } + using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + CancellationTokenSource cts = new CancellationTokenSource(); + cts.Cancel(); - public sealed class Task : SendToBase - { - public Task(ITestOutputHelper output) : base(output) {} - } + OperationCanceledException ex = await Assert.ThrowsAnyAsync( + () => sender.SendToAsync(new byte[1], SocketFlags.None, ValidUdpRemoteEndpoint, cts.Token).AsTask()); - public sealed class Eap : SendToBase - { - public Eap(ITestOutputHelper output) : base(output) {} - } + Assert.Equal(cts.Token, ex.CancellationToken); } } + + public sealed class SendTo_MemoryArrayTask : SendTo + { + public SendTo_MemoryArrayTask(ITestOutputHelper output) : base(output) { } + } + + public sealed class SendTo_MemoryNativeTask : SendTo + { + public SendTo_MemoryNativeTask(ITestOutputHelper output) : base(output) { } + } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs index 93e050059ca719..2e3801f7a7f172 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs @@ -228,6 +228,8 @@ public class SocketHelperCancellableTask : SocketHelperBase { // Use a cancellable CancellationToken that we never cancel so that implementations can't just elide handling the CancellationToken. private readonly CancellationTokenSource _cts = new CancellationTokenSource(); + // This variant is typically working with Memory overloads. + public override bool ValidatesArrayArguments => false; public override Task AcceptAsync(Socket s) => s.AcceptAsync(); @@ -244,15 +246,15 @@ public override Task ReceiveAsync(Socket s, ArraySegment buffer) => public override Task ReceiveAsync(Socket s, IList> bufferList) => s.ReceiveAsync(bufferList, SocketFlags.None); public override Task ReceiveFromAsync(Socket s, ArraySegment buffer, EndPoint endPoint) => - s.ReceiveFromAsync(buffer, SocketFlags.None, endPoint); + s.ReceiveFromAsync(buffer, SocketFlags.None, endPoint, _cts.Token).AsTask(); public override Task ReceiveMessageFromAsync(Socket s, ArraySegment buffer, EndPoint endPoint) => - s.ReceiveMessageFromAsync(buffer, SocketFlags.None, endPoint); + s.ReceiveMessageFromAsync(buffer, SocketFlags.None, endPoint, _cts.Token).AsTask(); public override Task SendAsync(Socket s, ArraySegment buffer) => s.SendAsync(buffer, SocketFlags.None, _cts.Token).AsTask(); public override Task SendAsync(Socket s, IList> bufferList) => s.SendAsync(bufferList, SocketFlags.None); public override Task SendToAsync(Socket s, ArraySegment buffer, EndPoint endPoint) => - s.SendToAsync(buffer, SocketFlags.None, endPoint); + s.SendToAsync(buffer, SocketFlags.None, endPoint, _cts.Token).AsTask() ; } public sealed class SocketHelperEap : SocketHelperBase diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj index 7ce628db90aae5..f79246b4d34efa 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj @@ -26,6 +26,7 @@ +