Skip to content

Commit

Permalink
simplify managed SNI receive callback use
Browse files Browse the repository at this point in the history
  • Loading branch information
Wraith2 committed Jul 24, 2021
1 parent 0288d90 commit bf79179
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ internal abstract class SNIHandle
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public abstract uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null);
public abstract uint SendAsync(SNIPacket packet);

/// <summary>
/// Receive a packet synchronously
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,14 @@ public uint Send(SNIPacket packet)
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public uint SendAsync(SNIPacket packet, SNIAsyncCallback callback)
public uint SendAsync(SNIPacket packet)
{
using (TrySNIEventScope.Create(nameof(SNIMarsConnection)))
{
lock (DemuxerSync)
{
return _lowerHandle.SendAsync(packet, callback);
return _lowerHandle.SendAsync(packet);
}
}
}
Expand Down Expand Up @@ -191,7 +190,7 @@ public void HandleReceiveError(SNIPacket packet)
/// <param name="sniErrorCode">SNI error code</param>
public void HandleSendComplete(SNIPacket packet, uint sniErrorCode)
{
packet.InvokeCompletionCallback(sniErrorCode);
packet.InvokeReceiveCallback(sniErrorCode);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,8 @@ public override uint Send(SNIPacket packet)
/// Send packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback)
private uint InternalSendAsync(SNIPacket packet)
{
Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to send muxed packet without smux reservation in InternalSendAsync");
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent("SNIMarsHandle.InternalSendAsync | SNI | INFO | SCOPE | Entering Scope {0}");
Expand All @@ -229,9 +228,9 @@ private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback)
}

