Skip to content

Commit

Permalink
rebase and fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Wraith2 committed Aug 28, 2021
1 parent 241631a commit 6d49b49
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ internal abstract class SNIHandle
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public abstract uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null);
public abstract uint SendAsync(SNIPacket packet);

/// <summary>
/// Receive a packet synchronously
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,14 @@ public uint Send(SNIPacket packet)
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public uint SendAsync(SNIPacket packet, SNIAsyncCallback callback)
public uint SendAsync(SNIPacket packet)
{
using (TrySNIEventScope.Create(nameof(SNIMarsConnection)))
{
lock (DemuxerSync)
{
return _lowerHandle.SendAsync(packet, callback);
return _lowerHandle.SendAsync(packet);
}
}
}
Expand Down Expand Up @@ -191,7 +190,7 @@ public void HandleReceiveError(SNIPacket packet)
/// <param name="sniErrorCode">SNI error code</param>
public void HandleSendComplete(SNIPacket packet, uint sniErrorCode)
{
packet.InvokeCompletionCallback(sniErrorCode);
packet.InvokeReceiveCallback(sniErrorCode);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,8 @@ public override uint Send(SNIPacket packet)
/// Send packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback)
private uint InternalSendAsync(SNIPacket packet)
{
Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to send muxed packet without smux reservation in InternalSendAsync");
using (TrySNIEventScope.Create("SNIMarsHandle.InternalSendAsync | SNI | INFO | SCOPE | Entering Scope {0}"))
Expand All @@ -215,9 +214,9 @@ private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback)
}

