-
Notifications
You must be signed in to change notification settings - Fork 299
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
533 additions
and
217 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
289 changes: 289 additions & 0 deletions
289
...ft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,289 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Buffers; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
|
||
namespace Microsoft.Data.SqlClient.SNI | ||
{ | ||
internal sealed partial class SslOverTdsStream | ||
{ | ||
public override int Read(byte[] buffer, int offset, int count) | ||
{ | ||
return Read(buffer.AsSpan(offset, count)); | ||
} | ||
|
||
public override void Write(byte[] buffer, int offset, int count) | ||
{ | ||
Write(buffer.AsSpan(offset, count)); | ||
} | ||
|
||
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | ||
{ | ||
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask(); | ||
} | ||
|
||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | ||
{ | ||
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask(); | ||
} | ||
|
||
public override int Read(Span<byte> buffer) | ||
{ | ||
if (_encapsulate) | ||
{ | ||
if (_packetBytes > 0) | ||
{ | ||
// there are queued bytes from a previous packet available | ||
// work out how many of the remaining bytes we can consume | ||
int wantedCount = Math.Min(buffer.Length, _packetBytes); | ||
int readCount = _stream.Read(buffer.Slice(0, wantedCount)); | ||
if (readCount == 0) | ||
{ | ||
// 0 means the connection was closed, tell the caller | ||
return 0; | ||
} | ||
_packetBytes -= readCount; | ||
return readCount; | ||
} | ||
else | ||
{ | ||
Span<byte> headerBytes = stackalloc byte[TdsEnums.HEADER_LEN]; | ||
|
||
// fetch the packet header to determine how long the packet is | ||
int headerBytesRead = 0; | ||
do | ||
{ | ||
int headerBytesReadIteration = _stream.Read(headerBytes.Slice(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead))); | ||
if (headerBytesReadIteration == 0) | ||
{ | ||
// 0 means the connection was closed, tell the caller | ||
return 0; | ||
} | ||
headerBytesRead += headerBytesReadIteration; | ||
} while (headerBytesRead < TdsEnums.HEADER_LEN); | ||
|
||
// read the packet data size from the header and store it in case it is needed for a subsequent call | ||
_packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; | ||
|
||
// read as much from the packet as the caller can accept | ||
int packetBytesRead = _stream.Read(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes))); | ||
_packetBytes -= packetBytesRead; | ||
return packetBytesRead; | ||
} | ||
} | ||
else | ||
{ | ||
return _stream.Read(buffer); | ||
} | ||
} | ||
|
||
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default) | ||
{ | ||
if (_encapsulate) | ||
{ | ||
if (_packetBytes > 0) | ||
{ | ||
// there are queued bytes from a previous packet available | ||
// work out how many of the remaining bytes we can consume | ||
int wantedCount = Math.Min(buffer.Length, _packetBytes); | ||
|
||
int readCount; | ||
{ | ||
ValueTask<int> remainderReadValueTask = _stream.ReadAsync(buffer.Slice(0, wantedCount), cancellationToken); | ||
if (remainderReadValueTask.IsCompletedSuccessfully) | ||
{ | ||
readCount = remainderReadValueTask.Result; | ||
} | ||
else | ||
{ | ||
readCount = await remainderReadValueTask.AsTask().ConfigureAwait(false); | ||
} | ||
} | ||
if (readCount == 0) | ||
{ | ||
// 0 means the connection was closed, tell the caller | ||
return 0; | ||
} | ||
_packetBytes -= readCount; | ||
return readCount; | ||
} | ||
else | ||
{ | ||
byte[] headerBytes = ArrayPool<byte>.Shared.Rent(TdsEnums.HEADER_LEN); | ||
Array.Clear(headerBytes, 0, headerBytes.Length); | ||
|
||
// fetch the packet header to determine how long the packet is | ||
int headerBytesRead = 0; | ||
do | ||
{ | ||
int headerBytesReadIteration; | ||
{ | ||
ValueTask<int> headerReadValueTask = _stream.ReadAsync(headerBytes.AsMemory(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead)), cancellationToken); | ||
if (headerReadValueTask.IsCompletedSuccessfully) | ||
{ | ||
headerBytesReadIteration = headerReadValueTask.Result; | ||
} | ||
else | ||
{ | ||
headerBytesReadIteration = await headerReadValueTask.AsTask().ConfigureAwait(false); | ||
} | ||
} | ||
if (headerBytesReadIteration == 0) | ||
{ | ||
// 0 means the connection was closed, cleanup the rented array and then tell the caller | ||
ArrayPool<byte>.Shared.Return(headerBytes, clearArray: true); | ||
return 0; | ||
} | ||
headerBytesRead += headerBytesReadIteration; | ||
} while (headerBytesRead < TdsEnums.HEADER_LEN); | ||
|
||
// read the packet data size from the header and store it in case it is needed for a subsequent call | ||
_packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; | ||
|
||
ArrayPool<byte>.Shared.Return(headerBytes, clearArray: true); | ||
|
||
// read as much from the packet as the caller can accept | ||
int packetBytesRead; | ||
{ | ||
ValueTask<int> packetReadValueTask = _stream.ReadAsync(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes)), cancellationToken); | ||
if (packetReadValueTask.IsCompletedSuccessfully) | ||
{ | ||
packetBytesRead = packetReadValueTask.Result; | ||
} | ||
else | ||
{ | ||
packetBytesRead = await packetReadValueTask.AsTask().ConfigureAwait(false); | ||
} | ||
} | ||
_packetBytes -= packetBytesRead; | ||
return packetBytesRead; | ||
} | ||
} | ||
else | ||
{ | ||
int read; | ||
{ | ||
ValueTask<int> readValueTask = _stream.ReadAsync(buffer, cancellationToken); | ||
if (readValueTask.IsCompletedSuccessfully) | ||
{ | ||
read = readValueTask.Result; | ||
} | ||
else | ||
{ | ||
read = await readValueTask.AsTask().ConfigureAwait(false); | ||
} | ||
} | ||
return read; | ||
} | ||
} | ||
|
||
public override void Write(ReadOnlySpan<byte> buffer) | ||
{ | ||
// During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After | ||
// negotiation, the underlying socket only sees SSL frames. | ||
if (_encapsulate) | ||
{ | ||
ReadOnlySpan<byte> remaining = buffer; | ||
byte[] packetBuffer = null; | ||
while (remaining.Length > 0) | ||
{ | ||
int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length); | ||
int packetLength = TdsEnums.HEADER_LEN + dataLength; | ||
|
||
if (packetBuffer == null) | ||
{ | ||
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength); | ||
} | ||
else if (packetBuffer.Length < packetLength) | ||
{ | ||
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true); | ||
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength); | ||
} | ||
|
||
SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength); | ||
|
||
Span<byte> data = packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength); | ||
remaining.Slice(0, dataLength).CopyTo(data); | ||
|
||
_stream.Write(packetBuffer.AsSpan(0, packetLength)); | ||
_stream.Flush(); | ||
|
||
remaining = remaining.Slice(dataLength); | ||
} | ||
if (packetBuffer != null) | ||
{ | ||
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true); | ||
} | ||
} | ||
else | ||
{ | ||
_stream.Write(buffer); | ||
_stream.Flush(); | ||
} | ||
} | ||
|
||
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) | ||
{ | ||
if (_encapsulate) | ||
{ | ||
ReadOnlyMemory<byte> remaining = buffer; | ||
byte[] packetBuffer = null; | ||
while (remaining.Length > 0) | ||
{ | ||
int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length); | ||
int packetLength = TdsEnums.HEADER_LEN + dataLength; | ||
|
||
if (packetBuffer == null) | ||
{ | ||
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength); | ||
} | ||
else if (packetBuffer.Length < packetLength) | ||
{ | ||
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true); | ||
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength); | ||
} | ||
|
||
SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength); | ||
|
||
remaining.Span.Slice(0, dataLength).CopyTo(packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength)); | ||
|
||
{ | ||
ValueTask packetWriteValueTask = _stream.WriteAsync(new ReadOnlyMemory<byte>(packetBuffer, 0, packetLength), cancellationToken); | ||
if (!packetWriteValueTask.IsCompletedSuccessfully) | ||
{ | ||
await packetWriteValueTask.AsTask().ConfigureAwait(false); | ||
} | ||
} | ||
|
||
await _stream.FlushAsync().ConfigureAwait(false); | ||
|
||
|
||
remaining = remaining.Slice(dataLength); | ||
} | ||
if (packetBuffer != null) | ||
{ | ||
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true); | ||
} | ||
} | ||
else | ||
{ | ||
{ | ||
ValueTask valueTask = _stream.WriteAsync(buffer, cancellationToken); | ||
if (!valueTask.IsCompletedSuccessfully) | ||
{ | ||
await valueTask.AsTask().ConfigureAwait(false); | ||
} | ||
} | ||
Task flushTask = _stream.FlushAsync(); | ||
if (flushTask.IsCompletedSuccessfully) | ||
{ | ||
await flushTask.ConfigureAwait(false); | ||
} | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.