SNIPacket muxedPacket = SetPacketSMUXHeader(packet);
muxedPacket.SetCompletionCallback(callback ?? HandleSendComplete);
muxedPacket.SetReceiveCallback(HandleSendComplete);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "MARS Session Id {0}, _sequenceNumber {1}, _sendHighwater {2}, Sending packet", args0: ConnectionId, args1: _sequenceNumber, args2: _sendHighwater);
return _connection.SendAsync(muxedPacket, callback);
return _connection.SendAsync(muxedPacket);
}
}
finally
Expand Down Expand Up @@ -259,7 +258,7 @@ private uint SendPendingPackets()
if (_sendPacketQueue.Count != 0)
{
packet = _sendPacketQueue.Peek();
uint result = InternalSendAsync(packet.Packet, packet.Callback);
uint result = InternalSendAsync(packet.Packet);

if (result != TdsEnums.SNI_SUCCESS && result != TdsEnums.SNI_SUCCESS_IO_PENDING)
{
Expand Down Expand Up @@ -294,17 +293,16 @@ private uint SendPendingPackets()
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
public override uint SendAsync(SNIPacket packet)
{
Debug.Assert(_connection != null && Monitor.IsEntered(_connection.DemuxerSync), "SNIMarsHandle.HandleRecieveComplete should be called while holding the SNIMarsConnection.DemuxerSync because it can cause deadlocks");
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent(s_className);
try
{
lock (this)
{
_sendPacketQueue.Enqueue(new SNIMarsQueuedPacket(packet, callback ?? _handleSendCompleteCallback));
_sendPacketQueue.Enqueue(new SNIMarsQueuedPacket(packet, _handleSendCompleteCallback));
}

SendPendingPackets();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.IO.Pipes;
using System.Net.Security;
Expand Down Expand Up @@ -238,10 +239,10 @@ public override uint ReceiveAsync(ref SNIPacket packet)
{
SNIPacket errorPacket;
packet = RentPacket(headerSize: 0, dataSize: _bufferSize);

packet.SetReceiveCallback(_receiveCallback);
try
{
packet.ReadFromStreamAsync(_stream, _receiveCallback);
packet.ReadFromStreamAsync(_stream);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Connection Id {0}, Rented and read packet asynchronously, dataLeft {1}", args0: _connectionId, args1: packet?.DataLeft);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
Expand Down Expand Up @@ -325,14 +326,14 @@ public override uint Send(SNIPacket packet)
}
}

public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
public override uint SendAsync(SNIPacket packet)
{
Debug.Assert(_sendCallback != null, "_sendCallback is null");
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent(s_className);
try
{
SNIAsyncCallback cb = callback ?? _sendCallback;
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Connection Id {0}, Packet writing to stream, dataLeft {1}", args0: _connectionId, args1: packet?.DataLeft);
packet.WriteToStreamAsync(_stream, cb, SNIProviders.NP_PROV);
packet.WriteToStreamAsync(_stream, _sendCallback, SNIProviders.NP_PROV);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
finally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ namespace Microsoft.Data.SqlClient.SNI
/// </summary>
internal sealed partial class SNIPacket
{
private static readonly Action<Task<int>, object> s_readCallback = ReadFromStreamAsyncContinuation;

private const string s_className = nameof(SNIPacket);
private int _dataLength; // the length of the data in the data segment, advanced by Append-ing data, does not include smux header length
private int _dataCapacity; // the total capacity requested, if the array is rented this may be less than the _data.Length, does not include smux header length
private int _dataOffset; // the start point of the data in the data segment, advanced by Take-ing data
private int _headerLength; // the amount of space at the start of the array reserved for the smux header, this is zeroed in SetHeader
// _headerOffset is not needed because it is always 0
private byte[] _data;
private SNIAsyncCallback _completionCallback;
private readonly Action<Task<int>, object> _readCallback;
private SNIAsyncCallback _receiveCallback;

public SNIPacket()
{
_readCallback = ReadFromStreamAsyncContinuation;
}

/// <summary>
Expand All @@ -53,25 +53,19 @@ public SNIPacket()

public int ReservedHeaderSize => _headerLength;

public bool HasCompletionCallback => !(_completionCallback is null);
public bool HasCompletionCallback => _receiveCallback is not null;

/// <summary>
/// Set async completion callback
/// Set async receive callback
/// </summary>
/// <param name="completionCallback">Completion callback</param>
public void SetCompletionCallback(SNIAsyncCallback completionCallback)
{
_completionCallback = completionCallback;
}
/// <param name="receiveCallback">Completion callback</param>
public void SetReceiveCallback(SNIAsyncCallback receiveCallback) => _receiveCallback = receiveCallback;

/// <summary>
/// Invoke the completion callback
/// Invoke the receive callback
/// </summary>
/// <param name="sniErrorCode">SNI error</param>
public void InvokeCompletionCallback(uint sniErrorCode)
{
_completionCallback(this, sniErrorCode);
}
public void InvokeReceiveCallback(uint sniErrorCode) => _receiveCallback(this, sniErrorCode);

/// <summary>
/// Allocate space for data
Expand Down Expand Up @@ -206,7 +200,7 @@ public void Release()
_dataLength = 0;
_dataOffset = 0;
_headerLength = 0;
_completionCallback = null;
_receiveCallback = null;
IsOutOfBand = false;
}

Expand All @@ -226,49 +220,48 @@ public void ReadFromStream(Stream stream)
/// Read data from a stream asynchronously
/// </summary>
/// <param name="stream">Stream to read from</param>
/// <param name="callback">Completion callback</param>
public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
public void ReadFromStreamAsync(Stream stream)
{
stream.ReadAsync(_data, 0, _dataCapacity, CancellationToken.None)
.ContinueWith(
continuationAction: _readCallback,
state: callback,
continuationAction: s_readCallback,
state: this,
CancellationToken.None,
TaskContinuationOptions.DenyChildAttach,
TaskScheduler.Default
);
}

private void ReadFromStreamAsyncContinuation(Task<int> t, object state)
private static void ReadFromStreamAsyncContinuation(Task<int> t, object state)
{
SNIAsyncCallback callback = (SNIAsyncCallback)state;
SNIPacket packet = (SNIPacket)state;
bool error = false;
Exception e = t.Exception?.InnerException;
if (e != null)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, e);
#if DEBUG
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.ERR, "Connection Id {0}, Internal Exception occurred while reading data: {1}", args0: _owner?.ConnectionId, args1: e?.Message);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.ERR, "Connection Id {0}, Internal Exception occurred while reading data: {1}", args0: packet._owner?.ConnectionId, args1: e?.Message);
#endif
error = true;
}
else
{
_dataLength = t.Result;
packet._dataLength = t.Result;
#if DEBUG
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Connection Id {0}, Packet Id {1} _dataLength {2} read from stream.", args0: _owner?.ConnectionId, args1: _id, args2: _dataLength);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Connection Id {0}, Packet Id {1} _dataLength {2} read from stream.", args0: packet._owner?.ConnectionId, args1: packet._id, args2: packet._dataLength);
#endif
if (_dataLength == 0)
if (packet._dataLength == 0)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, Strings.SNI_ERROR_2);
#if DEBUG
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.ERR, "Connection Id {0}, No data read from stream, connection was terminated.", args0: _owner?.ConnectionId);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.ERR, "Connection Id {0}, No data read from stream, connection was terminated.", args0: packet._owner?.ConnectionId);
#endif
error = true;
}
}

callback(this, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
packet.InvokeReceiveCallback(error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Security;
Expand Down Expand Up @@ -805,15 +806,14 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
public override uint SendAsync(SNIPacket packet)
{
Debug.Assert(_sendCallback !=null, "_sendCallback is null");
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent(s_className);
try
{
SNIAsyncCallback cb = callback ?? _sendCallback;
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
packet.WriteToStreamAsync(_stream, _sendCallback, SNIProviders.TCP_PROV);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Connection Id {0}, Data sent to stream asynchronously", args0: _connectionId);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
Expand All @@ -832,10 +832,10 @@ public override uint ReceiveAsync(ref SNIPacket packet)
{
SNIPacket errorPacket;
packet = RentPacket(headerSize: 0, dataSize: _bufferSize);

packet.SetReceiveCallback(_receiveCallback);
try
{
packet.ReadFromStreamAsync(_stream, _receiveCallback);
packet.ReadFromStreamAsync(_stream);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Connection Id {0}, Data received from stream asynchronously", args0: _connectionId);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
Expand Down

0 comments on commit bf79179

Please sign in to comment.