Skip to content

Commit

Permalink
Content validation update spec (#42191)
Browse files Browse the repository at this point in the history
* enum rename and footer read/write methods

* align encode/decode tests | update encoding stream

* decode stream footer

* rename
  • Loading branch information
jaschrep-msft authored Feb 28, 2024
1 parent 79e29b3 commit 034c9cd
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 120 deletions.
60 changes: 57 additions & 3 deletions sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Buffers;
using System.Buffers.Binary;
using System.IO;
using System.Security.Cryptography;
using Azure.Core;

namespace Azure.Storage.Shared;
Expand All @@ -18,7 +17,7 @@ internal static class StructuredMessage
public enum Flags
{
None = 0,
CrcSegment = 1,
StorageCrc64 = 1,
}

public static class V1_0
Expand Down Expand Up @@ -86,7 +85,62 @@ public static IDisposable GetStreamHeaderBytes(
}
#endregion

// no stream footer content in 1.0
#region StreamFooter
public static void ReadStreamFooter(
ReadOnlySpan<byte> buffer,
Span<byte> crc64 = default)
{
int expectedBufferSize = 0;
if (!crc64.IsEmpty)
{
Errors.AssertBufferExactSize(crc64, Crc64Length, nameof(crc64));
expectedBufferSize += Crc64Length;
}
Errors.AssertBufferExactSize(buffer, expectedBufferSize, nameof(buffer));

if (!crc64.IsEmpty)
{
buffer.Slice(0, Crc64Length).CopyTo(crc64);
}
}

public static int WriteStreamFooter(Span<byte> buffer, ReadOnlySpan<byte> crc64 = default)
{
int requiredSpace = 0;
if (!crc64.IsEmpty)
{
Errors.AssertBufferExactSize(crc64, Crc64Length, nameof(crc64));
requiredSpace += Crc64Length;
}

Errors.AssertBufferMinimumSize(buffer, requiredSpace, nameof(buffer));
int offset = 0;
if (!crc64.IsEmpty)
{
crc64.CopyTo(buffer.Slice(offset, Crc64Length));
offset += Crc64Length;
}

return offset;
}

/// <summary>
/// Gets stream header in a buffer rented from the provided ArrayPool.
/// </summary>
/// <returns>
/// Disposable to return the buffer to the pool.
/// </returns>
public static IDisposable GetStreamFooterBytes(
ArrayPool<byte> pool,
out Memory<byte> bytes,
ReadOnlySpan<byte> crc64 = default)
{
Argument.AssertNotNull(pool, nameof(pool));
IDisposable disposable = pool.RentAsMemoryDisposable(StreamHeaderLength, out bytes);
WriteStreamFooter(bytes.Span, crc64);
return disposable;
}
#endregion

#region SegmentHeader
public static void ReadSegmentHeader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ internal class StructuredMessageDecodingStream : Stream
private int _metadataBufferLength = 0;

private int _streamHeaderLength;
// private readonly int _streamFooterLength; // no stream footer in 1.0
private int _streamFooterLength;
private int _segmentHeaderLength;
private int _segmentFooterLength;
private int _totalSegments;
Expand Down Expand Up @@ -170,7 +170,7 @@ public override async ValueTask<int> ReadAsync(Memory<byte> buf, CancellationTok
private long CurrentRegionLength => _currentRegion switch
{
SMRegion.StreamHeader => _streamHeaderLength,
SMRegion.StreamFooter => 0,
SMRegion.StreamFooter => _streamFooterLength,
SMRegion.SegmentHeader => _segmentHeaderLength,
SMRegion.SegmentFooter => _segmentFooterLength,
SMRegion.SegmentContent => _currentSegmentContentLength,
Expand Down Expand Up @@ -229,6 +229,7 @@ private int Decode(Span<byte> buffer)
SMRegion.SegmentFooter => ProcessSegmentFooter(buffer.Slice(bufferConsumed)),
_ => 0,
};
// TODO surface error if processed is 0
gaps.Add((bufferConsumed, processed));
bufferConsumed += processed;
}
Expand Down Expand Up @@ -335,9 +336,10 @@ private int ProcessStreamHeader(ReadOnlySpan<byte> span)
out _innerStreamLength,
out _flags,
out _totalSegments);
if (_flags.HasFlag(StructuredMessage.Flags.CrcSegment))
if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64))
{
_segmentFooterLength = _flags.HasFlag(StructuredMessage.Flags.CrcSegment) ? StructuredMessage.Crc64Length : 0;
_segmentFooterLength = _flags.HasFlag(StructuredMessage.Flags.StorageCrc64) ? StructuredMessage.Crc64Length : 0;
_streamFooterLength = _flags.HasFlag(StructuredMessage.Flags.StorageCrc64) ? StructuredMessage.Crc64Length : 0;
_segmentCrc = StorageCrc64HashAlgorithm.Create();
_totalContentCrc = StorageCrc64HashAlgorithm.Create();
}
Expand All @@ -347,7 +349,21 @@ private int ProcessStreamHeader(ReadOnlySpan<byte> span)

private int ProcessStreamFooter(ReadOnlySpan<byte> span)
{
return 0;
int totalProcessed = 0;
if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64))
{
totalProcessed += StructuredMessage.Crc64Length;
using (ArrayPool<byte>.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span<byte> calculated))
{
_totalContentCrc.GetCurrentHash(calculated);
ReadOnlySpan<byte> expected = span.Slice(0, StructuredMessage.Crc64Length);
if (!calculated.SequenceEqual(expected))
{
throw Errors.ChecksumMismatch(calculated, expected);
}
}
}
return totalProcessed;
}

