diff --git a/src/Nerdbank.MessagePack/Converters/PrimitiveConverters.cs b/src/Nerdbank.MessagePack/Converters/PrimitiveConverters.cs index 0dac3f2f..3326f8f1 100644 --- a/src/Nerdbank.MessagePack/Converters/PrimitiveConverters.cs +++ b/src/Nerdbank.MessagePack/Converters/PrimitiveConverters.cs @@ -4,6 +4,7 @@ #pragma warning disable SA1649 // File name should match first type name #pragma warning disable SA1402 // File may only contain a single class +using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Numerics; using System.Text; @@ -18,12 +19,92 @@ namespace Nerdbank.MessagePack.Converters; /// internal class StringConverter : MessagePackConverter { +#if NET + /// + public override bool PreferAsyncSerialization => true; +#endif + /// public override string? Read(ref MessagePackReader reader, SerializationContext context) => reader.ReadString(); /// public override void Write(ref MessagePackWriter writer, in string? value, SerializationContext context) => writer.Write(value); +#if NET + /// + [Experimental("NBMsgPackAsync")] + public override async ValueTask ReadAsync(MessagePackAsyncReader reader, SerializationContext context) + { + const uint MinChunkSize = 2048; + + MessagePackStreamingReader streamingReader = reader.CreateStreamingReader(); + bool wasNil; + if (streamingReader.TryReadNil(out wasNil).NeedsMoreBytes()) + { + streamingReader = new(await streamingReader.FetchMoreBytesAsync().ConfigureAwait(false)); + } + + if (wasNil) + { + reader.ReturnReader(ref streamingReader); + return null; + } + + uint length; + while (streamingReader.TryReadStringHeader(out length).NeedsMoreBytes()) + { + streamingReader = new(await streamingReader.FetchMoreBytesAsync().ConfigureAwait(false)); + } + + string result; + if (streamingReader.TryReadRaw(length, out ReadOnlySequence utf8BytesSequence).NeedsMoreBytes()) + { + uint remainingBytesToDecode = length; + using SequencePool.Rental sequenceRental = SequencePool.Shared.Rent(); + Sequence charSequence = sequenceRental.Value; + Decoder decoder = StringEncoding.UTF8.GetDecoder(); + while (remainingBytesToDecode > 0) + { + // We'll always require at least a reasonable numbe of bytes to decode at once, + // to keep overhead to a minimum. + uint desiredBytesThisRound = Math.Min(remainingBytesToDecode, MinChunkSize); + if (streamingReader.SequenceReader.Remaining < desiredBytesThisRound) + { + // We don't have enough bytes to decode this round. Fetch more. + streamingReader = new(await streamingReader.FetchMoreBytesAsync(desiredBytesThisRound).ConfigureAwait(false)); + } + + int thisLoopLength = unchecked((int)Math.Min(int.MaxValue, Math.Min(checked((uint)streamingReader.SequenceReader.Remaining), remainingBytesToDecode))); + Assumes.True(streamingReader.TryReadRaw(thisLoopLength, out utf8BytesSequence) == MessagePackPrimitives.DecodeResult.Success); + bool flush = utf8BytesSequence.Length == remainingBytesToDecode; + decoder.Convert(utf8BytesSequence, charSequence, flush, out _, out _); + remainingBytesToDecode -= checked((uint)utf8BytesSequence.Length); + } + + result = string.Create( + checked((int)charSequence.Length), + charSequence, + static (span, seq) => seq.AsReadOnlySequence.CopyTo(span)); + } + else + { + // We happened to get all bytes at once. Decode now. + result = StringEncoding.UTF8.GetString(utf8BytesSequence); + } + + reader.ReturnReader(ref streamingReader); + return result; + } + + /// + [Experimental("NBMsgPackAsync")] + public override ValueTask WriteAsync(MessagePackAsyncWriter writer, string? value, SerializationContext context) + { + // We *could* do incremental string encoding, flushing periodically based on the user's preferred flush threshold. + return base.WriteAsync(writer, value, context); + } +#endif + /// public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => new() { ["type"] = "string" }; } diff --git a/test/Nerdbank.MessagePack.Tests/AsyncSerializationTests.cs b/test/Nerdbank.MessagePack.Tests/AsyncSerializationTests.cs index 8bfdd15a..c98fc20d 100644 --- a/test/Nerdbank.MessagePack.Tests/AsyncSerializationTests.cs +++ b/test/Nerdbank.MessagePack.Tests/AsyncSerializationTests.cs @@ -55,6 +55,29 @@ public async Task WithPreBuffering() Assert.Equal(0, converter.AsyncDeserializationCounter); } + [Fact] + public async Task DecodeLargeString() + { + string expected = new string('a', 100 * 1024); + ReadOnlySequence msgpack = new(this.Serializer.Serialize(expected, TestContext.Current.CancellationToken)); + FragmentedPipeReader pipeReader = new(msgpack, msgpack.GetPosition(0), msgpack.GetPosition(1), msgpack.GetPosition(512), msgpack.GetPosition(6000), msgpack.GetPosition(32 * 1024)); + string? actual = await this.Serializer.DeserializeAsync(pipeReader, Witness.ShapeProvider, TestContext.Current.CancellationToken); + Assert.Equal(expected, actual); + } + + [Fact] + public async Task DecodeEmptyString() + { + string expected = string.Empty; + ReadOnlySequence msgpack = new(this.Serializer.Serialize(expected, TestContext.Current.CancellationToken)); + FragmentedPipeReader pipeReader = new(msgpack, msgpack.GetPosition(0)); + string? actual = await this.Serializer.DeserializeAsync(pipeReader, Witness.ShapeProvider, TestContext.Current.CancellationToken); + Assert.Equal(expected, actual); + } + + [GenerateShape] + private partial class Witness; + [GenerateShape] public partial record Poco(int X, int Y); diff --git a/test/Nerdbank.MessagePack.Tests/FragmentedPipeReader.cs b/test/Nerdbank.MessagePack.Tests/FragmentedPipeReader.cs index 10cdfa15..3ee2f833 100644 --- a/test/Nerdbank.MessagePack.Tests/FragmentedPipeReader.cs +++ b/test/Nerdbank.MessagePack.Tests/FragmentedPipeReader.cs @@ -13,6 +13,9 @@ internal class FragmentedPipeReader : PipeReader private readonly int? chunkSize; private readonly SequencePosition[]? chunkPositions; +#if NETFRAMEWORK + private readonly long[]? chunkIndexes; +#endif private SequencePosition consumed; private SequencePosition examined; @@ -32,6 +35,9 @@ public FragmentedPipeReader(ReadOnlySequence buffer, params SequencePositi this.buffer = buffer; this.consumed = this.examined = buffer.Start; this.chunkPositions = chunkPositions; +#if NETFRAMEWORK + this.chunkIndexes = [.. chunkPositions.Select(p => buffer.Slice(0, p).Length)]; +#endif } public override void AdvanceTo(SequencePosition consumed) => this.AdvanceTo(consumed, consumed); @@ -63,7 +69,12 @@ public override ValueTask ReadAsync(CancellationToken cancellationTo if (this.lastReadReturnedPosition.HasValue && this.examined.Equals(this.lastReadReturnedPosition.Value)) { // The caller has examined everything we gave them. Give them more. +#if NETFRAMEWORK + long examinedIndex = this.buffer.Slice(0, this.examined).Length; + int lastChunkGivenIndex = Array.IndexOf(this.chunkIndexes!, examinedIndex); +#else int lastChunkGivenIndex = Array.IndexOf(this.chunkPositions, this.examined); +#endif Assumes.True(lastChunkGivenIndex >= 0); chunkEnd = this.chunkPositions.Length > lastChunkGivenIndex + 1 ? this.chunkPositions[lastChunkGivenIndex + 1] : this.buffer.End; }