From 3709500908b63631e7c9dfef5aa595bfe5c6d973 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Fri, 8 Dec 2023 09:05:08 -0800 Subject: [PATCH] Finished making dictionaries work in File and Memory implementations and updated tests. --- .../Ipc/ArrowFileReaderImplementation.cs | 74 +++++++++++---- .../Ipc/ArrowMemoryReaderImplementation.cs | 60 ++++++------ .../Ipc/ArrowReaderImplementation.cs | 49 ++++------ .../Ipc/ArrowStreamReaderImplementation.cs | 94 ++++++++++++------- csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs | 38 +++++++- .../ArrowReaderBenchmark.cs | 2 +- .../ArrowFileReaderTests.cs | 6 +- csharp/test/Apache.Arrow.Tests/TestData.cs | 2 +- 8 files changed, 209 insertions(+), 116 deletions(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs index 5ccdc38120d63..18e22c9eeb2b9 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs @@ -42,7 +42,7 @@ public ArrowFileReaderImplementation(Stream stream, MemoryAllocator allocator, I { } - public async ValueTask RecordBatchCountAsync() + public async ValueTask RecordBatchCountAsync(CancellationToken cancellationToken = default) { if (!HasReadSchema) { @@ -145,7 +145,7 @@ private void ReadSchema(Memory buffer) public async ValueTask ReadRecordBatchAsync(int index, CancellationToken cancellationToken) { await ReadSchemaAsync().ConfigureAwait(false); - await ReadDictionariesAsync().ConfigureAwait(false); + await ReadDictionariesAsync(cancellationToken).ConfigureAwait(false); if (index >= _footer.RecordBatchCount) { @@ -179,7 +179,7 @@ public RecordBatch ReadRecordBatch(int index) public override async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) { await ReadSchemaAsync().ConfigureAwait(false); - await ReadDictionariesAsync().ConfigureAwait(false); + await ReadDictionariesAsync(cancellationToken).ConfigureAwait(false); if (_recordBatchIndex >= _footer.RecordBatchCount) { @@ -208,22 +208,43 @@ public override RecordBatch ReadNextRecordBatch() return result; } - private async ValueTask ReadDictionariesAsync() + private async ValueTask ReadDictionariesAsync(CancellationToken cancellationToken = default) { if (HasReadDictionaries) { return; } - // We don't know in what order the dictionaries have been serialized, so we deserialize - // just their indices and then construct them in X order - foreach (Block block in _footer.Dictionaries) + int index = 0; + while (index < _footer.DictionaryCount) { - BaseStream.Position = block.Offset; - await ReadRecordBatchAsync(deferDictionaryLoad: true).ConfigureAwait(false); + index = await ReadNextDictionaryAsync(index, cancellationToken).ConfigureAwait(false); } + } + + private async ValueTask ReadNextDictionaryAsync(int index, CancellationToken cancellationToken) + { + Block block = _footer.Dictionaries[index++]; + BaseStream.Position = block.Offset; + await ReadMessageAsync(async (message, cancellationToken) => + { + if (message.HeaderType != Flatbuf.MessageHeader.DictionaryBatch) + { + return null; + } + Flatbuf.DictionaryBatch dictionaryBatch = message.Header().Value; - DictionaryMemo.FinishLoad(); + long position = BaseStream.Position; + while (!DictionaryMemo.CanLoad(dictionaryBatch.Id)) + { + // recursive load + index = await ReadNextDictionaryAsync(index, cancellationToken); + } + BaseStream.Position = position; + return await CreateArrowObjectAsync(message, cancellationToken); + }, cancellationToken).ConfigureAwait(false); + + return index; } private void ReadDictionaries() @@ -233,15 +254,36 @@ private void ReadDictionaries() return; } - // We don't know in what order the dictionaries have been serialized, so we deserialize - // just their indices and then construct them in X order - foreach (Block block in _footer.Dictionaries) + int index = 0; + while (index < _footer.DictionaryCount) { - BaseStream.Position = block.Offset; - ReadRecordBatch(deferDictionaryLoad: true); + index = ReadNextDictionary(index); } + } + + private int ReadNextDictionary(int index) + { + Block block = _footer.Dictionaries[index++]; + BaseStream.Position = block.Offset; + ReadMessage(message => + { + if (message.HeaderType != Flatbuf.MessageHeader.DictionaryBatch) + { + return null; + } + Flatbuf.DictionaryBatch dictionaryBatch = message.Header().Value; + + long position = BaseStream.Position; + while (!DictionaryMemo.CanLoad(dictionaryBatch.Id)) + { + // recursive load + index = ReadNextDictionary(index); + } + BaseStream.Position = position; + return CreateArrowObject(message); + }); - DictionaryMemo.FinishLoad(); + return index; } /// diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs index af4f963ee520f..6e2336a591bf1 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs @@ -43,30 +43,17 @@ public override RecordBatch ReadNextRecordBatch() { ReadSchema(); - if (_buffer.Length <= _bufferPosition + sizeof(int)) + RecordBatch batch = null; + while (batch == null) { - // reached the end - return null; - } - - // Get Length of record batch for message header. - int messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); - _bufferPosition += sizeof(int); - - if (messageLength == 0) - { - //reached the end - return null; - } - else if (messageLength == MessageSerializer.IpcContinuationToken) - { - // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length if (_buffer.Length <= _bufferPosition + sizeof(int)) { - throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message."); + // reached the end + return null; } - messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); + // Get Length of record batch for message header. + int messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); _bufferPosition += sizeof(int); if (messageLength == 0) @@ -74,17 +61,36 @@ public override RecordBatch ReadNextRecordBatch() //reached the end return null; } - } + else if (messageLength == MessageSerializer.IpcContinuationToken) + { + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (_buffer.Length <= _bufferPosition + sizeof(int)) + { + throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message."); + } + + messageLength = BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition)); + _bufferPosition += sizeof(int); + + if (messageLength == 0) + { + //reached the end + return null; + } + } + + Message message = Message.GetRootAsMessage( + CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength))); + _bufferPosition += messageLength; - Message message = Message.GetRootAsMessage( - CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength))); - _bufferPosition += messageLength; + int bodyLength = (int)message.BodyLength; + ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength)); + _bufferPosition += bodyLength; - int bodyLength = (int)message.BodyLength; - ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength)); - _bufferPosition += bodyLength; + batch = CreateArrowObjectFromMessage(message, bodybb, memoryOwner: null); + } - return CreateArrowObjectFromMessage(message, bodybb, memoryOwner: null); + return batch; } private void ReadSchema() diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index b981ee79dcfff..d3115da52cc6c 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -107,10 +107,7 @@ private static bool MatchEnum(Flatbuf.MessageHeader messageHeader, Type flatBuff /// Null when the message type is not RecordBatch. /// protected RecordBatch CreateArrowObjectFromMessage( - Flatbuf.Message message, - ByteBuffer bodyByteBuffer, - IMemoryOwner memoryOwner, - bool deferDictionaryLoad = false) + Flatbuf.Message message, ByteBuffer bodyByteBuffer, IMemoryOwner memoryOwner) { switch (message.HeaderType) { @@ -119,11 +116,11 @@ protected RecordBatch CreateArrowObjectFromMessage( break; case Flatbuf.MessageHeader.DictionaryBatch: Flatbuf.DictionaryBatch dictionaryBatch = message.Header().Value; - ReadDictionaryBatch(message.Version, dictionaryBatch, bodyByteBuffer, memoryOwner, deferDictionaryLoad); + ReadDictionaryBatch(message.Version, dictionaryBatch, bodyByteBuffer, memoryOwner); break; case Flatbuf.MessageHeader.RecordBatch: Flatbuf.RecordBatch rb = message.Header().Value; - List arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb, deferDictionaryLoad); + List arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb); return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length); default: // NOTE: Skip unsupported message type @@ -143,8 +140,7 @@ private void ReadDictionaryBatch( MetadataVersion version, Flatbuf.DictionaryBatch dictionaryBatch, ByteBuffer bodyByteBuffer, - IMemoryOwner memoryOwner, - bool deferDictionaryLoad) + IMemoryOwner memoryOwner) { long id = dictionaryBatch.Id; IArrowType valueType = DictionaryMemo.GetDictionaryType(id); @@ -157,18 +153,14 @@ private void ReadDictionaryBatch( Field valueField = new Field("dummy", valueType, true); var schema = new Schema(new[] { valueField }, default); - IList arrays = BuildArrays(version, schema, bodyByteBuffer, recordBatch.Value, deferDictionaryLoad); + IList arrays = BuildArrays(version, schema, bodyByteBuffer, recordBatch.Value); if (arrays.Count != 1) { throw new InvalidDataException("Dictionary record batch must contain only one field"); } - if (deferDictionaryLoad) - { - DictionaryMemo.AddDictionaryValues(id, arrays[0].Data); - } - else if (dictionaryBatch.IsDelta) + if (dictionaryBatch.IsDelta) { DictionaryMemo.AddDeltaDictionary(id, arrays[0], _allocator); } @@ -182,8 +174,7 @@ private List BuildArrays( MetadataVersion version, Schema schema, ByteBuffer messageBuffer, - Flatbuf.RecordBatch recordBatchMessage, - bool deferDictionaryLoad) + Flatbuf.RecordBatch recordBatchMessage) { var arrays = new List(recordBatchMessage.NodesLength); @@ -201,8 +192,8 @@ private List BuildArrays( Flatbuf.FieldNode fieldNode = recordBatchEnumerator.CurrentNode; ArrayData arrayData = field.DataType.IsFixedPrimitive() - ? LoadPrimitiveField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator, deferDictionaryLoad) - : LoadVariableField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator, deferDictionaryLoad); + ? LoadPrimitiveField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator) + : LoadVariableField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator); arrays.Add(ArrowArrayFactory.BuildArray(arrayData)); } while (recordBatchEnumerator.MoveNextNode()); @@ -244,8 +235,7 @@ private ArrayData LoadPrimitiveField( Field field, in Flatbuf.FieldNode fieldNode, ByteBuffer bodyData, - IBufferCreator bufferCreator, - bool deferDictionaryLoad) + IBufferCreator bufferCreator) { int fieldLength = (int)fieldNode.Length; @@ -298,10 +288,10 @@ private ArrayData LoadPrimitiveField( recordBatchEnumerator.MoveNextBuffer(); } - ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator, deferDictionaryLoad); + ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator); IArrowArray dictionary = null; - if (field.DataType.TypeId == ArrowTypeId.Dictionary && !deferDictionaryLoad) + if (field.DataType.TypeId == ArrowTypeId.Dictionary) { long id = DictionaryMemo.GetId(field); dictionary = DictionaryMemo.GetDictionary(id); @@ -316,9 +306,9 @@ private ArrayData LoadVariableField( Field field, in Flatbuf.FieldNode fieldNode, ByteBuffer bodyData, - IBufferCreator bufferCreator, - bool deferDictionaryLoad) + IBufferCreator bufferCreator) { + ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator); if (!recordBatchEnumerator.MoveNextBuffer()) { @@ -346,10 +336,10 @@ private ArrayData LoadVariableField( } ArrowBuffer[] arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer }; - ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator, deferDictionaryLoad); + ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator); IArrowArray dictionary = null; - if (field.DataType.TypeId == ArrowTypeId.Dictionary && !deferDictionaryLoad) + if (field.DataType.TypeId == ArrowTypeId.Dictionary) { long id = DictionaryMemo.GetId(field); dictionary = DictionaryMemo.GetDictionary(id); @@ -363,8 +353,7 @@ private ArrayData[] GetChildren( ref RecordBatchEnumerator recordBatchEnumerator, Field field, ByteBuffer bodyData, - IBufferCreator bufferCreator, - bool deferDictionaryLoad) + IBufferCreator bufferCreator) { if (!(field.DataType is NestedType type)) return null; @@ -377,8 +366,8 @@ private ArrayData[] GetChildren( Field childField = type.Fields[index]; ArrayData child = childField.DataType.IsFixedPrimitive() - ? LoadPrimitiveField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator, deferDictionaryLoad) - : LoadVariableField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator, deferDictionaryLoad); + ? LoadPrimitiveField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator) + : LoadVariableField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator); children[index] = child; } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 56a5dd124a74e..647a7b11b214c 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -57,17 +57,18 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca { await ReadSchemaAsync().ConfigureAwait(false); - RecordBatch result = null; - - while (result == null) + ReadResult result = default; + do { - result = await ReadRecordBatchAsync(deferDictionaryLoad: false, cancellationToken); - } + result = await ReadMessageAsync(CreateArrowObjectAsync, cancellationToken).ConfigureAwait(false); + } while (result.Batch == null && result.MessageLength > 0); - return result; + return result.Batch; } - protected async ValueTask ReadRecordBatchAsync(bool deferDictionaryLoad, CancellationToken cancellationToken = default) + protected async ValueTask ReadMessageAsync( + Func> ctor, + CancellationToken cancellationToken) { int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) .ConfigureAwait(false); @@ -75,7 +76,7 @@ protected async ValueTask ReadRecordBatchAsync(bool deferDictionary if (messageLength == 0) { // reached end - return null; + return default; } RecordBatch result = null; @@ -87,43 +88,47 @@ await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); - int bodyLength = checked((int)message.BodyLength); + result = await ctor(message, cancellationToken); + }).ConfigureAwait(false); - IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); - Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); - bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken) - .ConfigureAwait(false); - EnsureFullRead(bodyBuff, bytesRead); + return new ReadResult(messageLength, result); + } - Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); - result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner, deferDictionaryLoad); - }).ConfigureAwait(false); + protected async ValueTask CreateArrowObjectAsync(Flatbuf.Message message, CancellationToken cancellationToken = default) + { + int bodyLength = checked((int)message.BodyLength); + + IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); + Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); + int bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken) + .ConfigureAwait(false); + EnsureFullRead(bodyBuff, bytesRead); - return result; + Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); + return CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); } protected RecordBatch ReadRecordBatch() { ReadSchema(); - RecordBatch result = null; - - while (result == null) + ReadResult result = default; + do { - result = ReadRecordBatch(deferDictionaryLoad: false); - } + result = ReadMessage(CreateArrowObject); + } while (result.Batch == null && result.MessageLength > 0); - return result; + return result.Batch; } - protected RecordBatch ReadRecordBatch(bool deferDictionaryLoad) + protected ReadResult ReadMessage(Func ctor) { int messageLength = ReadMessageLength(throwOnFullRead: false); if (messageLength == 0) { // reached end - return null; + return default; } RecordBatch result = null; @@ -134,18 +139,23 @@ protected RecordBatch ReadRecordBatch(bool deferDictionaryLoad) Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); - int bodyLength = checked((int)message.BodyLength); + result = ctor(message); + }); + + return new ReadResult(messageLength, result); + } - IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); - Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); - bytesRead = BaseStream.ReadFullBuffer(bodyBuff); - EnsureFullRead(bodyBuff, bytesRead); + protected RecordBatch CreateArrowObject(Flatbuf.Message message) + { + int bodyLength = checked((int)message.BodyLength); - Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); - result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner, deferDictionaryLoad); - }); + IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); + Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); + int bytesRead = BaseStream.ReadFullBuffer(bodyBuff); + EnsureFullRead(bodyBuff, bytesRead); - return result; + Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); + return CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); } protected virtual async ValueTask ReadSchemaAsync() @@ -230,7 +240,7 @@ await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => return messageLength; } - protected int ReadMessageLength(bool throwOnFullRead) + private int ReadMessageLength(bool throwOnFullRead) { int messageLength = 0; ArrayPool.Shared.RentReturn(4, lengthBuffer => @@ -280,5 +290,17 @@ internal static void EnsureFullRead(Memory buffer, int bytesRead) throw new InvalidOperationException("Unexpectedly reached the end of the stream before a full buffer was read."); } } + + internal struct ReadResult + { + public readonly int MessageLength; + public readonly RecordBatch Batch; + + public ReadResult(int messageLength, RecordBatch batch) + { + MessageLength = messageLength; + Batch = batch; + } + } } } diff --git a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs index 3db3cad43acf7..f1df993f52c74 100644 --- a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs +++ b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs @@ -117,12 +117,46 @@ public void AddDeltaDictionary(long id, IArrowArray deltaDictionary, MemoryAlloc AddOrReplaceDictionary(id, dictionary); } - public void AddDictionaryValues(long id, ArrayData values) + // Returns true if the corresponding dictionaries have been loaded + public bool CanLoad(long id) { + IArrowType type = GetDictionaryType(id); + if (type is NestedType) + { + NestedTypeVisitor visitor = new NestedTypeVisitor(this); + type.Accept(visitor); + return visitor.CanLoad; + } + + return true; } - public void FinishLoad() + private sealed class NestedTypeVisitor : IArrowTypeVisitor { + private readonly DictionaryMemo _memo; + public bool CanLoad { get; private set; } + + public NestedTypeVisitor(DictionaryMemo memo) + { + _memo = memo; + CanLoad = true; + } + + public void Visit(NestedType type) + { + foreach (Field field in type.Fields) + { + if (field.DataType is DictionaryType && ( + !_memo._fieldToId.TryGetValue(field, out long id) || + !_memo._idToDictionary.TryGetValue(id, out IArrowArray array))) + { + CanLoad = false; + break; + } + } + } + + public void Visit(IArrowType type) { } } } } diff --git a/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs b/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs index 4e491a2a6b128..cd8198d434cc7 100644 --- a/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs +++ b/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs @@ -38,7 +38,7 @@ public class ArrowReaderBenchmark [GlobalSetup] public async Task GlobalSetup() { - RecordBatch batch = TestData.CreateSampleRecordBatch(length: Count); + RecordBatch batch = TestData.CreateSampleRecordBatch(length: Count, createDictionaryArray: false); _memoryStream = new MemoryStream(); ArrowStreamWriter writer = new ArrowStreamWriter(_memoryStream, batch.Schema); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs index 2f2229ded4c46..585b1acc27f17 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs @@ -66,7 +66,7 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) ArrowFileReader reader = new ArrowFileReader(stream, memoryPool, leaveOpen: shouldLeaveOpen); reader.ReadNextRecordBatch(); - Assert.Equal(1, memoryPool.Statistics.Allocations); + Assert.Equal(2, memoryPool.Statistics.Allocations); Assert.True(memoryPool.Statistics.BytesAllocated > 0); reader.Dispose(); @@ -132,8 +132,8 @@ private static async Task TestReadRecordBatchHelper( [Fact] public async Task TestReadMultipleRecordBatchAsync() { - RecordBatch originalBatch1 = TestData.CreateSampleRecordBatch(length: 100); - RecordBatch originalBatch2 = TestData.CreateSampleRecordBatch(length: 50); + RecordBatch originalBatch1 = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false); + RecordBatch originalBatch2 = TestData.CreateSampleRecordBatch(length: 50, createDictionaryArray: false); using (MemoryStream stream = new MemoryStream()) { diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs b/csharp/test/Apache.Arrow.Tests/TestData.cs index 3af6efb97b437..79e886f0deabb 100644 --- a/csharp/test/Apache.Arrow.Tests/TestData.cs +++ b/csharp/test/Apache.Arrow.Tests/TestData.cs @@ -23,7 +23,7 @@ namespace Apache.Arrow.Tests { public static class TestData { - public static RecordBatch CreateSampleRecordBatch(int length, bool createDictionaryArray = false) + public static RecordBatch CreateSampleRecordBatch(int length, bool createDictionaryArray = true) { return CreateSampleRecordBatch(length, columnSetCount: 1, createDictionaryArray); }