Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Content validation update spec #42191

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Buffers;
using System.Buffers.Binary;
using System.IO;
using System.Security.Cryptography;
using Azure.Core;

namespace Azure.Storage.Shared;
Expand All @@ -18,7 +17,7 @@ internal static class StructuredMessage
public enum Flags
{
None = 0,
CrcSegment = 1,
StorageCrc64 = 1,
}

public static class V1_0
Expand Down Expand Up @@ -86,7 +85,62 @@ public static IDisposable GetStreamHeaderBytes(
}
#endregion

// no stream footer content in 1.0
#region StreamFooter
public static void ReadStreamFooter(
ReadOnlySpan<byte> buffer,
Span<byte> crc64 = default)
{
int expectedBufferSize = 0;
if (!crc64.IsEmpty)
{
Errors.AssertBufferExactSize(crc64, Crc64Length, nameof(crc64));
expectedBufferSize += Crc64Length;
}
Errors.AssertBufferExactSize(buffer, expectedBufferSize, nameof(buffer));

if (!crc64.IsEmpty)
{
buffer.Slice(0, Crc64Length).CopyTo(crc64);
}
}

public static int WriteStreamFooter(Span<byte> buffer, ReadOnlySpan<byte> crc64 = default)
{
int requiredSpace = 0;
if (!crc64.IsEmpty)
{
Errors.AssertBufferExactSize(crc64, Crc64Length, nameof(crc64));
requiredSpace += Crc64Length;
}

Errors.AssertBufferMinimumSize(buffer, requiredSpace, nameof(buffer));
int offset = 0;
if (!crc64.IsEmpty)
{
crc64.CopyTo(buffer.Slice(offset, Crc64Length));
offset += Crc64Length;
}

return offset;
}

/// <summary>
/// Gets stream header in a buffer rented from the provided ArrayPool.
/// </summary>
/// <returns>
/// Disposable to return the buffer to the pool.
/// </returns>
public static IDisposable GetStreamFooterBytes(
ArrayPool<byte> pool,
out Memory<byte> bytes,
ReadOnlySpan<byte> crc64 = default)
{
Argument.AssertNotNull(pool, nameof(pool));
IDisposable disposable = pool.RentAsMemoryDisposable(StreamHeaderLength, out bytes);
WriteStreamFooter(bytes.Span, crc64);
return disposable;
}
#endregion

#region SegmentHeader
public static void ReadSegmentHeader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ internal class StructuredMessageDecodingStream : Stream
private int _metadataBufferLength = 0;

private int _streamHeaderLength;
// private readonly int _streamFooterLength; // no stream footer in 1.0
private int _streamFooterLength;
private int _segmentHeaderLength;
private int _segmentFooterLength;
private int _totalSegments;
Expand Down Expand Up @@ -170,7 +170,7 @@ public override async ValueTask<int> ReadAsync(Memory<byte> buf, CancellationTok
private long CurrentRegionLength => _currentRegion switch
{
SMRegion.StreamHeader => _streamHeaderLength,
SMRegion.StreamFooter => 0,
SMRegion.StreamFooter => _streamFooterLength,
SMRegion.SegmentHeader => _segmentHeaderLength,
SMRegion.SegmentFooter => _segmentFooterLength,
SMRegion.SegmentContent => _currentSegmentContentLength,
Expand Down Expand Up @@ -229,6 +229,7 @@ private int Decode(Span<byte> buffer)
SMRegion.SegmentFooter => ProcessSegmentFooter(buffer.Slice(bufferConsumed)),
_ => 0,
};
// TODO surface error if processed is 0
gaps.Add((bufferConsumed, processed));
bufferConsumed += processed;
}
Expand Down Expand Up @@ -335,9 +336,10 @@ private int ProcessStreamHeader(ReadOnlySpan<byte> span)
out _innerStreamLength,
out _flags,
out _totalSegments);
if (_flags.HasFlag(StructuredMessage.Flags.CrcSegment))
if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64))
{
_segmentFooterLength = _flags.HasFlag(StructuredMessage.Flags.CrcSegment) ? StructuredMessage.Crc64Length : 0;
_segmentFooterLength = _flags.HasFlag(StructuredMessage.Flags.StorageCrc64) ? StructuredMessage.Crc64Length : 0;
_streamFooterLength = _flags.HasFlag(StructuredMessage.Flags.StorageCrc64) ? StructuredMessage.Crc64Length : 0;
_segmentCrc = StorageCrc64HashAlgorithm.Create();
_totalContentCrc = StorageCrc64HashAlgorithm.Create();
}
Expand All @@ -347,7 +349,21 @@ private int ProcessStreamHeader(ReadOnlySpan<byte> span)

private int ProcessStreamFooter(ReadOnlySpan<byte> span)
{
return 0;
int totalProcessed = 0;
if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64))
{
totalProcessed += StructuredMessage.Crc64Length;
using (ArrayPool<byte>.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span<byte> calculated))
{
_totalContentCrc.GetCurrentHash(calculated);
ReadOnlySpan<byte> expected = span.Slice(0, StructuredMessage.Crc64Length);
if (!calculated.SequenceEqual(expected))
{
throw Errors.ChecksumMismatch(calculated, expected);
}
}
}
return totalProcessed;
}

