Skip to content

Commit

Permalink
Decode large strings incrementally in async deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
AArnott committed Jan 4, 2025
1 parent e830210 commit b01e5c3
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
81 changes: 81 additions & 0 deletions src/Nerdbank.MessagePack/Converters/PrimitiveConverters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma warning disable SA1649 // File name should match first type name
#pragma warning disable SA1402 // File may only contain a single class

using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Numerics;
using System.Text;
Expand All @@ -18,12 +19,92 @@ namespace Nerdbank.MessagePack.Converters;
/// </summary>
internal class StringConverter : MessagePackConverter<string>
{
#if NET
/// <inheritdoc/>
public override bool PreferAsyncSerialization => true;
#endif

/// <inheritdoc/>
public override string? Read(ref MessagePackReader reader, SerializationContext context) => reader.ReadString();

/// <inheritdoc/>
public override void Write(ref MessagePackWriter writer, in string? value, SerializationContext context) => writer.Write(value);

#if NET
/// <inheritdoc/>
[Experimental("NBMsgPackAsync")]
public override async ValueTask<string?> ReadAsync(MessagePackAsyncReader reader, SerializationContext context)
{
const uint MinChunkSize = 2048;

MessagePackStreamingReader streamingReader = reader.CreateStreamingReader();
bool wasNil;
if (streamingReader.TryReadNil(out wasNil).NeedsMoreBytes())
{
streamingReader = new(await streamingReader.FetchMoreBytesAsync().ConfigureAwait(false));
}

if (wasNil)
{
reader.ReturnReader(ref streamingReader);
return null;
}

uint length;
while (streamingReader.TryReadStringHeader(out length).NeedsMoreBytes())
{
streamingReader = new(await streamingReader.FetchMoreBytesAsync().ConfigureAwait(false));
}

string result;
if (streamingReader.TryReadRaw(length, out ReadOnlySequence<byte> utf8BytesSequence).NeedsMoreBytes())
{
uint remainingBytesToDecode = length;
using SequencePool<char>.Rental sequenceRental = SequencePool<char>.Shared.Rent();
Sequence<char> charSequence = sequenceRental.Value;
Decoder decoder = StringEncoding.UTF8.GetDecoder();
while (remainingBytesToDecode > 0)
{
// We'll always require at least a reasonable numbe of bytes to decode at once,
// to keep overhead to a minimum.
uint desiredBytesThisRound = Math.Min(remainingBytesToDecode, MinChunkSize);
if (streamingReader.SequenceReader.Remaining < desiredBytesThisRound)
{
// We don't have enough bytes to decode this round. Fetch more.
streamingReader = new(await streamingReader.FetchMoreBytesAsync(desiredBytesThisRound).ConfigureAwait(false));
}

int thisLoopLength = unchecked((int)Math.Min(int.MaxValue, Math.Min(checked((uint)streamingReader.SequenceReader.Remaining), remainingBytesToDecode)));
Assumes.True(streamingReader.TryReadRaw(thisLoopLength, out utf8BytesSequence) == MessagePackPrimitives.DecodeResult.Success);
bool flush = utf8BytesSequence.Length == remainingBytesToDecode;
decoder.Convert(utf8BytesSequence, charSequence, flush, out _, out _);
remainingBytesToDecode -= checked((uint)utf8BytesSequence.Length);
}

result = string.Create(
checked((int)charSequence.Length),
charSequence,
static (span, seq) => seq.AsReadOnlySequence.CopyTo(span));
}
else
{
// We happened to get all bytes at once. Decode now.
result = StringEncoding.UTF8.GetString(utf8BytesSequence);
}

reader.ReturnReader(ref streamingReader);
return result;
}

/// <inheritdoc/>
[Experimental("NBMsgPackAsync")]
public override ValueTask WriteAsync(MessagePackAsyncWriter writer, string? value, SerializationContext context)
{
// We *could* do incremental string encoding, flushing periodically based on the user's preferred flush threshold.
return base.WriteAsync(writer, value, context);
}
#endif

/// <inheritdoc/>
public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => new() { ["type"] = "string" };
}
Expand Down
23 changes: 23 additions & 0 deletions test/Nerdbank.MessagePack.Tests/AsyncSerializationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,29 @@ public async Task WithPreBuffering()
Assert.Equal(0, converter.AsyncDeserializationCounter);
}

[Fact]
public async Task DecodeLargeString()
{
string expected = new string('a', 100 * 1024);
ReadOnlySequence<byte> msgpack = new(this.Serializer.Serialize<string, Witness>(expected, TestContext.Current.CancellationToken));
FragmentedPipeReader pipeReader = new(msgpack, msgpack.GetPosition(0), msgpack.GetPosition(1), msgpack.GetPosition(512), msgpack.GetPosition(6000), msgpack.GetPosition(32 * 1024));
string? actual = await this.Serializer.DeserializeAsync<string>(pipeReader, Witness.ShapeProvider, TestContext.Current.CancellationToken);
Assert.Equal(expected, actual);
}

[Fact]
public async Task DecodeEmptyString()
{
string expected = string.Empty;
ReadOnlySequence<byte> msgpack = new(this.Serializer.Serialize<string, Witness>(expected, TestContext.Current.CancellationToken));
FragmentedPipeReader pipeReader = new(msgpack, msgpack.GetPosition(0));
string? actual = await this.Serializer.DeserializeAsync<string>(pipeReader, Witness.ShapeProvider, TestContext.Current.CancellationToken);
Assert.Equal(expected, actual);
}

[GenerateShape<string>]
private partial class Witness;

[GenerateShape]
public partial record Poco(int X, int Y);

Expand Down
11 changes: 11 additions & 0 deletions test/Nerdbank.MessagePack.Tests/FragmentedPipeReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ internal class FragmentedPipeReader : PipeReader
private readonly int? chunkSize;

private readonly SequencePosition[]? chunkPositions;
#if NETFRAMEWORK
private readonly long[]? chunkIndexes;
#endif

private SequencePosition consumed;
private SequencePosition examined;
Expand All @@ -32,6 +35,9 @@ public FragmentedPipeReader(ReadOnlySequence<byte> buffer, params SequencePositi
this.buffer = buffer;
this.consumed = this.examined = buffer.Start;
this.chunkPositions = chunkPositions;
#if NETFRAMEWORK
this.chunkIndexes = [.. chunkPositions.Select(p => buffer.Slice(0, p).Length)];
#endif
}

public override void AdvanceTo(SequencePosition consumed) => this.AdvanceTo(consumed, consumed);
Expand Down Expand Up @@ -63,7 +69,12 @@ public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationTo
if (this.lastReadReturnedPosition.HasValue && this.examined.Equals(this.lastReadReturnedPosition.Value))
{
// The caller has examined everything we gave them. Give them more.
#if NETFRAMEWORK
long examinedIndex = this.buffer.Slice(0, this.examined).Length;
int lastChunkGivenIndex = Array.IndexOf(this.chunkIndexes!, examinedIndex);
#else
int lastChunkGivenIndex = Array.IndexOf(this.chunkPositions, this.examined);
#endif
Assumes.True(lastChunkGivenIndex >= 0);
chunkEnd = this.chunkPositions.Length > lastChunkGivenIndex + 1 ? this.chunkPositions[lastChunkGivenIndex + 1] : this.buffer.End;
}
Expand Down

0 comments on commit b01e5c3

Please sign in to comment.