diff --git a/MQTTnet.sln.DotSettings b/MQTTnet.sln.DotSettings index b12b29857..962baa61f 100644 --- a/MQTTnet.sln.DotSettings +++ b/MQTTnet.sln.DotSettings @@ -236,6 +236,7 @@ See the LICENSE file in the project root for more information. True True True + True True True True diff --git a/Source/MQTTnet.Server/Internal/Adapter/MqttTcpServerListener.cs b/Source/MQTTnet.Server/Internal/Adapter/MqttTcpServerListener.cs index 4982cc40d..fc76b0575 100644 --- a/Source/MQTTnet.Server/Internal/Adapter/MqttTcpServerListener.cs +++ b/Source/MQTTnet.Server/Internal/Adapter/MqttTcpServerListener.cs @@ -136,7 +136,7 @@ async Task AcceptClientConnectionsAsync(CancellationToken cancellationToken) { try { - var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); + var clientSocket = await _socket.AcceptAsync(cancellationToken).ConfigureAwait(false); if (clientSocket == null) { continue; diff --git a/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs b/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs index 7b6276642..3813e23b2 100644 --- a/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs +++ b/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs @@ -2,78 +2,80 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Implementations; using System; using System.Net; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Implementations; -namespace MQTTnet.Tests +namespace MQTTnet.Tests; + +[TestClass] +public class MqttTcpChannel_Tests { - [TestClass] - public class MqttTcpChannel_Tests + [TestMethod] + public async Task Dispose_Channel_While_Used() { - [TestMethod] - public async Task Dispose_Channel_While_Used() - { - var ct = new CancellationTokenSource(); - var serverSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp); + using var ct = new CancellationTokenSource(); + using var serverSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp); - try - { - serverSocket.Bind(new IPEndPoint(IPAddress.Any, 50001)); - serverSocket.Listen(0); + try + { + serverSocket.Bind(new IPEndPoint(IPAddress.Any, 50001)); + serverSocket.Listen(0); #pragma warning disable 4014 - Task.Run(async () => + Task.Run( + async () => #pragma warning restore 4014 { while (!ct.IsCancellationRequested) { - var client = await serverSocket.AcceptAsync(); + var client = await serverSocket.AcceptAsync(CancellationToken.None); var data = new byte[] { 128 }; await client.SendAsync(new ArraySegment(data), SocketFlags.None); } - }, ct.Token); + }, + ct.Token); - var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp); - await clientSocket.ConnectAsync(new DnsEndPoint("localhost", 50001), CancellationToken.None); + using var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp); + await clientSocket.ConnectAsync(new DnsEndPoint("localhost", 50001), CancellationToken.None); - var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null); + var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null); - await Task.Delay(100, ct.Token); + await Task.Delay(100, ct.Token); - var buffer = new byte[1]; - await tcpChannel.ReadAsync(buffer, 0, 1, ct.Token); + var buffer = new byte[1]; + await tcpChannel.ReadAsync(buffer, 0, 1, ct.Token); - Assert.AreEqual(128, buffer[0]); + Assert.AreEqual(128, buffer[0]); - // This block should fail after dispose. + // This block should fail after dispose. #pragma warning disable 4014 - Task.Run(() => + Task.Run( + () => #pragma warning restore 4014 { Task.Delay(200, ct.Token); tcpChannel.Dispose(); - }, ct.Token); + }, + ct.Token); - try - { - await tcpChannel.ReadAsync(buffer, 0, 1, CancellationToken.None); - } - catch (Exception exception) - { - Assert.IsInstanceOfType(exception, typeof(SocketException)); - Assert.AreEqual(SocketError.OperationAborted, ((SocketException)exception).SocketErrorCode); - } + try + { + await tcpChannel.ReadAsync(buffer, 0, 1, CancellationToken.None); } - finally + catch (Exception exception) { - ct.Cancel(false); - serverSocket.Dispose(); + Assert.IsInstanceOfType(exception, typeof(SocketException)); + Assert.AreEqual(SocketError.OperationAborted, ((SocketException)exception).SocketErrorCode); } } + finally + { + ct.Cancel(false); + } } -} +} \ No newline at end of file diff --git a/Source/MQTTnet/Implementations/CrossPlatformSocket.cs b/Source/MQTTnet/Implementations/CrossPlatformSocket.cs index 96f90dc24..f1caab8f7 100644 --- a/Source/MQTTnet/Implementations/CrossPlatformSocket.cs +++ b/Source/MQTTnet/Implementations/CrossPlatformSocket.cs @@ -108,11 +108,11 @@ public int TcpKeepAliveTime set => _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, value); } - public async Task AcceptAsync() + public async Task AcceptAsync(CancellationToken cancellationToken) { try { - var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); + var clientSocket = await _socket.AcceptAsync(cancellationToken).ConfigureAwait(false); return new CrossPlatformSocket(clientSocket); } catch (ObjectDisposedException)