private int ProcessSegmentHeader(ReadOnlySpan<byte> span)
Expand All @@ -369,7 +385,7 @@ private int ProcessSegmentHeader(ReadOnlySpan<byte> span)
private int ProcessSegmentFooter(ReadOnlySpan<byte> span)
{
int totalProcessed = 0;
if (_flags.HasFlag(StructuredMessage.Flags.CrcSegment))
if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64))
{
totalProcessed += StructuredMessage.Crc64Length;
using (ArrayPool<byte>.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span<byte> calculated))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ internal class StructuredMessageEncodingStream : Stream
private readonly StructuredMessage.Flags _flags;
private bool _disposed;

private bool UseCrcSegment => _flags.HasFlag(StructuredMessage.Flags.CrcSegment);
private readonly StorageCrc64HashAlgorithm _runningCrc;
private readonly byte[] _runningCrcCheckpoints;
private bool UseCrcSegment => _flags.HasFlag(StructuredMessage.Flags.StorageCrc64);
private readonly StorageCrc64HashAlgorithm _totalCrc;
private StorageCrc64HashAlgorithm _segmentCrc;
private readonly byte[] _segmentCrcs;
private int _latestSegmentCrcd = 0;

#region Segments
Expand Down Expand Up @@ -197,16 +198,21 @@ public StructuredMessageEncodingStream(
_segmentContentLength = segmentContentLength;

_streamHeaderLength = StructuredMessage.V1_0.StreamHeaderLength;
_streamFooterLength = 0;
_streamFooterLength = UseCrcSegment ? StructuredMessage.Crc64Length : 0;
_segmentHeaderLength = StructuredMessage.V1_0.SegmentHeaderLength;
_segmentFooterLength = UseCrcSegment ? StructuredMessage.Crc64Length : 0;

if (UseCrcSegment)
{
_runningCrc = StorageCrc64HashAlgorithm.Create();
_runningCrcCheckpoints = ArrayPool<byte>.Shared.Rent(
_totalCrc = StorageCrc64HashAlgorithm.Create();
_segmentCrc = StorageCrc64HashAlgorithm.Create();
_segmentCrcs = ArrayPool<byte>.Shared.Rent(
GetTotalSegments(innerStream, segmentContentLength) * StructuredMessage.Crc64Length);
innerStream = ChecksumCalculatingStream.GetReadStream(innerStream, span => _runningCrc.Append(span));
innerStream = ChecksumCalculatingStream.GetReadStream(innerStream, span =>
{
_totalCrc.Append(span);
_segmentCrc.Append(span);
});
}

_innerStream = innerStream;
Expand Down Expand Up @@ -366,9 +372,22 @@ private int ReadFromStreamHeader(Span<byte> buffer)

private int ReadFromStreamFooter(Span<byte> buffer)
{
// method left intact for future stream footer content
// end of stream, no need to change _currentRegion
return 0;
int read = Math.Min(buffer.Length, _segmentFooterLength - _currentRegionPosition);
if (read <= 0)
{
return 0;
}

using IDisposable _ = StructuredMessage.V1_0.GetStreamFooterBytes(
ArrayPool<byte>.Shared,
out Memory<byte> footerBytes,
crc64: UseCrcSegment
? _totalCrc.GetCurrentHash() // TODO array pooling
: default);
footerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer);
_currentRegionPosition += read;

return read;
}

private int ReadFromSegmentHeader(Span<byte> buffer)
Expand Down Expand Up @@ -404,9 +423,9 @@ private int ReadFromSegmentFooter(Span<byte> buffer)
out Memory<byte> headerBytes,
crc64: UseCrcSegment
? new Span<byte>(
_runningCrcCheckpoints,
(CurrentEncodingSegment-1) * _runningCrc.HashLengthInBytes,
_runningCrc.HashLengthInBytes)
_segmentCrcs,
(CurrentEncodingSegment-1) * _totalCrc.HashLengthInBytes,
_totalCrc.HashLengthInBytes)
: default);
headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer);
_currentRegionPosition += read;
Expand All @@ -433,11 +452,12 @@ private void CleanupContentSegment()
_currentRegionPosition = 0;
if (UseCrcSegment && CurrentEncodingSegment - 1 == _latestSegmentCrcd)
{
_runningCrc.GetCurrentHash(new Span<byte>(
_runningCrcCheckpoints,
_latestSegmentCrcd * _runningCrc.HashLengthInBytes,
_runningCrc.HashLengthInBytes));
_segmentCrc.GetCurrentHash(new Span<byte>(
_segmentCrcs,
_latestSegmentCrcd * _segmentCrc.HashLengthInBytes,
_segmentCrc.HashLengthInBytes));
_latestSegmentCrcd++;
_segmentCrc = StorageCrc64HashAlgorithm.Create();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,17 @@ public async Task DecodesData(
[Values(true, false)] bool useCrc)
{
int segmentContentLength = seglen ?? int.MaxValue;
Flags flags = useCrc ? Flags.CrcSegment : Flags.None;
Flags flags = useCrc ? Flags.StorageCrc64 : Flags.None;

byte[] originalData = new byte[dataLength];
new Random().NextBytes(originalData);
byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentContentLength, flags);

Stream encodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData));
Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData));
byte[] decodedData;
using (MemoryStream dest = new())
{
await CopyStream(encodingStream, dest, readLen);
await CopyStream(decodingStream, dest, readLen);
decodedData = dest.ToArray();
}

Expand Down
Loading
Loading