Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(csharp): Implement remaining functions in 1.0 spec #1773

Merged
merged 2 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,29 @@ public AdbcStatement()
/// </summary>
public virtual byte[] SubstraitPlan
{
get { throw new NotImplementedException(); }
set { throw new NotImplementedException(); }
get { throw AdbcException.NotImplemented("Statement does not support SubstraitPlan"); }
set { throw AdbcException.NotImplemented("Statement does not support SubstraitPlan"); }
}

/// <summary>
/// Binds this statement to a <see cref="RecordBatch"/> to provide parameter values or bulk data ingestion.
/// </summary>
/// <param name="batch">the RecordBatch to bind</param>
/// <param name="schema">the schema of the RecordBatch</param>
public virtual void Bind(RecordBatch batch, Schema schema)
{
throw AdbcException.NotImplemented("Statement does not support Bind");
}

/// <summary>
/// Binds this statement to an <see cref="IArrowArrayStream"/> to provide parameter values or bulk data ingestion.
/// </summary>
/// <param name="stream"></param>
public virtual void BindStream(IArrowArrayStream stream)
{
throw AdbcException.NotImplemented("Statement does not support BindStream");
}

/// <summary>
/// Executes the statement and returns a tuple containing the number
/// of records and the <see cref="IArrowArrayStream"/>.
Expand Down
179 changes: 155 additions & 24 deletions csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Apache.Arrow.Adbc.Extensions;
using Apache.Arrow.C;
using Apache.Arrow.Ipc;

#if NETSTANDARD
using Apache.Arrow.Adbc.Extensions;
#endif