SNIPacket muxedPacket = SetPacketSMUXHeader(packet);
muxedPacket.SetCompletionCallback(callback ?? HandleSendComplete);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsHandle), EventType.INFO, "MARS Session Id {0}, _sequenceNumber {1}, _sendHighwater {2}, Sending packet", args0: ConnectionId, args1: _sequenceNumber, args2: _sendHighwater);
return _connection.SendAsync(muxedPacket, callback);
muxedPacket.SetReceiveCallback(HandleSendComplete);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "MARS Session Id {0}, _sequenceNumber {1}, _sendHighwater {2}, Sending packet", args0: ConnectionId, args1: _sequenceNumber, args2: _sendHighwater);
return _connection.SendAsync(muxedPacket);
}
}
}
Expand All @@ -241,7 +240,7 @@ private uint SendPendingPackets()
if (_sendPacketQueue.Count != 0)
{
packet = _sendPacketQueue.Peek();
uint result = InternalSendAsync(packet.Packet, packet.Callback);
uint result = InternalSendAsync(packet.Packet);

if (result != TdsEnums.SNI_SUCCESS && result != TdsEnums.SNI_SUCCESS_IO_PENDING)
{
Expand Down Expand Up @@ -272,16 +271,15 @@ private uint SendPendingPackets()
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
public override uint SendAsync(SNIPacket packet)
{
Debug.Assert(_connection != null && Monitor.IsEntered(_connection.DemuxerSync), "SNIMarsHandle.HandleRecieveComplete should be called while holding the SNIMarsConnection.DemuxerSync because it can cause deadlocks");
using (TrySNIEventScope.Create(nameof(SNIMarsHandle)))
{
lock (this)
{
_sendPacketQueue.Enqueue(new SNIMarsQueuedPacket(packet, callback ?? _handleSendCompleteCallback));
_sendPacketQueue.Enqueue(new SNIMarsQueuedPacket(packet, _handleSendCompleteCallback));
}

SendPendingPackets();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.IO.Pipes;
using System.Net.Security;
Expand Down Expand Up @@ -210,11 +211,11 @@ public override uint ReceiveAsync(ref SNIPacket packet)
{
SNIPacket errorPacket;
packet = RentPacket(headerSize: 0, dataSize: _bufferSize);

packet.SetReceiveCallback(_receiveCallback);
try
{
packet.ReadFromStreamAsync(_stream, _receiveCallback);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO, "Connection Id {0}, Rented and read packet asynchronously, dataLeft {1}", args0: _connectionId, args1: packet?.DataLeft);
packet.ReadFromStreamAsync(_stream);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Connection Id {0}, Rented and read packet asynchronously, dataLeft {1}", args0: _connectionId, args1: packet?.DataLeft);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
catch (ObjectDisposedException ode)
Expand Down Expand Up @@ -288,13 +289,12 @@ public override uint Send(SNIPacket packet)
}
}

public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
public override uint SendAsync(SNIPacket packet)
{
using (TrySNIEventScope.Create(nameof(SNINpHandle)))
{
SNIAsyncCallback cb = callback ?? _sendCallback;
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO, "Connection Id {0}, Packet writing to stream, dataLeft {1}", args0: _connectionId, args1: packet?.DataLeft);
packet.WriteToStreamAsync(_stream, cb, SNIProviders.NP_PROV);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Connection Id {0}, Packet writing to stream, dataLeft {1}", args0: _connectionId, args1: packet?.DataLeft);
packet.WriteToStreamAsync(_stream, _sendCallback, SNIProviders.NP_PROV);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ internal sealed partial class SNIPacket
private int _headerLength; // the amount of space at the start of the array reserved for the smux header, this is zeroed in SetHeader
// _headerOffset is not needed because it is always 0
private byte[] _data;
private SNIAsyncCallback _completionCallback;
private readonly Action<Task<int>, object> _readCallback;
private SNIAsyncCallback _receiveCallback;

public SNIPacket()
{
_readCallback = ReadFromStreamAsyncContinuation;
}

/// <summary>
Expand All @@ -52,25 +50,19 @@ public SNIPacket()

public int ReservedHeaderSize => _headerLength;

public bool HasCompletionCallback => !(_completionCallback is null);
public bool HasCompletionCallback => _receiveCallback is not null;

/// <summary>
/// Set async completion callback
/// Set async receive callback
/// </summary>
/// <param name="completionCallback">Completion callback</param>
public void SetCompletionCallback(SNIAsyncCallback completionCallback)
{
_completionCallback = completionCallback;
}
/// <param name="receiveCallback">Completion callback</param>
public void SetReceiveCallback(SNIAsyncCallback receiveCallback) => _receiveCallback = receiveCallback;

/// <summary>
/// Invoke the completion callback
/// Invoke the receive callback
/// </summary>
/// <param name="sniErrorCode">SNI error</param>
public void InvokeCompletionCallback(uint sniErrorCode)
{
_completionCallback(this, sniErrorCode);
}
public void InvokeReceiveCallback(uint sniErrorCode) => _receiveCallback(this, sniErrorCode);

/// <summary>
/// Allocate space for data
Expand Down Expand Up @@ -205,7 +197,7 @@ public void Release()
_dataLength = 0;
_dataOffset = 0;
_headerLength = 0;
_completionCallback = null;
_receiveCallback = null;
IsOutOfBand = false;
}

Expand All @@ -225,22 +217,21 @@ public void ReadFromStream(Stream stream)
/// Read data from a stream asynchronously
/// </summary>
/// <param name="stream">Stream to read from</param>
/// <param name="callback">Completion callback</param>
public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
public void ReadFromStreamAsync(Stream stream)
{
stream.ReadAsync(_data, 0, _dataCapacity, CancellationToken.None)
.ContinueWith(
continuationAction: _readCallback,
state: callback,
continuationAction: s_readCallback,
state: this,
CancellationToken.None,
TaskContinuationOptions.DenyChildAttach,
TaskScheduler.Default
);
}

private void ReadFromStreamAsyncContinuation(Task<int> t, object state)
private static void ReadFromStreamAsyncContinuation(Task<int> t, object state)
{
SNIAsyncCallback callback = (SNIAsyncCallback)state;
SNIPacket packet = (SNIPacket)state;
bool error = false;
Exception e = t.Exception?.InnerException;
if (e != null)
Expand All @@ -253,11 +244,11 @@ private void ReadFromStreamAsyncContinuation(Task<int> t, object state)
}
else
{
_dataLength = t.Result;
packet._dataLength = t.Result;
#if DEBUG
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIPacket), EventType.INFO, "Connection Id {0}, Packet Id {1} _dataLength {2} read from stream.", args0: _owner?.ConnectionId, args1: _id, args2: _dataLength);
#endif
if (_dataLength == 0)
if (packet._dataLength == 0)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, Strings.SNI_ERROR_2);
#if DEBUG
Expand All @@ -267,7 +258,7 @@ private void ReadFromStreamAsyncContinuation(Task<int> t, object state)
}
}

callback(this, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
packet.InvokeReceiveCallback(error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Security;
Expand Down Expand Up @@ -800,15 +801,13 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
public override uint SendAsync(SNIPacket packet)
{
using (TrySNIEventScope.Create(nameof(SNITCPHandle)))
{
SNIAsyncCallback cb = callback ?? _sendCallback;
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Data sent to stream asynchronously", args0: _connectionId);
packet.WriteToStreamAsync(_stream, _sendCallback, SNIProviders.TCP_PROV);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Connection Id {0}, Data sent to stream asynchronously", args0: _connectionId);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
}
Expand All @@ -822,11 +821,11 @@ public override uint ReceiveAsync(ref SNIPacket packet)
{
SNIPacket errorPacket;
packet = RentPacket(headerSize: 0, dataSize: _bufferSize);

packet.SetReceiveCallback(_receiveCallback);
try
{
packet.ReadFromStreamAsync(_stream, _receiveCallback);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Data received from stream asynchronously", args0: _connectionId);
packet.ReadFromStreamAsync(_stream);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Connection Id {0}, Data received from stream asynchronously", args0: _connectionId);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
catch (Exception e) when (e is ObjectDisposedException || e is SocketException || e is IOException)
Expand Down

0 comments on commit 6d49b49

Please sign in to comment.