Skip to content

Commit

Permalink
feat(csharp/src/Drivers/Apache): extend capability of GetInfo for Spa…
Browse files Browse the repository at this point in the history
…rk driver (#1863)

Extend capability of GetInfo for Spark driver
* Adds dynamic calls to get the following from the DBMS 
  * vendor name
  * vendor version
* vendor sql (`true` - hard-coded default)
* driver version (using file info/product version)

Adds tests for supported and unsupported info.
  • Loading branch information
birschick-bq authored May 17, 2024
1 parent abe6d6a commit 3d021ea
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 75 deletions.
91 changes: 31 additions & 60 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
Expand All @@ -35,17 +36,29 @@ public abstract class HiveServer2Connection : AdbcConnection
internal TTransport? transport;
internal TCLIService.Client? client;
internal TSessionHandle? sessionHandle;
private readonly Lazy<string> _vendorVersion;
private readonly Lazy<string> _vendorName;

internal HiveServer2Connection(IReadOnlyDictionary<string, string> properties)
{
this.properties = properties;
// Note: "LazyThreadSafetyMode.PublicationOnly" is thread-safe initialization where
// the first successful thread sets the value. If an exception is thrown, initialization
// will retry until it successfully returns a value without an exception.
// https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects
_vendorVersion = new Lazy<string>(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), LazyThreadSafetyMode.PublicationOnly);
_vendorName = new Lazy<string>(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), LazyThreadSafetyMode.PublicationOnly);
}

internal TCLIService.Client Client
{
get { return this.client ?? throw new InvalidOperationException("connection not open"); }
}

protected string VendorVersion => _vendorVersion.Value;

protected string VendorName => _vendorName.Value;

internal async Task OpenAsync()
{
TProtocol protocol = await CreateProtocolAsync();
Expand Down Expand Up @@ -81,6 +94,24 @@ protected void PollForResponse()
} while (statusResponse.OperationState == TOperationState.PENDING_STATE || statusResponse.OperationState == TOperationState.RUNNING_STATE);
}

private string GetInfoTypeStringValue(TGetInfoType infoType)
{
TGetInfoReq req = new()
{
SessionHandle = this.sessionHandle ?? throw new InvalidOperationException("session not created"),
InfoType = infoType,
};

TGetInfoResp getInfoResp = Client.GetInfo(req).Result;
if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
{
throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage)
.SetNativeError(getInfoResp.Status.ErrorCode)
.SetSqlState(getInfoResp.Status.SqlState);
}

return getInfoResp.InfoValue.StringValue;
}

public override void Dispose()
{
Expand All @@ -102,65 +133,5 @@ protected Schema GetSchema()
TGetResultSetMetadataResp response = this.Client.GetResultSetMetadata(request).Result;
return SchemaParser.GetArrowSchema(response.Schema);
}

sealed class GetObjectsReader : IArrowArrayStream
{
HiveServer2Connection? connection;
Schema schema;
List<TSparkArrowBatch>? batches;
int index;
IArrowReader? reader;

public GetObjectsReader(HiveServer2Connection connection, Schema schema)
{
this.connection = connection;
this.schema = schema;
}

public Schema Schema { get { return schema; } }

public async ValueTask<RecordBatch?> ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
{
while (true)
{
if (this.reader != null)
{
RecordBatch? next = await this.reader.ReadNextRecordBatchAsync(cancellationToken);
if (next != null)
{
return next;
}
this.reader = null;
}

if (this.batches != null && this.index < this.batches.Count)
{
this.reader = new ArrowStreamReader(new ChunkStream(this.schema, this.batches[this.index++].Batch));
continue;
}

this.batches = null;
this.index = 0;

if (this.connection == null)
{
return null;
}

TFetchResultsReq request = new TFetchResultsReq(this.connection.operationHandle, TFetchOrientation.FETCH_NEXT, 50000);
TFetchResultsResp response = await this.connection.Client.FetchResults(request, cancellationToken);
this.batches = response.Results.ArrowBatches;

if (!response.HasMoreRows)
{
this.connection = null;
}
}
}

public void Dispose()
{
}
}
}
}
61 changes: 50 additions & 11 deletions csharp/src/Drivers/Apache/Spark/SparkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
using System.Diagnostics;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
Expand All @@ -43,16 +44,20 @@ public class SparkConnection : HiveServer2Connection
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.DriverArrowVersion,
AdbcInfoCode.VendorName
AdbcInfoCode.VendorName,
AdbcInfoCode.VendorSql,
AdbcInfoCode.VendorVersion,
};

const string ProductVersionDefault = "1.0.0";
const string InfoDriverName = "ADBC Spark Driver";
const string InfoDriverVersion = "1.0.0";
const string InfoVendorName = "Spark";
const string InfoDriverArrowVersion = "1.0.0";
const bool InfoVendorSql = true;
const int DecimalPrecisionDefault = 10;
const int DecimalScaleDefault = 0;

private readonly Lazy<string> _productVersion;

internal static TSparkGetDirectResults sparkGetDirectResults = new TSparkGetDirectResults(1000);

internal static readonly Dictionary<string, string> timestampConfig = new Dictionary<string, string>
Expand Down Expand Up @@ -83,8 +88,11 @@ private enum ColumnTypeId
internal SparkConnection(IReadOnlyDictionary<string, string> properties)
: base(properties)
{
_productVersion = new Lazy<string>(() => GetProductVersion(), LazyThreadSafetyMode.PublicationOnly);
}

protected string ProductVersion => _productVersion.Value;