private int ProcessSegmentHeader(ReadOnlySpan<byte> span)
Expand All @@ -369,7 +385,7 @@ private int ProcessSegmentHeader(ReadOnlySpan<byte> span)
private int ProcessSegmentFooter(ReadOnlySpan<byte> span)
{
int totalProcessed = 0;
if (_flags.HasFlag(StructuredMessage.Flags.CrcSegment))
if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64))
{
totalProcessed += StructuredMessage.Crc64Length;
using (ArrayPool<byte>.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span<byte> calculated))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ internal class StructuredMessageEncodingStream : Stream
private readonly StructuredMessage.Flags _flags;
private bool _disposed;

private bool UseCrcSegment => _flags.HasFlag(StructuredMessage.Flags.CrcSegment);
private readonly StorageCrc64HashAlgorithm _runningCrc;
private readonly byte[] _runningCrcCheckpoints;
private bool UseCrcSegment => _flags.HasFlag(StructuredMessage.Flags.StorageCrc64);
private readonly StorageCrc64HashAlgorithm _totalCrc;
private StorageCrc64HashAlgorithm _segmentCrc;
private readonly byte[] _segmentCrcs;
private int _latestSegmentCrcd = 0;

#region Segments
Expand Down Expand Up @@ -197,16 +198,21 @@ public StructuredMessageEncodingStream(
_segmentContentLength = segmentContentLength;

_streamHeaderLength = StructuredMessage.V1_0.StreamHeaderLength;
_streamFooterLength = 0;
_streamFooterLength = UseCrcSegment ? StructuredMessage.Crc64Length : 0;
_segmentHeaderLength = StructuredMessage.V1_0.SegmentHeaderLength;
_segmentFooterLength = UseCrcSegment ? StructuredMessage.Crc64Length : 0;

if (UseCrcSegment)
{
_runningCrc = StorageCrc64HashAlgorithm.Create();
_runningCrcCheckpoints = ArrayPool<byte>.Shared.Rent(
_totalCrc = StorageCrc64HashAlgorithm.Create();
_segmentCrc = StorageCrc64HashAlgorithm.Create();
_segmentCrcs = ArrayPool<byte>.Shared.Rent(
GetTotalSegments(innerStream, segmentContentLength) * StructuredMessage.Crc64Length);
innerStream = ChecksumCalculatingStream.GetReadStream(innerStream, span => _runningCrc.Append(span));
innerStream = ChecksumCalculatingStream.GetReadStream(innerStream, span =>
{
_totalCrc.Append(span);
_segmentCrc.Append(span);
});
}

_innerStream = innerStream;
Expand Down Expand Up @@ -366,9 +372,22 @@ private int ReadFromStreamHeader(Span<byte> buffer)

private int ReadFromStreamFooter(Span<byte> buffer)
{
// method left intact for future stream footer content
// end of stream, no need to change _currentRegion
return 0;
int read = Math.Min(buffer.Length, _segmentFooterLength - _currentRegionPosition);
if (read <= 0)
{
return 0;
}

using IDisposable _ = StructuredMessage.V1_0.GetStreamFooterBytes(
ArrayPool<byte>.Shared,
out Memory<byte> footerBytes,
crc64: UseCrcSegment
? _totalCrc.GetCurrentHash() // TODO array pooling
: default);
footerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer);
_currentRegionPosition += read;

return read;
}

private int ReadFromSegmentHeader(Span<byte> buffer)
Expand Down Expand Up @@ -404,9 +423,9 @@ private int ReadFromSegmentFooter(Span<byte> buffer)
out Memory<byte> headerBytes,
crc64: UseCrcSegment
? new Span<byte>(
_runningCrcCheckpoints,
(CurrentEncodingSegment-1) * _runningCrc.HashLengthInBytes,
_runningCrc.HashLengthInBytes)
_segmentCrcs,
(CurrentEncodingSegment-1) * _totalCrc.HashLengthInBytes,
_totalCrc.HashLengthInBytes)
: default);
headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer);
_currentRegionPosition += read;
Expand All @@ -433,11 +452,12 @@ private void CleanupContentSegment()
_currentRegionPosition = 0;
if (UseCrcSegment && CurrentEncodingSegment - 1 == _latestSegmentCrcd)
{
_runningCrc.GetCurrentHash(new Span<byte>(
_runningCrcCheckpoints,
_latestSegmentCrcd * _runningCrc.HashLengthInBytes,
_runningCrc.HashLengthInBytes));
_segmentCrc.GetCurrentHash(new Span<byte>(
_segmentCrcs,
_latestSegmentCrcd * _segmentCrc.HashLengthInBytes,
_segmentCrc.HashLengthInBytes));
_latestSegmentCrcd++;
_segmentCrc = StorageCrc64HashAlgorithm.Create();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,17 @@ public async Task DecodesData(
[Values(true, false)] bool useCrc)
{
int segmentContentLength = seglen ?? int.MaxValue;
Flags flags = useCrc ? Flags.CrcSegment : Flags.None;
Flags flags = useCrc ? Flags.StorageCrc64 : Flags.None;

byte[] originalData = new byte[dataLength];
new Random().NextBytes(originalData);
byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentContentLength, flags);

Stream encodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData));
Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData));
byte[] decodedData;
using (MemoryStream dest = new())
{
await CopyStream(encodingStream, dest, readLen);
await CopyStream(decodingStream, dest, readLen);
decodedData = dest.ToArray();
}

Expand Down
Loading

0 comments on commit 034c9cd

Please sign in to comment.