diff --git a/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs b/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs index 95a39439f7b20..51287674b2e70 100644 --- a/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs +++ b/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs @@ -16,46 +16,36 @@ using System; using System.Buffers; using System.Runtime.CompilerServices; -using System.Threading.Tasks; namespace Apache.Arrow { internal static class ArrayPoolExtensions { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void RentReturn(this ArrayPool pool, int length, Action> action) + public static ArrayLease RentReturn(this ArrayPool pool, int length, out Memory buffer) { - byte[] array = null; - - try - { - array = pool.Rent(length); - action(array.AsMemory(0, length)); - } - finally - { - if (array != null) - { - pool.Return(array); - } - } + byte[] array = pool.Rent(length); + buffer = array.AsMemory(0, length); + return new ArrayLease(pool, array); } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static async ValueTask RentReturnAsync(this ArrayPool pool, int length, Func, ValueTask> action) + internal struct ArrayLease : IDisposable { - byte[] array = null; + private readonly ArrayPool _pool; + private byte[] _array; - try + public ArrayLease(ArrayPool pool, byte[] array) { - array = pool.Rent(length); - await action(array.AsMemory(0, length)); + _pool = pool; + _array = array; } - finally + + public void Dispose() { - if (array != null) + if (_array != null) { - pool.Return(array); + _pool.Return(_array); + _array = null; } } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs index 3ae475885f16a..02f36b079349b 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs @@ -42,47 +42,47 @@ public ArrowFileReaderImplementation(Stream stream, MemoryAllocator allocator, I { } - public async ValueTask RecordBatchCountAsync() + public async ValueTask RecordBatchCountAsync(CancellationToken cancellationToken = default) { if (!HasReadSchema) { - await ReadSchemaAsync().ConfigureAwait(false); + await ReadSchemaAsync(cancellationToken).ConfigureAwait(false); } return _footer.RecordBatchCount; } - protected override async ValueTask ReadSchemaAsync() + protected override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) { if (HasReadSchema) { return; } - await ValidateFileAsync().ConfigureAwait(false); + await ValidateFileAsync(cancellationToken).ConfigureAwait(false); int footerLength = 0; - await ArrayPool.Shared.RentReturnAsync(4, async (buffer) => + using (ArrayPool.Shared.RentReturn(4, out Memory buffer)) { BaseStream.Position = GetFooterLengthPosition(); - int bytesRead = await BaseStream.ReadFullBufferAsync(buffer).ConfigureAwait(false); + int bytesRead = await BaseStream.ReadFullBufferAsync(buffer, cancellationToken).ConfigureAwait(false); EnsureFullRead(buffer, bytesRead); footerLength = ReadFooterLength(buffer); - }).ConfigureAwait(false); + } - await ArrayPool.Shared.RentReturnAsync(footerLength, async (buffer) => + using (ArrayPool.Shared.RentReturn(footerLength, out Memory buffer)) { long footerStartPosition = GetFooterLengthPosition() - footerLength; BaseStream.Position = footerStartPosition; - int bytesRead = await BaseStream.ReadFullBufferAsync(buffer).ConfigureAwait(false); + int bytesRead = await BaseStream.ReadFullBufferAsync(buffer, cancellationToken).ConfigureAwait(false); EnsureFullRead(buffer, bytesRead); ReadSchema(buffer); - }).ConfigureAwait(false); + } } protected override void ReadSchema() @@ -95,7 +95,7 @@ protected override void ReadSchema() ValidateFile(); int footerLength = 0; - ArrayPool.Shared.RentReturn(4, (buffer) => + using (ArrayPool.Shared.RentReturn(4, out Memory buffer)) { BaseStream.Position = GetFooterLengthPosition(); @@ -103,9 +103,9 @@ protected override void ReadSchema() EnsureFullRead(buffer, bytesRead); footerLength = ReadFooterLength(buffer); - }); + } - ArrayPool.Shared.RentReturn(footerLength, (buffer) => + using (ArrayPool.Shared.RentReturn(footerLength, out Memory buffer)) { long footerStartPosition = GetFooterLengthPosition() - footerLength; @@ -115,7 +115,7 @@ protected override void ReadSchema() EnsureFullRead(buffer, bytesRead); ReadSchema(buffer); - }); + } } private long GetFooterLengthPosition() @@ -239,14 +239,14 @@ private void ReadDictionaries() /// /// Check if file format is valid. If it's valid don't run the validation again. /// - private async ValueTask ValidateFileAsync() + private async ValueTask ValidateFileAsync(CancellationToken cancellationToken = default) { if (IsFileValid) { return; } - await ValidateMagicAsync().ConfigureAwait(false); + await ValidateMagicAsync(cancellationToken).ConfigureAwait(false); IsFileValid = true; } @@ -266,20 +266,20 @@ private void ValidateFile() IsFileValid = true; } - private async ValueTask ValidateMagicAsync() + private async ValueTask ValidateMagicAsync(CancellationToken cancellationToken = default) { long startingPosition = BaseStream.Position; int magicLength = ArrowFileConstants.Magic.Length; try { - await ArrayPool.Shared.RentReturnAsync(magicLength, async (buffer) => + using (ArrayPool.Shared.RentReturn(magicLength, out Memory buffer)) { // Seek to the beginning of the stream BaseStream.Position = 0; // Read beginning of stream - await BaseStream.ReadAsync(buffer).ConfigureAwait(false); + await BaseStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); VerifyMagic(buffer); @@ -287,10 +287,10 @@ await ArrayPool.Shared.RentReturnAsync(magicLength, async (buffer) => BaseStream.Position = BaseStream.Length - magicLength; // Read the end of the stream - await BaseStream.ReadAsync(buffer).ConfigureAwait(false); + await BaseStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); VerifyMagic(buffer); - }).ConfigureAwait(false); + } } finally { @@ -305,7 +305,7 @@ private void ValidateMagic() try { - ArrayPool.Shared.RentReturn(magicLength, buffer => + using (ArrayPool.Shared.RentReturn(magicLength, out Memory buffer)) { // Seek to the beginning of the stream BaseStream.Position = 0; @@ -322,7 +322,7 @@ private void ValidateMagic() BaseStream.Read(buffer); VerifyMagic(buffer); - }); + } } finally { diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs index 95b9f60fffe0f..547fa800ec71e 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs @@ -215,7 +215,7 @@ private void WriteFooter(Schema schema) // Write footer length - Buffers.RentReturn(4, (buffer) => + using (Buffers.RentReturn(4, out Memory buffer)) { int footerLength; checked @@ -226,7 +226,7 @@ private void WriteFooter(Schema schema) BinaryPrimitives.WriteInt32LittleEndian(buffer.Span, footerLength); BaseStream.Write(buffer); - }); + } // Write magic @@ -286,7 +286,7 @@ private async Task WriteFooterAsync(Schema schema, CancellationToken cancellatio cancellationToken.ThrowIfCancellationRequested(); - await Buffers.RentReturnAsync(4, async (buffer) => + using (Buffers.RentReturn(4, out Memory buffer)) { int footerLength; checked @@ -297,7 +297,7 @@ await Buffers.RentReturnAsync(4, async (buffer) => BinaryPrimitives.WriteInt32LittleEndian(buffer.Span, footerLength); await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); - }).ConfigureAwait(false); + } // Write magic diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 184e0348e5e07..5428c88c27bbc 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -78,7 +78,7 @@ protected async ValueTask ReadMessageAsync(CancellationToken cancell } RecordBatch result = null; - await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) => + using (ArrayPool.Shared.RentReturn(messageLength, out Memory messageBuff)) { int bytesRead = await BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken) .ConfigureAwait(false); @@ -96,7 +96,7 @@ await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); - }).ConfigureAwait(false); + } return new ReadResult(messageLength, result); } @@ -125,7 +125,7 @@ protected ReadResult ReadMessage() } RecordBatch result = null; - ArrayPool.Shared.RentReturn(messageLength, messageBuff => + using (ArrayPool.Shared.RentReturn(messageLength, out Memory messageBuff)) { int bytesRead = BaseStream.ReadFullBuffer(messageBuff); EnsureFullRead(messageBuff, bytesRead); @@ -141,12 +141,12 @@ protected ReadResult ReadMessage() Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); - }); + } return new ReadResult(messageLength, result); } - protected virtual async ValueTask ReadSchemaAsync() + protected virtual async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) { if (HasReadSchema) { @@ -154,18 +154,18 @@ protected virtual async ValueTask ReadSchemaAsync() } // Figure out length of schema - int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true) + int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true, cancellationToken) .ConfigureAwait(false); - await ArrayPool.Shared.RentReturnAsync(schemaMessageLength, async (buff) => + using (ArrayPool.Shared.RentReturn(schemaMessageLength, out Memory buff)) { // Read in schema - int bytesRead = await BaseStream.ReadFullBufferAsync(buff).ConfigureAwait(false); + int bytesRead = await BaseStream.ReadFullBufferAsync(buff, cancellationToken).ConfigureAwait(false); EnsureFullRead(buff, bytesRead); Google.FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); - }).ConfigureAwait(false); + } } protected virtual void ReadSchema() @@ -178,20 +178,20 @@ protected virtual void ReadSchema() // Figure out length of schema int schemaMessageLength = ReadMessageLength(throwOnFullRead: true); - ArrayPool.Shared.RentReturn(schemaMessageLength, buff => + using (ArrayPool.Shared.RentReturn(schemaMessageLength, out Memory buff)) { int bytesRead = BaseStream.ReadFullBuffer(buff); EnsureFullRead(buff, bytesRead); Google.FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); - }); + } } private async ValueTask ReadMessageLengthAsync(bool throwOnFullRead, CancellationToken cancellationToken = default) { int messageLength = 0; - await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => + using (ArrayPool.Shared.RentReturn(4, out Memory lengthBuffer)) { int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken) .ConfigureAwait(false); @@ -201,7 +201,7 @@ await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => } else if (bytesRead != 4) { - return; + return 0; } messageLength = BitUtility.ReadInt32(lengthBuffer); @@ -217,13 +217,12 @@ await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => } else if (bytesRead != 4) { - messageLength = 0; - return; + return 0; } messageLength = BitUtility.ReadInt32(lengthBuffer); } - }).ConfigureAwait(false); + }; return messageLength; } @@ -231,7 +230,7 @@ await ArrayPool.Shared.RentReturnAsync(4, async (lengthBuffer) => private int ReadMessageLength(bool throwOnFullRead) { int messageLength = 0; - ArrayPool.Shared.RentReturn(4, lengthBuffer => + using (ArrayPool.Shared.RentReturn(4, out Memory lengthBuffer)) { int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer); if (throwOnFullRead) @@ -240,7 +239,7 @@ private int ReadMessageLength(bool throwOnFullRead) } else if (bytesRead != 4) { - return; + return 0; } messageLength = BitUtility.ReadInt32(lengthBuffer); @@ -255,13 +254,12 @@ private int ReadMessageLength(bool throwOnFullRead) } else if (bytesRead != 4) { - messageLength = 0; - return; + return 0; } messageLength = BitUtility.ReadInt32(lengthBuffer); } - }); + } return messageLength; } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 483dcea898fbe..5f490019b2133 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -890,7 +890,7 @@ private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancell private void WriteIpcMessageLength(int length) { - Buffers.RentReturn(_options.SizeOfIpcLength, (buffer) => + using (Buffers.RentReturn(_options.SizeOfIpcLength, out Memory buffer)) { Memory currentBufferPosition = buffer; if (!_options.WriteLegacyIpcFormat) @@ -902,12 +902,12 @@ private void WriteIpcMessageLength(int length) BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length); BaseStream.Write(buffer); - }); + } } private async ValueTask WriteIpcMessageLengthAsync(int length, CancellationToken cancellationToken) { - await Buffers.RentReturnAsync(_options.SizeOfIpcLength, async (buffer) => + using (Buffers.RentReturn(_options.SizeOfIpcLength, out Memory buffer)) { Memory currentBufferPosition = buffer; if (!_options.WriteLegacyIpcFormat) @@ -919,7 +919,7 @@ await Buffers.RentReturnAsync(_options.SizeOfIpcLength, async (buffer) => BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length); await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); - }).ConfigureAwait(false); + } } protected int CalculatePadding(long offset, int alignment = 8)