namespace Apache.Arrow.Adbc.C
{
public class CAdbcDriverExporter
Expand All @@ -38,6 +35,7 @@ public class CAdbcDriverExporter
#if NET5_0_OR_GREATER
private static unsafe delegate* unmanaged<CAdbcError*, void> ReleaseErrorPtr => (delegate* unmanaged<CAdbcError*, void>)s_releaseError.Pointer;
private static unsafe delegate* unmanaged<CAdbcDriver*, CAdbcError*, AdbcStatusCode> ReleaseDriverPtr => &ReleaseDriver;
private static unsafe delegate* unmanaged<CAdbcPartitions*, void> ReleasePartitionsPtr => &ReleasePartitions;

private static unsafe delegate* unmanaged<CAdbcDatabase*, CAdbcError*, AdbcStatusCode> DatabaseInitPtr => &InitDatabase;
private static unsafe delegate* unmanaged<CAdbcDatabase*, CAdbcError*, AdbcStatusCode> DatabaseReleasePtr => &ReleaseDatabase;
Expand All @@ -55,16 +53,23 @@ public class CAdbcDriverExporter
private static unsafe delegate* unmanaged<CAdbcConnection*, byte*, byte*, CAdbcError*, AdbcStatusCode> ConnectionSetOptionPtr => &SetConnectionOption;

private static unsafe delegate* unmanaged<CAdbcStatement*, CArrowArray*, CArrowSchema*, CAdbcError*, AdbcStatusCode> StatementBindPtr => &BindStatement;
private static unsafe delegate* unmanaged<CAdbcStatement*, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> StatementBindStreamPtr => &BindStreamStatement;
private static unsafe delegate* unmanaged<CAdbcStatement*, CArrowArrayStream*, long*, CAdbcError*, AdbcStatusCode> StatementExecuteQueryPtr => &ExecuteStatementQuery;
private static unsafe delegate* unmanaged<CAdbcStatement*, CArrowSchema*, CAdbcPartitions*, long*, CAdbcError*, AdbcStatusCode> StatementExecutePartitionsPtr => &ExecuteStatementPartitions;
private static unsafe delegate* unmanaged<CAdbcConnection*, CAdbcStatement*, CAdbcError*, AdbcStatusCode> StatementNewPtr => &NewStatement;
private static unsafe delegate* unmanaged<CAdbcStatement*, CAdbcError*, AdbcStatusCode> StatementReleasePtr => &ReleaseStatement;
private static unsafe delegate* unmanaged<CAdbcStatement*, CAdbcError*, AdbcStatusCode> StatementPreparePtr => &PrepareStatement;
private static unsafe delegate* unmanaged<CAdbcStatement*, byte*, CAdbcError*, AdbcStatusCode> StatementSetSqlQueryPtr => &SetStatementSqlQuery;
private static unsafe delegate* unmanaged<CAdbcStatement*, byte*, int, CAdbcError*, AdbcStatusCode> StatementSetSubstraitPlanPtr => &SetStatementSubstraitPlan;
private static unsafe delegate* unmanaged<CAdbcStatement*, CArrowSchema*, CAdbcError*, AdbcStatusCode> StatementGetParameterSchemaPtr => &GetStatementParameterSchema;
#else
private static IntPtr ReleaseErrorPtr => s_releaseError.Pointer;
internal unsafe delegate AdbcStatusCode DriverRelease(CAdbcDriver* driver, CAdbcError* error);
private static unsafe readonly NativeDelegate<DriverRelease> s_releaseDriver = new NativeDelegate<DriverRelease>(ReleaseDriver);
private static IntPtr ReleaseDriverPtr => s_releaseDriver.Pointer;
internal unsafe delegate void PartitionsRelease(CAdbcPartitions* partitions);
private static unsafe readonly NativeDelegate<PartitionsRelease> s_releasePartitions = new NativeDelegate<PartitionsRelease>(ReleasePartitions);
private static IntPtr ReleasePartitionsPtr => s_releasePartitions.Pointer;

private static unsafe readonly NativeDelegate<DatabaseFn> s_databaseInit = new NativeDelegate<DatabaseFn>(InitDatabase);
private static IntPtr DatabaseInitPtr => s_databaseInit.Pointer;
Expand Down Expand Up @@ -102,12 +107,18 @@ public class CAdbcDriverExporter
private static unsafe readonly NativeDelegate<ConnectionSetOption> s_connectionSetOption = new NativeDelegate<ConnectionSetOption>(SetConnectionOption);
private static IntPtr ConnectionSetOptionPtr => s_connectionSetOption.Pointer;

private unsafe delegate AdbcStatusCode StatementBind(CAdbcStatement* statement, CArrowArray* array, CArrowSchema* schema, CAdbcError* error);
internal unsafe delegate AdbcStatusCode StatementBind(CAdbcStatement* statement, CArrowArray* array, CArrowSchema* schema, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementBind> s_statementBind = new NativeDelegate<StatementBind>(BindStatement);
private static IntPtr StatementBindPtr => s_statementBind.Pointer;
internal unsafe delegate AdbcStatusCode StatementBindStream(CAdbcStatement* statement, CArrowArrayStream* stream, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementBindStream> s_statementBindStream = new NativeDelegate<StatementBindStream>(BindStreamStatement);
private static IntPtr StatementBindStreamPtr => s_statementBindStream.Pointer;
internal unsafe delegate AdbcStatusCode StatementExecuteQuery(CAdbcStatement* statement, CArrowArrayStream* stream, long* rows, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementExecuteQuery> s_statementExecuteQuery = new NativeDelegate<StatementExecuteQuery>(ExecuteStatementQuery);
private static IntPtr StatementExecuteQueryPtr = s_statementExecuteQuery.Pointer;
internal unsafe delegate AdbcStatusCode StatementExecutePartitions(CAdbcStatement* statement, CArrowSchema* schema, CAdbcPartitions* partitions, long* rows, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementExecutePartitions> s_statementExecutePartitions = new NativeDelegate<StatementExecutePartitions>(ExecuteStatementPartitions);
private static IntPtr StatementExecutePartitionsPtr = s_statementExecutePartitions.Pointer;
internal unsafe delegate AdbcStatusCode StatementNew(CAdbcConnection* connection, CAdbcStatement* statement, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementNew> s_statementNew = new NativeDelegate<StatementNew>(NewStatement);
private static IntPtr StatementNewPtr => s_statementNew.Pointer;
Expand All @@ -119,17 +130,14 @@ public class CAdbcDriverExporter
internal unsafe delegate AdbcStatusCode StatementSetSqlQuery(CAdbcStatement* statement, byte* text, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementSetSqlQuery> s_statementSetSqlQuery = new NativeDelegate<StatementSetSqlQuery>(SetStatementSqlQuery);
private static IntPtr StatementSetSqlQueryPtr = s_statementSetSqlQuery.Pointer;
internal unsafe delegate AdbcStatusCode StatementSetSubstraitPlan(CAdbcStatement* statement, byte* plan, int length, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementSetSubstraitPlan> s_statementSetSubstraitPlan = new NativeDelegate<StatementSetSubstraitPlan>(SetStatementSubstraitPlan);
private static IntPtr StatementSetSubstraitPlanPtr = s_statementSetSubstraitPlan.Pointer;
internal unsafe delegate AdbcStatusCode StatementGetParameterSchema(CAdbcStatement* statement, CArrowSchema* schema, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementGetParameterSchema> s_statementGetParameterSchema = new NativeDelegate<StatementGetParameterSchema>(GetStatementParameterSchema);
private static IntPtr StatementGetParameterSchemaPtr = s_statementGetParameterSchema.Pointer;
#endif

/*
* Not yet implemented

unsafe delegate AdbcStatusCode StatementBindStream(CAdbcStatement* statement, CArrowArrayStream* stream, CAdbcError* error);
unsafe delegate AdbcStatusCode StatementExecutePartitions(CAdbcStatement* statement, CArrowSchema* schema, CAdbcPartitions* partitions, long* rows_affected, CAdbcError* error);
unsafe delegate AdbcStatusCode StatementGetParameterSchema(CAdbcStatement* statement, CArrowSchema* schema, CAdbcError* error);
unsafe delegate AdbcStatusCode StatementSetSubstraitPlan(CAdbcStatement statement, byte* plan, int length, CAdbcError error);
*/

public unsafe static AdbcStatusCode AdbcDriverInit(int version, CAdbcDriver* nativeDriver, CAdbcError* error, AdbcDriver driver)
{
DriverStub stub = new DriverStub(driver);
Expand All @@ -142,7 +150,6 @@ public unsafe static AdbcStatusCode AdbcDriverInit(int version, CAdbcDriver* nat
nativeDriver->DatabaseSetOption = DatabaseSetOptionPtr;
nativeDriver->DatabaseRelease = DatabaseReleasePtr;

// TODO: This should probably only set the pointers for the functionality actually supported by this particular driver
nativeDriver->ConnectionCommit = ConnectionCommitPtr;
nativeDriver->ConnectionGetInfo = ConnectionGetInfoPtr;
nativeDriver->ConnectionGetObjects = ConnectionGetObjectsPtr;
Expand All @@ -156,15 +163,15 @@ public unsafe static AdbcStatusCode AdbcDriverInit(int version, CAdbcDriver* nat
nativeDriver->ConnectionRollback = ConnectionRollbackPtr;

nativeDriver->StatementBind = StatementBindPtr;
// nativeDriver->StatementBindStream = StatementBindStreamPtr;
nativeDriver->StatementBindStream = StatementBindStreamPtr;
nativeDriver->StatementExecuteQuery = StatementExecuteQueryPtr;
// nativeDriver->StatementExecutePartitions = StatementExecutePartitionsPtr;
// nativeDriver->StatementGetParameterSchema = StatementGetParameterSchemaPtr;
nativeDriver->StatementExecutePartitions = StatementExecutePartitionsPtr;
nativeDriver->StatementGetParameterSchema = StatementGetParameterSchemaPtr;
nativeDriver->StatementNew = StatementNewPtr;
nativeDriver->StatementPrepare = StatementPreparePtr;
nativeDriver->StatementRelease = StatementReleasePtr;
nativeDriver->StatementSetSqlQuery = StatementSetSqlQueryPtr;
// nativeDriver->StatementSetSubstraitPlan = StatementSetSubstraitPlanPtr;
nativeDriver->StatementSetSubstraitPlan = StatementSetSubstraitPlanPtr;

return 0;
}
Expand All @@ -181,12 +188,7 @@ private unsafe static AdbcStatusCode SetError(CAdbcError* error, Exception excep
{
ReleaseError(error);

#if NETSTANDARD
error->message = (byte*)MarshalExtensions.StringToCoTaskMemUTF8(exception.Message);
#else
error->message = (byte*)Marshal.StringToCoTaskMemUTF8(exception.Message);
#endif

error->sqlstate0 = (byte)0;
error->sqlstate1 = (byte)0;
error->sqlstate2 = (byte)0;
Expand Down Expand Up @@ -249,6 +251,37 @@ private unsafe static AdbcStatusCode ReleaseDriver(CAdbcDriver* nativeDriver, CA
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
private unsafe static void ReleasePartitions(CAdbcPartitions* partitions)
{
if (partitions != null)
{
if (partitions->partitions != null)
{
for (int i = 0; i < partitions->num_partitions; i++)
{
byte* partition = partitions->partitions[i];
if (partition != null)
{
Marshal.FreeHGlobal((IntPtr)partition);
partitions->partitions[i] = null;
}
}
Marshal.FreeHGlobal((IntPtr)partitions->partitions);
partitions->partitions = null;
}
if (partitions->partition_lengths != null)
{
Marshal.FreeHGlobal((IntPtr)partitions->partition_lengths);
partitions->partition_lengths = null;
}

partitions->release = default;
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
Expand Down Expand Up @@ -512,6 +545,46 @@ private unsafe static AdbcStatusCode SetStatementSqlQuery(CAdbcStatement* native
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
private unsafe static AdbcStatusCode SetStatementSubstraitPlan(CAdbcStatement* nativeStatement, byte* plan, int length, CAdbcError* error)
{
try
{
GCHandle gch = GCHandle.FromIntPtr((IntPtr)nativeStatement->private_data);
AdbcStatement stub = (AdbcStatement)gch.Target;

stub.SubstraitPlan = MarshalExtensions.MarshalBuffer(plan, length);

return AdbcStatusCode.Success;
}
catch (Exception e)
{
return SetError(error, e);
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
private unsafe static AdbcStatusCode GetStatementParameterSchema(CAdbcStatement* nativeStatement, CArrowSchema* schema, CAdbcError* error)
{
try
{
GCHandle gch = GCHandle.FromIntPtr((IntPtr)nativeStatement->private_data);
AdbcStatement stub = (AdbcStatement)gch.Target;

CArrowSchemaExporter.ExportSchema(stub.GetParameterSchema(), schema);

return AdbcStatusCode.Success;
}
catch (Exception e)
{
return SetError(error, e);
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
Expand All @@ -533,6 +606,26 @@ private unsafe static AdbcStatusCode BindStatement(CAdbcStatement* nativeStateme
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
private unsafe static AdbcStatusCode BindStreamStatement(CAdbcStatement* nativeStatement, CArrowArrayStream* stream, CAdbcError* error)
{
try
{
GCHandle gch = GCHandle.FromIntPtr((IntPtr)nativeStatement->private_data);
AdbcStatement stub = (AdbcStatement)gch.Target;

IArrowArrayStream arrayStream = CArrowArrayStreamImporter.ImportArrayStream(stream);
stub.BindStream(arrayStream);
return AdbcStatusCode.Success;
}
catch (Exception e)
{
return SetError(error, e);
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
Expand All @@ -557,6 +650,44 @@ private unsafe static AdbcStatusCode ExecuteStatementQuery(CAdbcStatement* nativ
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
private unsafe static AdbcStatusCode ExecuteStatementPartitions(CAdbcStatement* nativeStatement, CArrowSchema* schema, CAdbcPartitions* partitions, long* rows, CAdbcError* error)
{
try
{
GCHandle gch = GCHandle.FromIntPtr((IntPtr)nativeStatement->private_data);
AdbcStatement stub = (AdbcStatement)gch.Target;
var result = stub.ExecutePartitioned();
if (rows != null)
{
*rows = result.AffectedRows;
}

partitions->release = ReleasePartitionsPtr;
partitions->num_partitions = result.PartitionDescriptors.Count;
partitions->partitions = (byte**)Marshal.AllocHGlobal(IntPtr.Size * result.PartitionDescriptors.Count);
partitions->partition_lengths = (nuint*)Marshal.AllocHGlobal(IntPtr.Size * result.PartitionDescriptors.Count);
for (int i = 0; i < partitions->num_partitions; i++)
{
ReadOnlySpan<byte> partition = result.PartitionDescriptors[i].Descriptor;
partitions->partition_lengths[i] = (nuint)partition.Length;
partitions->partitions[i] = (byte*)Marshal.AllocHGlobal(partition.Length);
fixed (void* descriptor = partition)
{
Buffer.MemoryCopy(descriptor, partitions->partitions[i], partition.Length, partition.Length);
}
}

return AdbcStatusCode.Success;
}
catch (Exception e)
{
return SetError(error, e);
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
Expand Down
Loading
Loading