protected override async ValueTask<TProtocol> CreateProtocolAsync()
{
Trace.TraceError($"create protocol with {properties.Count} properties.");
Expand Down Expand Up @@ -137,6 +145,7 @@ public override AdbcStatement CreateStatement()
public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
{
const int strValTypeID = 0;
const int boolValTypeId = 1;

UnionType infoUnionType = new UnionType(
new Field[]
Expand Down Expand Up @@ -178,8 +187,11 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
ArrowBuffer.Builder<byte> typeBuilder = new ArrowBuffer.Builder<byte>();
ArrowBuffer.Builder<int> offsetBuilder = new ArrowBuffer.Builder<int>();
StringArray.Builder stringInfoBuilder = new StringArray.Builder();
BooleanArray.Builder booleanInfoBuilder = new BooleanArray.Builder();

int nullCount = 0;
int arrayLength = codes.Count;
int offset = 0;

foreach (AdbcInfoCode code in codes)
{
Expand All @@ -188,32 +200,53 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
case AdbcInfoCode.DriverName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverName);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.DriverVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.Append(InfoDriverVersion);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(ProductVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.DriverArrowVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverArrowVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.Append(InfoVendorName);
offsetBuilder.Append(offset++);
string vendorName = VendorName;
stringInfoBuilder.Append(vendorName);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(offset++);
string? vendorVersion = VendorVersion;
stringInfoBuilder.Append(vendorVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorSql:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(boolValTypeId);
offsetBuilder.Append(offset++);
stringInfoBuilder.AppendNull();
booleanInfoBuilder.Append(InfoVendorSql);
break;
default:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.AppendNull();
booleanInfoBuilder.AppendNull();
nullCount++;
break;
}
Expand All @@ -231,7 +264,7 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
IArrowArray[] childrenArrays = new IArrowArray[]
{
stringInfoBuilder.Build(),
new BooleanArray.Builder().Build(),
booleanInfoBuilder.Build(),
new Int64Array.Builder().Build(),
new Int32Array.Builder().Build(),
new ListArray.Builder(StringType.Default).Build(),
Expand Down Expand Up @@ -749,6 +782,12 @@ private static bool TryParse(string input, out Decimal128Type? value)
return true;
}
}

private string GetProductVersion()
{
FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location);
return fileVersionInfo.ProductVersion ?? ProductVersionDefault;
}
}

internal struct TableInfoPair
Expand Down
77 changes: 73 additions & 4 deletions csharp/test/Drivers/Apache/Spark/DriverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,30 @@ public async Task CanGetInfo()
{
AdbcConnection adbcConnection = NewConnection();

using IArrowArrayStream stream = adbcConnection.GetInfo(new List<AdbcInfoCode>() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.VendorName });
// Test the supported info codes
List<AdbcInfoCode> handledCodes = new List<AdbcInfoCode>()
{
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.VendorName,
AdbcInfoCode.DriverArrowVersion,
AdbcInfoCode.VendorVersion,
AdbcInfoCode.VendorSql
};
using IArrowArrayStream stream = adbcConnection.GetInfo(handledCodes);

RecordBatch recordBatch = await stream.ReadNextRecordBatchAsync();
UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name");

List<string> expectedValues = new List<string>() { "DriverName", "DriverVersion", "VendorName" };
List<string> expectedValues = new List<string>()
{
"DriverName",
"DriverVersion",
"VendorName",
"DriverArrowVersion",
"VendorVersion",
"VendorSql"
};

for (int i = 0; i < infoNameArray.Length; i++)
{
Expand All @@ -98,8 +116,59 @@ public async Task CanGetInfo()

Assert.Contains(value.ToString(), expectedValues);

StringArray stringArray = (StringArray)valueArray.Fields[0];
Console.WriteLine($"{value}={stringArray.GetString(i)}");
switch (value)
{
case AdbcInfoCode.VendorSql:
// TODO: How does external developer know the second field is the boolean field?
BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1];
bool? boolValue = booleanArray.GetValue(i);
OutputHelper?.WriteLine($"{value}={boolValue}");
Assert.True(boolValue);
break;
default:
StringArray stringArray = (StringArray)valueArray.Fields[0];
string stringValue = stringArray.GetString(i);
OutputHelper?.WriteLine($"{value}={stringValue}");
Assert.NotNull(stringValue);
break;
}
}

// Test the unhandled info codes.
List<AdbcInfoCode> unhandledCodes = new List<AdbcInfoCode>()
{
AdbcInfoCode.VendorArrowVersion,
AdbcInfoCode.VendorSubstrait,
AdbcInfoCode.VendorSubstraitMaxVersion
};
using IArrowArrayStream stream2 = adbcConnection.GetInfo(unhandledCodes);

recordBatch = await stream2.ReadNextRecordBatchAsync();
infoNameArray = (UInt32Array)recordBatch.Column("info_name");

List<string> unexpectedValues = new List<string>()
{
"VendorArrowVersion",
"VendorSubstrait",
"VendorSubstraitMaxVersion"
};
for (int i = 0; i < infoNameArray.Length; i++)
{
AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i);
DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value");

Assert.Contains(value.ToString(), unexpectedValues);
switch (value)
{
case AdbcInfoCode.VendorSql:
BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1];
Assert.Null(booleanArray.GetValue(i));
break;
default:
StringArray stringArray = (StringArray)valueArray.Fields[0];
Assert.Null(stringArray.GetString(i));
break;
}
}
}

Expand Down

0 comments on commit 3d021ea

Please sign in to comment.