From 3d021eac66aef5e03306554b0dceab81835dc14a Mon Sep 17 00:00:00 2001 From: Bruce Irschick Date: Fri, 17 May 2024 15:51:09 -0700 Subject: [PATCH] feat(csharp/src/Drivers/Apache): extend capability of GetInfo for Spark driver (#1863) Extend capability of GetInfo for Spark driver * Adds dynamic calls to get the following from the DBMS * vendor name * vendor version * vendor sql (`true` - hard-coded default) * driver version (using file info/product version) Adds tests for supported and unsupported info. --- .../Apache/Hive2/HiveServer2Connection.cs | 91 +++++++------------ .../Drivers/Apache/Spark/SparkConnection.cs | 61 ++++++++++--- .../test/Drivers/Apache/Spark/DriverTests.cs | 77 +++++++++++++++- 3 files changed, 154 insertions(+), 75 deletions(-) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index 9e11cec10b..57bca5d1d7 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Ipc; @@ -35,10 +36,18 @@ public abstract class HiveServer2Connection : AdbcConnection internal TTransport? transport; internal TCLIService.Client? client; internal TSessionHandle? sessionHandle; + private readonly Lazy _vendorVersion; + private readonly Lazy _vendorName; internal HiveServer2Connection(IReadOnlyDictionary properties) { this.properties = properties; + // Note: "LazyThreadSafetyMode.PublicationOnly" is thread-safe initialization where + // the first successful thread sets the value. If an exception is thrown, initialization + // will retry until it successfully returns a value without an exception. + // https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects + _vendorVersion = new Lazy(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), LazyThreadSafetyMode.PublicationOnly); + _vendorName = new Lazy(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), LazyThreadSafetyMode.PublicationOnly); } internal TCLIService.Client Client @@ -46,6 +55,10 @@ internal TCLIService.Client Client get { return this.client ?? throw new InvalidOperationException("connection not open"); } } + protected string VendorVersion => _vendorVersion.Value; + + protected string VendorName => _vendorName.Value; + internal async Task OpenAsync() { TProtocol protocol = await CreateProtocolAsync(); @@ -81,6 +94,24 @@ protected void PollForResponse() } while (statusResponse.OperationState == TOperationState.PENDING_STATE || statusResponse.OperationState == TOperationState.RUNNING_STATE); } + private string GetInfoTypeStringValue(TGetInfoType infoType) + { + TGetInfoReq req = new() + { + SessionHandle = this.sessionHandle ?? throw new InvalidOperationException("session not created"), + InfoType = infoType, + }; + + TGetInfoResp getInfoResp = Client.GetInfo(req).Result; + if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage) + .SetNativeError(getInfoResp.Status.ErrorCode) + .SetSqlState(getInfoResp.Status.SqlState); + } + + return getInfoResp.InfoValue.StringValue; + } public override void Dispose() { @@ -102,65 +133,5 @@ protected Schema GetSchema() TGetResultSetMetadataResp response = this.Client.GetResultSetMetadata(request).Result; return SchemaParser.GetArrowSchema(response.Schema); } - - sealed class GetObjectsReader : IArrowArrayStream - { - HiveServer2Connection? connection; - Schema schema; - List? batches; - int index; - IArrowReader? reader; - - public GetObjectsReader(HiveServer2Connection connection, Schema schema) - { - this.connection = connection; - this.schema = schema; - } - - public Schema Schema { get { return schema; } } - - public async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) - { - while (true) - { - if (this.reader != null) - { - RecordBatch? next = await this.reader.ReadNextRecordBatchAsync(cancellationToken); - if (next != null) - { - return next; - } - this.reader = null; - } - - if (this.batches != null && this.index < this.batches.Count) - { - this.reader = new ArrowStreamReader(new ChunkStream(this.schema, this.batches[this.index++].Batch)); - continue; - } - - this.batches = null; - this.index = 0; - - if (this.connection == null) - { - return null; - } - - TFetchResultsReq request = new TFetchResultsReq(this.connection.operationHandle, TFetchOrientation.FETCH_NEXT, 50000); - TFetchResultsResp response = await this.connection.Client.FetchResults(request, cancellationToken); - this.batches = response.Results.ArrowBatches; - - if (!response.HasMoreRows) - { - this.connection = null; - } - } - } - - public void Dispose() - { - } - } } } diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index 36d62f89f2..3e3bbcae2e 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -20,6 +20,7 @@ using System.Diagnostics; using System.Net.Http; using System.Net.Http.Headers; +using System.Reflection; using System.Text; using System.Text.RegularExpressions; using System.Threading; @@ -43,16 +44,20 @@ public class SparkConnection : HiveServer2Connection AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.DriverArrowVersion, - AdbcInfoCode.VendorName + AdbcInfoCode.VendorName, + AdbcInfoCode.VendorSql, + AdbcInfoCode.VendorVersion, }; + const string ProductVersionDefault = "1.0.0"; const string InfoDriverName = "ADBC Spark Driver"; - const string InfoDriverVersion = "1.0.0"; - const string InfoVendorName = "Spark"; const string InfoDriverArrowVersion = "1.0.0"; + const bool InfoVendorSql = true; const int DecimalPrecisionDefault = 10; const int DecimalScaleDefault = 0; + private readonly Lazy _productVersion; + internal static TSparkGetDirectResults sparkGetDirectResults = new TSparkGetDirectResults(1000); internal static readonly Dictionary timestampConfig = new Dictionary @@ -83,8 +88,11 @@ private enum ColumnTypeId internal SparkConnection(IReadOnlyDictionary properties) : base(properties) { + _productVersion = new Lazy(() => GetProductVersion(), LazyThreadSafetyMode.PublicationOnly); } + protected string ProductVersion => _productVersion.Value; + protected override async ValueTask CreateProtocolAsync() { Trace.TraceError($"create protocol with {properties.Count} properties."); @@ -137,6 +145,7 @@ public override AdbcStatement CreateStatement() public override IArrowArrayStream GetInfo(IReadOnlyList codes) { const int strValTypeID = 0; + const int boolValTypeId = 1; UnionType infoUnionType = new UnionType( new Field[] @@ -178,8 +187,11 @@ public override IArrowArrayStream GetInfo(IReadOnlyList codes) ArrowBuffer.Builder typeBuilder = new ArrowBuffer.Builder(); ArrowBuffer.Builder offsetBuilder = new ArrowBuffer.Builder(); StringArray.Builder stringInfoBuilder = new StringArray.Builder(); + BooleanArray.Builder booleanInfoBuilder = new BooleanArray.Builder(); + int nullCount = 0; int arrayLength = codes.Count; + int offset = 0; foreach (AdbcInfoCode code in codes) { @@ -188,32 +200,53 @@ public override IArrowArrayStream GetInfo(IReadOnlyList codes) case AdbcInfoCode.DriverName: infoNameBuilder.Append((UInt32)code); typeBuilder.Append(strValTypeID); - offsetBuilder.Append(stringInfoBuilder.Length); + offsetBuilder.Append(offset++); stringInfoBuilder.Append(InfoDriverName); + booleanInfoBuilder.AppendNull(); break; case AdbcInfoCode.DriverVersion: infoNameBuilder.Append((UInt32)code); typeBuilder.Append(strValTypeID); - offsetBuilder.Append(stringInfoBuilder.Length); - stringInfoBuilder.Append(InfoDriverVersion); + offsetBuilder.Append(offset++); + stringInfoBuilder.Append(ProductVersion); + booleanInfoBuilder.AppendNull(); break; case AdbcInfoCode.DriverArrowVersion: infoNameBuilder.Append((UInt32)code); typeBuilder.Append(strValTypeID); - offsetBuilder.Append(stringInfoBuilder.Length); + offsetBuilder.Append(offset++); stringInfoBuilder.Append(InfoDriverArrowVersion); + booleanInfoBuilder.AppendNull(); break; case AdbcInfoCode.VendorName: infoNameBuilder.Append((UInt32)code); typeBuilder.Append(strValTypeID); - offsetBuilder.Append(stringInfoBuilder.Length); - stringInfoBuilder.Append(InfoVendorName); + offsetBuilder.Append(offset++); + string vendorName = VendorName; + stringInfoBuilder.Append(vendorName); + booleanInfoBuilder.AppendNull(); + break; + case AdbcInfoCode.VendorVersion: + infoNameBuilder.Append((UInt32)code); + typeBuilder.Append(strValTypeID); + offsetBuilder.Append(offset++); + string? vendorVersion = VendorVersion; + stringInfoBuilder.Append(vendorVersion); + booleanInfoBuilder.AppendNull(); + break; + case AdbcInfoCode.VendorSql: + infoNameBuilder.Append((UInt32)code); + typeBuilder.Append(boolValTypeId); + offsetBuilder.Append(offset++); + stringInfoBuilder.AppendNull(); + booleanInfoBuilder.Append(InfoVendorSql); break; default: infoNameBuilder.Append((UInt32)code); typeBuilder.Append(strValTypeID); - offsetBuilder.Append(stringInfoBuilder.Length); + offsetBuilder.Append(offset++); stringInfoBuilder.AppendNull(); + booleanInfoBuilder.AppendNull(); nullCount++; break; } @@ -231,7 +264,7 @@ public override IArrowArrayStream GetInfo(IReadOnlyList codes) IArrowArray[] childrenArrays = new IArrowArray[] { stringInfoBuilder.Build(), - new BooleanArray.Builder().Build(), + booleanInfoBuilder.Build(), new Int64Array.Builder().Build(), new Int32Array.Builder().Build(), new ListArray.Builder(StringType.Default).Build(), @@ -749,6 +782,12 @@ private static bool TryParse(string input, out Decimal128Type? value) return true; } } + + private string GetProductVersion() + { + FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location); + return fileVersionInfo.ProductVersion ?? ProductVersionDefault; + } } internal struct TableInfoPair diff --git a/csharp/test/Drivers/Apache/Spark/DriverTests.cs b/csharp/test/Drivers/Apache/Spark/DriverTests.cs index a4f3a4607d..a7507e473a 100644 --- a/csharp/test/Drivers/Apache/Spark/DriverTests.cs +++ b/csharp/test/Drivers/Apache/Spark/DriverTests.cs @@ -84,12 +84,30 @@ public async Task CanGetInfo() { AdbcConnection adbcConnection = NewConnection(); - using IArrowArrayStream stream = adbcConnection.GetInfo(new List() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.VendorName }); + // Test the supported info codes + List handledCodes = new List() + { + AdbcInfoCode.DriverName, + AdbcInfoCode.DriverVersion, + AdbcInfoCode.VendorName, + AdbcInfoCode.DriverArrowVersion, + AdbcInfoCode.VendorVersion, + AdbcInfoCode.VendorSql + }; + using IArrowArrayStream stream = adbcConnection.GetInfo(handledCodes); RecordBatch recordBatch = await stream.ReadNextRecordBatchAsync(); UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name"); - List expectedValues = new List() { "DriverName", "DriverVersion", "VendorName" }; + List expectedValues = new List() + { + "DriverName", + "DriverVersion", + "VendorName", + "DriverArrowVersion", + "VendorVersion", + "VendorSql" + }; for (int i = 0; i < infoNameArray.Length; i++) { @@ -98,8 +116,59 @@ public async Task CanGetInfo() Assert.Contains(value.ToString(), expectedValues); - StringArray stringArray = (StringArray)valueArray.Fields[0]; - Console.WriteLine($"{value}={stringArray.GetString(i)}"); + switch (value) + { + case AdbcInfoCode.VendorSql: + // TODO: How does external developer know the second field is the boolean field? + BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1]; + bool? boolValue = booleanArray.GetValue(i); + OutputHelper?.WriteLine($"{value}={boolValue}"); + Assert.True(boolValue); + break; + default: + StringArray stringArray = (StringArray)valueArray.Fields[0]; + string stringValue = stringArray.GetString(i); + OutputHelper?.WriteLine($"{value}={stringValue}"); + Assert.NotNull(stringValue); + break; + } + } + + // Test the unhandled info codes. + List unhandledCodes = new List() + { + AdbcInfoCode.VendorArrowVersion, + AdbcInfoCode.VendorSubstrait, + AdbcInfoCode.VendorSubstraitMaxVersion + }; + using IArrowArrayStream stream2 = adbcConnection.GetInfo(unhandledCodes); + + recordBatch = await stream2.ReadNextRecordBatchAsync(); + infoNameArray = (UInt32Array)recordBatch.Column("info_name"); + + List unexpectedValues = new List() + { + "VendorArrowVersion", + "VendorSubstrait", + "VendorSubstraitMaxVersion" + }; + for (int i = 0; i < infoNameArray.Length; i++) + { + AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i); + DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value"); + + Assert.Contains(value.ToString(), unexpectedValues); + switch (value) + { + case AdbcInfoCode.VendorSql: + BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1]; + Assert.Null(booleanArray.GetValue(i)); + break; + default: + StringArray stringArray = (StringArray)valueArray.Fields[0]; + Assert.Null(stringArray.GetString(i)); + break; + } } }