From f2b5f8d17e9aa75de9ac8ee919aab1622617ba54 Mon Sep 17 00:00:00 2001 From: Bruce Irschick Date: Wed, 25 Sep 2024 10:48:26 -0700 Subject: [PATCH] feat(csharp/src/Drivers/Apache/Spark): Perform scalar data type conversion for Spark over HTTP (#2152) Adds scalar conversion from string to Date32/Decimal128/DateTimeOff for DATE/DECIMAL/TIMESTAMP, respectively. Note: FLOAT type is returned as Double NULL type is returned as String. Re-use existing tests but adjust what data type the comparison uses depending on the data conversion flag. Some refactoring to make available from all Hive Server2 driver types. Add minimal support for Impala testing. --- .../Apache/Hive2/DataTypeConversion.cs | 66 +++ .../Drivers/Apache/Hive2/DecimalUtility.cs | 452 ++++++++++++++++++ .../Apache/Hive2/HiveServer2Connection.cs | 2 + .../Apache/Hive2/HiveServer2Parameters.cs | 27 ++ .../Drivers/Apache/Hive2/HiveServer2Reader.cs | 206 +++++++- .../Apache/Hive2/HiveServer2SchemaParser.cs | 58 +++ .../Apache/Hive2/HiveServer2Statement.cs | 2 +- .../Drivers/Apache/Impala/ImpalaAuthType.cs | 54 +++ .../Drivers/Apache/Impala/ImpalaConnection.cs | 4 +- .../Drivers/Apache/Impala/ImpalaParameters.cs | 38 ++ csharp/src/Drivers/Apache/Spark/README.md | 56 +-- .../src/Drivers/Apache/Spark/SparkAuthType.cs | 58 +++ .../Drivers/Apache/Spark/SparkConnection.cs | 2 - .../Apache/Spark/SparkConnectionFactory.cs | 6 +- .../Apache/Spark/SparkDatabricksConnection.cs | 46 +- .../Spark/SparkDatabricksSchemaParser.cs | 58 +++ .../Apache/Spark/SparkHttpConnection.cs | 53 +- .../Drivers/Apache/Spark/SparkParameters.cs | 102 ---- .../Drivers/Apache/Spark/SparkServerType.cs | 56 +++ .../Apache/Spark/SparkStandardConnection.cs | 4 +- .../src/Drivers/Apache/Thrift/SchemaParser.cs | 14 +- .../test/Apache.Arrow.Adbc.Tests/TestBase.cs | 4 +- ...che.Arrow.Adbc.Tests.Drivers.Apache.csproj | 11 +- .../Apache/Hive2/DecimalUtilityTests.cs | 172 +++++++ .../Apache/Hive2/HiveServer2ParametersTest.cs | 62 +++ .../Apache/Impala/ImpalaTestEnvironment.cs | 77 +++ .../test/Drivers/Apache/Impala/ImpalaTests.cs | 28 +- .../Resources/ImpalaData.sql} | 0 .../Apache/Impala/Resources/impalaconfig.json | 14 +- .../Apache/Spark/DateTimeValueTests.cs | 5 +- .../test/Drivers/Apache/Spark/DriverTests.cs | 14 +- .../Drivers/Apache/Spark/NumericValueTests.cs | 4 +- .../Spark/Resources/SparkData-Databricks.sql | 133 ++++++ .../Apache/Spark/SparkTestEnvironment.cs | 15 +- .../Drivers/Apache/Spark/StringValueTests.cs | 26 +- 35 files changed, 1632 insertions(+), 297 deletions(-) create mode 100644 csharp/src/Drivers/Apache/Hive2/DataTypeConversion.cs create mode 100644 csharp/src/Drivers/Apache/Hive2/DecimalUtility.cs create mode 100644 csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs create mode 100644 csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs create mode 100644 csharp/src/Drivers/Apache/Impala/ImpalaAuthType.cs create mode 100644 csharp/src/Drivers/Apache/Impala/ImpalaParameters.cs create mode 100644 csharp/src/Drivers/Apache/Spark/SparkAuthType.cs create mode 100644 csharp/src/Drivers/Apache/Spark/SparkDatabricksSchemaParser.cs create mode 100644 csharp/src/Drivers/Apache/Spark/SparkServerType.cs create mode 100644 csharp/test/Drivers/Apache/Hive2/DecimalUtilityTests.cs create mode 100644 csharp/test/Drivers/Apache/Hive2/HiveServer2ParametersTest.cs create mode 100644 csharp/test/Drivers/Apache/Impala/ImpalaTestEnvironment.cs rename csharp/test/Drivers/Apache/{Spark/Resources/SparkData-3.4.sql => Impala/Resources/ImpalaData.sql} (100%) create mode 100644 csharp/test/Drivers/Apache/Spark/Resources/SparkData-Databricks.sql diff --git a/csharp/src/Drivers/Apache/Hive2/DataTypeConversion.cs b/csharp/src/Drivers/Apache/Hive2/DataTypeConversion.cs new file mode 100644 index 0000000000..0d0865d7ab --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/DataTypeConversion.cs @@ -0,0 +1,66 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + [Flags] + internal enum DataTypeConversion + { + Empty = 0, + None = 1, + Scalar = 2, + } + + internal static class DataTypeConversionParser + { + internal const string SupportedList = DataTypeConversionOptions.None + ", " + DataTypeConversionOptions.Scalar; + + internal static DataTypeConversion Parse(string? dataTypeConversion) + { + DataTypeConversion result = DataTypeConversion.Empty; + + if (string.IsNullOrWhiteSpace(dataTypeConversion)) + { + // Default + return DataTypeConversion.Scalar; + } + + string[] conversions = dataTypeConversion!.Split(','); + foreach (string? conversion in conversions) + { + result |= (conversion?.Trim().ToLowerInvariant()) switch + { + null or "" => DataTypeConversion.Empty, + DataTypeConversionOptions.None => DataTypeConversion.None, + DataTypeConversionOptions.Scalar => DataTypeConversion.Scalar, + _ => throw new ArgumentOutOfRangeException(nameof(dataTypeConversion), conversion, "Invalid or unsupported data type conversion"), + }; + } + + if (result.HasFlag(DataTypeConversion.None) && result.HasFlag(DataTypeConversion.Scalar)) + { + throw new ArgumentOutOfRangeException(nameof(dataTypeConversion), dataTypeConversion, "Conflicting data type conversion options"); + } + // Default + if (result == DataTypeConversion.Empty) result = DataTypeConversion.Scalar; + + return result; + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/DecimalUtility.cs b/csharp/src/Drivers/Apache/Hive2/DecimalUtility.cs new file mode 100644 index 0000000000..48c3534682 --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/DecimalUtility.cs @@ -0,0 +1,452 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Numerics; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + internal static class DecimalUtility + { + private const char AsciiZero = '0'; + private const int AsciiDigitMaxIndex = '9' - AsciiZero; + private const char AsciiMinus = '-'; + private const char AsciiPlus = '+'; + private const char AsciiUpperE = 'E'; + private const char AsciiLowerE = 'e'; + private const char AsciiPeriod = '.'; + + /// + /// Gets the BigInteger bytes for the given string value. + /// + /// The numeric string value to get bytes for. + /// The decimal precision for the target Decimal[128|256] + /// The decimal scale for the target Decimal[128|256] + /// The width in bytes for the target buffer. Should match the length of the bytes parameter. + /// The buffer to place the BigInteger bytes into. + /// + internal static void GetBytes(string value, int precision, int scale, int byteWidth, Span bytes) + { + if (precision < 1) + { + throw new ArgumentOutOfRangeException(nameof(precision), precision, "precision value must be greater than zero."); + } + if (scale < 0 || scale >= precision) + { + throw new ArgumentOutOfRangeException(nameof(scale), scale, "scale value must be in the range 0 .. precision."); + } + if (byteWidth > bytes.Length) + { + throw new ArgumentOutOfRangeException(nameof(byteWidth), byteWidth, $"value for byteWidth {byteWidth} exceeds the the size of bytes."); + } + + BigInteger integerValue = ToBigInteger(value, precision, scale); + + FillBytes(bytes, integerValue, byteWidth); + } + + private static void FillBytes(Span bytes, BigInteger integerValue, int byteWidth) + { + int bytesWritten = 0; +#if NETCOREAPP + if (!integerValue.TryWriteBytes(bytes, out bytesWritten, false, !BitConverter.IsLittleEndian)) + { + throw new OverflowException("Could not extract bytes from integer value " + integerValue); + } +#else + byte[] tempBytes = integerValue.ToByteArray(); + bytesWritten = tempBytes.Length; + if (bytesWritten > bytes.Length) + { + throw new OverflowException($"Decimal size greater than {byteWidth} bytes: {bytesWritten}"); + } + tempBytes.CopyTo(bytes); +#endif + byte fillByte = (byte)(integerValue < 0 ? 255 : 0); + for (int i = bytesWritten; i < byteWidth; i++) + { + bytes[i] = fillByte; + } + } + + private static BigInteger ToBigInteger(string value, int precision, int scale) + { + BigInteger integerValue; +#if NETCOREAPP + ReadOnlySpan significantValue = GetSignificantValue(value, precision, scale); + integerValue = BigInteger.Parse(significantValue); +#else + ReadOnlySpan significantValue = GetSignificantValue(value.AsSpan(), precision, scale); + integerValue = BigInteger.Parse(significantValue.ToString()); +#endif + return integerValue; + } + + private static ReadOnlySpan GetSignificantValue(ReadOnlySpan value, int precision, int scale) + { + ParseDecimal(value, out ParserState state); + + ProcessDecimal(value, + precision, + scale, + state, + out char sign, + out ReadOnlySpan integerSpan, + out ReadOnlySpan fractionalSpan, + out int neededScale); + + Span significant = new char[precision + 1]; + BuildSignificantValue( + sign, + scale, + integerSpan, + fractionalSpan, + neededScale, + significant); + + return significant; + } + + private static void ProcessDecimal(ReadOnlySpan value, int precision, int scale, ParserState state, out char sign, out ReadOnlySpan integerSpan, out ReadOnlySpan fractionalSpan, out int neededScale) + { + int int_length = 0; + int frac_length = 0; + int exponent = 0; + + if (state.IntegerStart != -1 && state.IntegerEnd != -1) int_length = state.IntegerEnd - state.IntegerStart + 1; + if (state.FractionalStart != -1 && state.FractionalEnd != -1) frac_length = state.FractionalEnd - state.FractionalStart + 1; + if (state.ExponentIndex != -1 && state.ExponentStart != -1 && state.ExponentEnd != -1 && state.ExponentEnd >= state.ExponentStart) + { + int expStart = state.ExpSignIndex != -1 ? state.ExpSignIndex : state.ExponentStart; + int expLength = state.ExponentEnd - expStart + 1; + ReadOnlySpan exponentSpan = value.Slice(expStart, expLength); +#if NETCOREAPP + exponent = int.Parse(exponentSpan); +#else + exponent = int.Parse(exponentSpan.ToString()); +#endif + } + integerSpan = int_length > 0 ? value.Slice(state.IntegerStart, state.IntegerEnd - state.IntegerStart + 1) : []; + fractionalSpan = frac_length > 0 ? value.Slice(state.FractionalStart, state.FractionalEnd - state.FractionalStart + 1) : []; + Span tempSignificant; + if (exponent != 0) + { + tempSignificant = new char[int_length + frac_length]; + if (int_length > 0) value.Slice(state.IntegerStart, state.IntegerEnd - state.IntegerStart + 1).CopyTo(tempSignificant.Slice(0)); + if (frac_length > 0) value.Slice(state.FractionalStart, state.FractionalEnd - state.FractionalStart + 1).CopyTo(tempSignificant.Slice(int_length)); + // Trim trailing zeros from combined string + while (tempSignificant[tempSignificant.Length - 1] == AsciiZero) + { + tempSignificant = tempSignificant.Slice(0, tempSignificant.Length - 1); + } + // Recalculate integer and fractional length + if (exponent > 0) + { + int_length = Math.Min(int_length + exponent, tempSignificant.Length); + frac_length = Math.Max(Math.Min(frac_length - exponent, tempSignificant.Length - int_length), 0); + } + else + { + int_length = Math.Max(int_length + exponent, 0); + frac_length = Math.Max(Math.Min(frac_length - exponent, tempSignificant.Length - int_length), 0); + } + // Reset the integer and fractional span + integerSpan = tempSignificant.Slice(0, int_length); + fractionalSpan = tempSignificant.Slice(int_length, frac_length); + } + + int neededPrecision = int_length + frac_length; + neededScale = frac_length; + if (neededPrecision > precision) + { + throw new OverflowException($"Decimal precision cannot be greater than that in the Arrow vector: {value.ToString()} has precision > {precision}"); + } + if (neededScale > scale) + { + throw new OverflowException($"Decimal scale cannot be greater than that in the Arrow vector: {value.ToString()} has scale > {scale}"); + } + sign = state.SignIndex != -1 ? value[state.SignIndex] : AsciiPlus; + } + + private static void BuildSignificantValue( + char sign, + int scale, + ReadOnlySpan integerSpan, + ReadOnlySpan fractionalSpan, + int neededScale, + Span significant) + { + significant[0] = sign; + int end = 0; + integerSpan.CopyTo(significant.Slice(end + 1)); + end += integerSpan.Length; + fractionalSpan.CopyTo(significant.Slice(end + 1)); + end += fractionalSpan.Length; + + // Add trailing zeros to adjust for scale + while (neededScale < scale) + { + neededScale++; + end++; + significant[end] = AsciiZero; + } + } + + private enum ParseState + { + StartWhiteSpace, + SignOrDigitOrDecimal, + DigitOrDecimalOrExponent, + FractionOrExponent, + ExpSignOrExpValue, + ExpValue, + EndWhiteSpace, + Invalid, + } + + private struct ParserState + { + public ParseState CurrentState = ParseState.StartWhiteSpace; + public int SignIndex = -1; + public int IntegerStart = -1; + public int IntegerEnd = -1; + public int DecimalIndex = -1; + public int FractionalStart = -1; + public int FractionalEnd = -1; + public int ExponentIndex = -1; + public int ExpSignIndex = -1; + public int ExponentStart = -1; + public int ExponentEnd = -1; + public bool HasZero = false; + + public ParserState() { } + } + + private static void ParseDecimal(ReadOnlySpan value, out ParserState parserState) + { + ParserState state = new ParserState(); + int index = 0; + int length = value.Length; + while (index < length) + { + char c = value[index]; + switch (state.CurrentState) + { + case ParseState.StartWhiteSpace: + if (!char.IsWhiteSpace(c)) + { + state.CurrentState = ParseState.SignOrDigitOrDecimal; + } + else + { + index++; + } + break; + case ParseState.SignOrDigitOrDecimal: + // Is Ascii Numeric + if ((uint)(c - AsciiZero) <= AsciiDigitMaxIndex) + { + if (!state.HasZero && c == AsciiZero) state.HasZero |= true; + state.IntegerStart = index; + state.IntegerEnd = index; + index++; + state.CurrentState = ParseState.DigitOrDecimalOrExponent; + } + else if (c == AsciiMinus || c == AsciiPlus) + { + state.SignIndex = index; + index++; + state.CurrentState = ParseState.DigitOrDecimalOrExponent; + } + else if (c == AsciiPeriod) + { + state.DecimalIndex = index; + index++; + state.CurrentState = ParseState.FractionOrExponent; + } + else if (char.IsWhiteSpace(c)) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.DigitOrDecimalOrExponent: + // Is Ascii Numeric + if ((uint)(c - AsciiZero) <= AsciiDigitMaxIndex) + { + if (state.IntegerStart == -1) state.IntegerStart = index; + if (!state.HasZero && c == AsciiZero) state.HasZero |= true; + state.IntegerEnd = index; + index++; + } + else if (c == AsciiPeriod) + { + state.DecimalIndex = index; + index++; + state.CurrentState = ParseState.FractionOrExponent; + } + else if (c == AsciiUpperE || c == AsciiLowerE) + { + state.ExponentIndex = index; + index++; + state.CurrentState = ParseState.ExpSignOrExpValue; + } + else if (char.IsWhiteSpace(c)) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.FractionOrExponent: + // Is Ascii Numeric + if ((uint)(c - AsciiZero) <= AsciiDigitMaxIndex) + { + if (state.FractionalStart == -1) state.FractionalStart = index; + if (!state.HasZero && c == AsciiZero) state.HasZero |= true; + state.FractionalEnd = index; + index++; + } + else if (c == AsciiUpperE || c == AsciiLowerE) + { + state.ExponentIndex = index; + index++; + state.CurrentState = ParseState.ExpSignOrExpValue; + } + else if (char.IsWhiteSpace(c)) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.ExpSignOrExpValue: + // Is Ascii Numeric + if ((uint)(c - AsciiZero) <= AsciiDigitMaxIndex) + { + if (state.ExponentStart == -1) state.ExponentStart = index; + state.ExponentEnd = index; + index++; + state.CurrentState = ParseState.ExpValue; + } + else if (c == AsciiMinus || c == AsciiPlus) + { + state.ExpSignIndex = index; + index++; + state.CurrentState = ParseState.ExpValue; + } + else if (char.IsWhiteSpace(c)) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.ExpValue: + // Is Ascii Numeric + if ((uint)(c - AsciiZero) <= AsciiDigitMaxIndex) + { + if (state.ExponentStart == -1) state.ExponentStart = index; + state.ExponentEnd = index; + index++; + } + else if (char.IsWhiteSpace(c)) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.EndWhiteSpace: + if (char.IsWhiteSpace(c)) + { + index++; + state.CurrentState = ParseState.EndWhiteSpace; + } + else + { + state.CurrentState = ParseState.Invalid; + } + break; + case ParseState.Invalid: + throw new ArgumentOutOfRangeException(nameof(value), value.ToString(), $"Invalid numeric value at index {index}."); + } + } + // Trim leading zeros from integer portion + if (state.IntegerStart != -1 && state.IntegerEnd != -1) + { + for (int i = state.IntegerStart; i <= state.IntegerEnd; i++) + { + if (value[i] != AsciiZero) break; + + state.IntegerStart = i + 1; + if (state.IntegerStart > state.IntegerEnd) + { + state.IntegerStart = -1; + state.IntegerEnd = -1; + break; + } + } + } + // Trim trailing zeros from fractional portion + if (state.FractionalStart != -1 && state.FractionalEnd != -1) + { + for (int i = state.FractionalEnd; i >= state.FractionalStart; i--) + { + if (value[i] != AsciiZero) break; + + state.FractionalEnd = i - 1; + if (state.FractionalStart > state.FractionalEnd) + { + state.FractionalStart = -1; + state.FractionalEnd = -1; + break; + } + } + } + // Must have a integer or fractional part. + if (state.IntegerStart == -1 && state.FractionalStart == -1) + { + if (!state.HasZero) + throw new ArgumentOutOfRangeException(nameof(value), value.ToString(), "input does not contain a valid numeric value."); + else + { + state.IntegerStart = value.IndexOf(AsciiZero); + state.IntegerEnd = state.IntegerStart; + } + } + + parserState = state; + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index 5815a98191..2100f5744a 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -71,6 +71,8 @@ internal async Task OpenAsync() internal TSessionHandle? SessionHandle { get; private set; } + protected internal DataTypeConversion DataTypeConversion { get; set; } = DataTypeConversion.None; + protected abstract Task CreateTransportAsync(); protected abstract Task CreateProtocolAsync(TTransport transport); diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs new file mode 100644 index 0000000000..5eec97823f --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + public static class DataTypeConversionOptions + { + public const string None = "none"; + public const string Scalar = "scalar"; + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs index 9f87c35ad5..e72076b132 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs @@ -16,24 +16,48 @@ */ using System; -using System.Linq; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Globalization; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Ipc; +using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { internal class HiveServer2Reader : IArrowArrayStream { + private const char AsciiZero = '0'; + private const int AsciiDigitMaxIndex = '9' - AsciiZero; + private const char AsciiDash = '-'; + private const char AsciiSpace = ' '; + private const char AsciiColon = ':'; + private const char AsciiPeriod = '.'; + private HiveServer2Statement? _statement; private readonly long _batchSize; + private readonly DataTypeConversion _dataTypeConversion; + private static readonly IReadOnlyDictionary> s_arrowStringConverters = + new Dictionary>() + { + { ArrowTypeId.Date32, ConvertToDate32 }, + { ArrowTypeId.Decimal128, ConvertToDecimal128 }, + { ArrowTypeId.Timestamp, ConvertToTimestamp }, + }; + - public HiveServer2Reader(HiveServer2Statement statement, Schema schema, long batchSize = HiveServer2Connection.BatchSizeDefault) + public HiveServer2Reader( + HiveServer2Statement statement, + Schema schema, + DataTypeConversion dataTypeConversion, + long batchSize = HiveServer2Connection.BatchSizeDefault) { _statement = statement; Schema = schema; _batchSize = batchSize; + _dataTypeConversion = dataTypeConversion; } public Schema Schema { get; } @@ -48,10 +72,20 @@ public HiveServer2Reader(HiveServer2Statement statement, Schema schema, long bat var request = new TFetchResultsReq(_statement.OperationHandle, TFetchOrientation.FETCH_NEXT, _batchSize); TFetchResultsResp response = await _statement.Connection.Client.FetchResults(request, cancellationToken); - int length = response.Results.Columns.Count > 0 ? GetArray(response.Results.Columns[0]).Length : 0; + int columnCount = response.Results.Columns.Count; + IList columnData = []; + bool shouldConvertScalar = _dataTypeConversion.HasFlag(DataTypeConversion.Scalar); + for (int i = 0; i < columnCount; i++) + { + IArrowType? expectedType = shouldConvertScalar ? Schema.FieldsList[i].DataType : null; + IArrowArray columnArray = GetArray(response.Results.Columns[i], expectedType); + columnData.Add(columnArray); + } + + int length = columnCount > 0 ? GetArray(response.Results.Columns[0]).Length : 0; var result = new RecordBatch( Schema, - response.Results.Columns.Select(GetArray), + columnData, length); if (!response.HasMoreRows) @@ -66,9 +100,9 @@ public void Dispose() { } - static IArrowArray GetArray(TColumn column) + private static IArrowArray GetArray(TColumn column, IArrowType? expectedArrowType = default) { - return + IArrowArray arrowArray = (IArrowArray?)column.BoolVal?.Values ?? (IArrowArray?)column.ByteVal?.Values ?? (IArrowArray?)column.I16Val?.Values ?? @@ -78,6 +112,166 @@ static IArrowArray GetArray(TColumn column) (IArrowArray?)column.StringVal?.Values ?? (IArrowArray?)column.BinaryVal?.Values ?? throw new InvalidOperationException("unsupported data type"); + if (expectedArrowType != null && arrowArray is StringArray stringArray && s_arrowStringConverters.ContainsKey(expectedArrowType.TypeId)) + { + // Perform a conversion from string to native/scalar type. + Func converter = s_arrowStringConverters[expectedArrowType.TypeId]; + return converter(stringArray, expectedArrowType); + } + return arrowArray; + } + + private static Date32Array ConvertToDate32(StringArray array, IArrowType _) + { + var resultArray = new Date32Array.Builder(); + foreach (string item in (IReadOnlyCollection)array) + { + if (item == null) + { + resultArray.AppendNull(); + continue; + } + + ReadOnlySpan date = item.AsSpan(); + bool isKnownFormat = date.Length >= 8 && date[4] == AsciiDash && date[7] == AsciiDash; + if (isKnownFormat) + { + DateTime value = ConvertToDateTime(date); + resultArray.Append(value); + } + else + { + resultArray.Append(DateTime.Parse(item, CultureInfo.InvariantCulture)); + } + } + + return resultArray.Build(); + } + + private static DateTime ConvertToDateTime(ReadOnlySpan date) + { + int year; + int month; + int day; +#if NETCOREAPP + year = int.Parse(date.Slice(0, 4)); + month = int.Parse(date.Slice(5, 2)); + day = int.Parse(date.Slice(8, 2)); +#else + year = int.Parse(date.Slice(0, 4).ToString()); + month = int.Parse(date.Slice(5, 2).ToString()); + day = int.Parse(date.Slice(8, 2).ToString()); +#endif + DateTime value = new(year, month, day); + return value; + } + + private static Decimal128Array ConvertToDecimal128(StringArray array, IArrowType schemaType) + { + // Using the schema type to get the precision and scale. + Decimal128Type decimalType = (Decimal128Type)schemaType; + var resultArray = new Decimal128Array.Builder(decimalType); + Span buffer = stackalloc byte[decimalType.ByteWidth]; + foreach (string item in (IReadOnlyList)array) + { + if (item == null) + { + resultArray.AppendNull(); + continue; + } + + // Try to parse the value into a decimal because it is the most performant and handles the exponent syntax. But this might overflow. + if (decimal.TryParse(item, NumberStyles.Float, CultureInfo.InvariantCulture, out decimal decimalValue)) + { + resultArray.Append(new SqlDecimal(decimalValue)); + } + else + { + DecimalUtility.GetBytes(item, decimalType.Precision, decimalType.Scale, decimalType.ByteWidth, buffer); + resultArray.Append(buffer); + } + } + return resultArray.Build(); + } + + private static TimestampArray ConvertToTimestamp(StringArray array, IArrowType _) + { + // Match the precision of the server + var resultArrayBuilder = new TimestampArray.Builder(TimeUnit.Microsecond); + foreach (string item in (IReadOnlyList)array) + { + if (item == null) + { + resultArrayBuilder.AppendNull(); + continue; + } + + ReadOnlySpan date = item.AsSpan(); + bool isKnownFormat = date.Length >= 17 && date[4] == AsciiDash && date[7] == AsciiDash && date[10] == AsciiSpace && date[13] == AsciiColon && date[16] == AsciiColon; + if (isKnownFormat) + { + DateTimeOffset value = ConvertToDateTimeOffset(date); + resultArrayBuilder.Append(value); + } + else + { + DateTimeOffset value = DateTimeOffset.Parse(item, DateTimeFormatInfo.InvariantInfo, DateTimeStyles.AssumeUniversal); + resultArrayBuilder.Append(value); + } + } + return resultArrayBuilder.Build(); + } + + private static DateTimeOffset ConvertToDateTimeOffset(ReadOnlySpan date) + { + int year; + int month; + int day; + int hour; + int minute; + int second; +#if NETCOREAPP + year = int.Parse(date.Slice(0, 4)); + month = int.Parse(date.Slice(5, 2)); + day = int.Parse(date.Slice(8, 2)); + hour = int.Parse(date.Slice(11, 2)); + minute = int.Parse(date.Slice(14, 2)); + second = int.Parse(date.Slice(17, 2)); +#else + year = int.Parse(date.Slice(0, 4).ToString()); + month = int.Parse(date.Slice(5, 2).ToString()); + day = int.Parse(date.Slice(8, 2).ToString()); + hour = int.Parse(date.Slice(11, 2).ToString()); + minute = int.Parse(date.Slice(14, 2).ToString()); + second = int.Parse(date.Slice(17, 2).ToString()); +#endif + DateTimeOffset dateValue = new(year, month, day, hour, minute, second, TimeSpan.Zero); + int length = date.Length; + if (length >= 20 && date[19] == AsciiPeriod) + { + int start = -1; + int end = 20; + while (end < length && (uint)(date[end] - AsciiZero) <= AsciiDigitMaxIndex) + { + if (start == -1) start = end; + end++; + } + int subSeconds = 0; + int subSecondsLength = start != -1 ? end - start : 0; + if (subSecondsLength > 0) + { +#if NETCOREAPP + subSeconds = int.Parse(date.Slice(start, subSecondsLength)); +#else + subSeconds = int.Parse(date.Slice(start, subSecondsLength).ToString()); +#endif + } + double factorOfMilliseconds = Math.Pow(10, subSecondsLength - 3); + long ticks = (long)(subSeconds * (TimeSpan.TicksPerMillisecond / factorOfMilliseconds)); + dateValue = dateValue.AddTicks(ticks); + } + + return dateValue; } } } diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs new file mode 100644 index 0000000000..913f4d114b --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs @@ -0,0 +1,58 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using Apache.Arrow.Types; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 +{ + internal class HiveServer2SchemaParser : SchemaParser + { + public override IArrowType GetArrowType(TPrimitiveTypeEntry thriftType, DataTypeConversion dataTypeConversion) + { + bool convertScalar = dataTypeConversion.HasFlag(DataTypeConversion.Scalar); + return thriftType.Type switch + { + TTypeId.BIGINT_TYPE => Int64Type.Default, + TTypeId.BINARY_TYPE => BinaryType.Default, + TTypeId.BOOLEAN_TYPE => BooleanType.Default, + TTypeId.DOUBLE_TYPE + or TTypeId.FLOAT_TYPE => DoubleType.Default, + TTypeId.INT_TYPE => Int32Type.Default, + TTypeId.SMALLINT_TYPE => Int16Type.Default, + TTypeId.TINYINT_TYPE => Int8Type.Default, + TTypeId.DATE_TYPE => convertScalar ? Date32Type.Default : StringType.Default, + TTypeId.DECIMAL_TYPE => convertScalar ? NewDecima128Type(thriftType) : StringType.Default, + TTypeId.TIMESTAMP_TYPE => convertScalar ? TimestampType.Default : StringType.Default, + TTypeId.CHAR_TYPE + or TTypeId.NULL_TYPE + or TTypeId.STRING_TYPE + or TTypeId.VARCHAR_TYPE + or TTypeId.INTERVAL_DAY_TIME_TYPE + or TTypeId.INTERVAL_YEAR_MONTH_TYPE + or TTypeId.ARRAY_TYPE + or TTypeId.MAP_TYPE + or TTypeId.STRUCT_TYPE + or TTypeId.UNION_TYPE + or TTypeId.USER_DEFINED_TYPE => StringType.Default, + TTypeId.TIMESTAMPLOCALTZ_TYPE => throw new NotImplementedException(), + _ => throw new NotImplementedException(), + }; + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 5438322abb..97594161e2 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -51,7 +51,7 @@ public override async ValueTask ExecuteQueryAsync() private async Task GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) { TGetResultSetMetadataResp response = await HiveServer2Connection.GetResultSetMetadataAsync(operationHandle, client, cancellationToken); - return Connection.SchemaParser.GetArrowSchema(response.Schema); + return Connection.SchemaParser.GetArrowSchema(response.Schema, Connection.DataTypeConversion); } public override async Task ExecuteUpdateAsync() diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaAuthType.cs b/csharp/src/Drivers/Apache/Impala/ImpalaAuthType.cs new file mode 100644 index 0000000000..bbbea6f5a0 --- /dev/null +++ b/csharp/src/Drivers/Apache/Impala/ImpalaAuthType.cs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace Apache.Arrow.Adbc.Drivers.Apache.Impala +{ + internal enum ImpalaAuthType + { + Invalid = 0, + None, + UsernameOnly, + Basic, + Empty = int.MaxValue, + } + + internal static class AuthTypeOptionsParser + { + internal static bool TryParse(string? authType, out ImpalaAuthType authTypeValue) + { + switch (authType?.Trim().ToLowerInvariant()) + { + case null: + case "": + authTypeValue = ImpalaAuthType.Empty; + return true; + case AuthTypeOptions.None: + authTypeValue = ImpalaAuthType.None; + return true; + case AuthTypeOptions.UsernameOnly: + authTypeValue = ImpalaAuthType.UsernameOnly; + return true; + case AuthTypeOptions.Basic: + authTypeValue = ImpalaAuthType.Basic; + return true; + default: + authTypeValue = ImpalaAuthType.Invalid; + return false; + } + } + } +} diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs index 9a6cd3f5e5..ef12f90f74 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs @@ -84,8 +84,8 @@ public override IArrowArrayStream GetTableTypes() public override Schema GetTableSchema(string? catalog, string? dbSchema, string tableName) => throw new System.NotImplementedException(); - internal override SchemaParser SchemaParser => throw new NotImplementedException(); + internal override SchemaParser SchemaParser { get; } = new HiveServer2SchemaParser(); - internal override IArrowArrayStream NewReader(T statement, Schema schema) => throw new NotImplementedException(); + internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: DataTypeConversion); } } diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaParameters.cs b/csharp/src/Drivers/Apache/Impala/ImpalaParameters.cs new file mode 100644 index 0000000000..443e4180b1 --- /dev/null +++ b/csharp/src/Drivers/Apache/Impala/ImpalaParameters.cs @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace Apache.Arrow.Adbc.Drivers.Apache.Impala +{ + /// + /// Parameters used for connecting to Impala data sources. + /// + public static class ImpalaParameters + { + public const string HostName = "adbc.impala.host"; + public const string Port = "adbc.impala.port"; + public const string Path = "adbc.impala.path"; + public const string AuthType = "adbc.impala.auth_type"; + public const string DataTypeConv = "adbc.impala.data_type_conv"; + } + + public static class AuthTypeOptions + { + public const string None = "none"; + public const string UsernameOnly = "username_only"; + public const string Basic = "basic"; + } +} diff --git a/csharp/src/Drivers/Apache/Spark/README.md b/csharp/src/Drivers/Apache/Spark/README.md index 46c3a0799c..6928b9bb9d 100644 --- a/csharp/src/Drivers/Apache/Spark/README.md +++ b/csharp/src/Drivers/Apache/Spark/README.md @@ -35,7 +35,7 @@ but can also be passed in the call to `AdbcDatabase.Connect`. | `uri` | The full URI that includes scheme, host, port and path. If set, this property takes precedence over `adbc.spark.host`, `adbc.spark.port` and `adbc.spark.path`. | | | `username` | The user name used for basic authentication | | | `password` | The password for the user name used for basic authentication. | | -| `adbc.spark.data_type_conv` | Comma-separated list of data conversion options. Each option indicates the type of conversion to perform on data returned from the Spark server.

Allowed values: `none`.

Option `none` indicates there is no conversion from Spark type to native type (i.e., no conversion from String to Timestamp for Apache Spark over HTTP). Example `adbc.spark.conv_data_type=none`.

(_Planned supported values_: `scalar`. Option `scalar` will perform conversion (if necessary) from the Spark data type to corresponding Arrow data types for types `DATE/Date32/DateTime`, `DECIMAL/Decimal128/SqlDecimal`, and `TIMESTAMP/Timestamp/DateTimeOffset`. Example `adbc.spark.conv_data_type=scalar`) | `scalar` | +| `adbc.spark.data_type_conv` | Comma-separated list of data conversion options. Each option indicates the type of conversion to perform on data returned from the Spark server.

Allowed values: `none`, `scalar`.

Option `none` indicates there is no conversion from Spark type to native type (i.e., no conversion from String to Timestamp for Apache Spark over HTTP). Example `adbc.spark.conv_data_type=none`.

Option `scalar` will perform conversion (if necessary) from the Spark data type to corresponding Arrow data types for types `DATE/Date32/DateTime`, `DECIMAL/Decimal128/SqlDecimal`, and `TIMESTAMP/Timestamp/DateTimeOffset`. Example `adbc.spark.conv_data_type=scalar` | `scalar` | | `adbc.statement.batch_size` | Sets the maximum number of rows to retrieve in a single batch request. | `50000` | | `adbc.statement.polltime_milliseconds` | If polling is necessary to get a result, this option sets the length of time (in milliseconds) to wait between polls. | `500` | @@ -70,32 +70,32 @@ The following table depicts how the Spark ADBC driver converts a Spark type to a | USER_DEFINED | String | string | | VARCHAR | String | string | -### Apache Spark over HTTP (when: adbc.spark.data_type_conv = none) - -| Spark Type | Arrow Type | C# Type | -| :--- | :---: | :---: | -| ARRAY* | String | string | -| BIGINT | Int64 | long | -| BINARY | Binary | byte[] | -| BOOLEAN | Boolean | bool | -| CHAR | String | string | -| DATE* | *String* | *string* | -| DECIMAL* | *String* | *string* | -| DOUBLE | Double | double | -| FLOAT | *Double* | *double* | -| INT | Int32 | int | -| INTERVAL_DAY_TIME+ | String | string | -| INTERVAL_YEAR_MONTH+ | String | string | -| MAP* | String | string | -| NULL | String | string | -| SMALLINT | Int16 | short | -| STRING | String | string | -| STRUCT* | String | string | -| TIMESTAMP* | *String* | *string* | -| TINYINT | Int8 | sbyte | -| UNION | String | string | -| USER_DEFINED | String | string | -| VARCHAR | String | string | +### Apache Spark over HTTP (adbc.spark.data_type_conv = ?) + +| Spark Type | Arrow Type (`none`) | C# Type (`none`) | Arrow Type (`scalar`) | C# Type (`scalar`) | +| :--- | :---: | :---: | :---: | :---: | +| ARRAY* | String | string | | | +| BIGINT | Int64 | long | | | +| BINARY | Binary | byte[] | | | +| BOOLEAN | Boolean | bool | | | +| CHAR | String | string | | | +| DATE* | *String* | *string* | Date32 | DateTime | +| DECIMAL* | *String* | *string* | Decimal128 | SqlDecimal | +| DOUBLE | Double | double | | | +| FLOAT | *Double* | *double* | | | +| INT | Int32 | int | | | +| INTERVAL_DAY_TIME+ | String | string | | | +| INTERVAL_YEAR_MONTH+ | String | string | | | +| MAP* | String | string | | | +| NULL | String | string | | | +| SMALLINT | Int16 | short | | | +| STRING | String | string | | | +| STRUCT* | String | string | | | +| TIMESTAMP* | *String* | *string* | Timestamp | DateTimeOffset | +| TINYINT | Int8 | sbyte | | | +| UNION | String | string | | | +| USER_DEFINED | String | string | | | +| VARCHAR | String | string | | | \* Types are returned as strings instead of "native" types
\+ Interval types are returned as strings @@ -112,7 +112,7 @@ Basic (username and password) authenication is not supported, at this time. ### Apache Spark over HTPP -This is currently unsupported. (Under development) +Support for Spark over HTTP is initial. ### Apache Spark Standard diff --git a/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs b/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs new file mode 100644 index 0000000000..8afb81c1e0 --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs @@ -0,0 +1,58 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal enum SparkAuthType + { + Invalid = 0, + None, + UsernameOnly, + Basic, + Token, + Empty = int.MaxValue, + } + + internal static class AuthTypeParser + { + internal static bool TryParse(string? authType, out SparkAuthType authTypeValue) + { + switch (authType?.Trim().ToLowerInvariant()) + { + case null: + case "": + authTypeValue = SparkAuthType.Empty; + return true; + case SparkAuthTypeConstants.None: + authTypeValue = SparkAuthType.None; + return true; + case SparkAuthTypeConstants.UsernameOnly: + authTypeValue = SparkAuthType.UsernameOnly; + return true; + case SparkAuthTypeConstants.Basic: + authTypeValue = SparkAuthType.Basic; + return true; + case SparkAuthTypeConstants.Token: + authTypeValue = SparkAuthType.Token; + return true; + default: + authTypeValue = SparkAuthType.Invalid; + return false; + } + } + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index c1500663b1..f532369e62 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -998,8 +998,6 @@ protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, protected abstract void ValidateAuthentication(); protected abstract void ValidateOptions(); - protected SparkDataTypeConversion DataTypeConversion = SparkDataTypeConversion.None; - protected abstract Task GetRowSetAsync(TGetTableTypesResp response); protected abstract Task GetRowSetAsync(TGetColumnsResp response); protected abstract Task GetRowSetAsync(TGetTablesResp response); diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs b/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs index ba889e7658..7e432289e0 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs @@ -25,15 +25,15 @@ internal class SparkConnectionFactory public static SparkConnection NewConnection(IReadOnlyDictionary properties) { bool _ = properties.TryGetValue(SparkParameters.Type, out string? type) && string.IsNullOrEmpty(type); - bool __ = SparkServerTypeConstants.TryParse(type, out SparkServerType serverTypeValue); + bool __ = ServerTypeParser.TryParse(type, out SparkServerType serverTypeValue); return serverTypeValue switch { SparkServerType.Databricks => new SparkDatabricksConnection(properties), SparkServerType.Http => new SparkHttpConnection(properties), // TODO: Re-enable when properly supported //SparkServerType.Standard => new SparkStandardConnection(properties), - SparkServerType.Empty => throw new ArgumentException($"Required property '{SparkParameters.Type}' is missing. Supported types: {SparkServerTypeConstants.SupportedList}", nameof(properties)), - _ => throw new ArgumentOutOfRangeException(nameof(properties), $"Unsupported or unknown value '{type}' given for property '{SparkParameters.Type}'. Supported types: {SparkServerTypeConstants.SupportedList}"), + SparkServerType.Empty => throw new ArgumentException($"Required property '{SparkParameters.Type}' is missing. Supported types: {ServerTypeParser.SupportedList}", nameof(properties)), + _ => throw new ArgumentOutOfRangeException(nameof(properties), $"Unsupported or unknown value '{type}' given for property '{SparkParameters.Type}'. Supported types: {ServerTypeParser.SupportedList}"), }; } diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs index 58ef731378..a20964ce44 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs @@ -15,11 +15,10 @@ * limitations under the License. */ -using System; using System.Collections.Generic; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; using Apache.Arrow.Ipc; -using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark @@ -32,7 +31,7 @@ public SparkDatabricksConnection(IReadOnlyDictionary properties) internal override IArrowArrayStream NewReader(T statement, Schema schema) => new SparkDatabricksReader(statement, schema); - internal override SchemaParser SchemaParser => new DatabricksSchemaParser(); + internal override SchemaParser SchemaParser => new SparkDatabricksSchemaParser(); internal override SparkServerType ServerType => SparkServerType.Databricks; @@ -45,7 +44,12 @@ protected override TOpenSessionReq CreateSessionRequest() return req; } - protected override void ValidateOptions() { } + protected override void ValidateOptions() + { + Properties.TryGetValue(SparkParameters.DataTypeConv, out string? dataTypeConv); + // Note: In Databricks, scalar types are provided implicitly. + DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv); + } protected override Task GetResultSetMetadataAsync(TGetSchemasResp response) => Task.FromResult(response.DirectResults.ResultSetMetadata); @@ -66,39 +70,5 @@ protected override Task GetRowSetAsync(TGetCatalogsResp response) => Task.FromResult(response.DirectResults.ResultSet.Results); protected override Task GetRowSetAsync(TGetSchemasResp response) => Task.FromResult(response.DirectResults.ResultSet.Results); - - internal class DatabricksSchemaParser : SchemaParser - { - public override IArrowType GetArrowType(TPrimitiveTypeEntry thriftType) - { - return thriftType.Type switch - { - TTypeId.BIGINT_TYPE => Int64Type.Default, - TTypeId.BINARY_TYPE => BinaryType.Default, - TTypeId.BOOLEAN_TYPE => BooleanType.Default, - TTypeId.DATE_TYPE => Date32Type.Default, - TTypeId.DOUBLE_TYPE => DoubleType.Default, - TTypeId.FLOAT_TYPE => FloatType.Default, - TTypeId.INT_TYPE => Int32Type.Default, - TTypeId.NULL_TYPE => NullType.Default, - TTypeId.SMALLINT_TYPE => Int16Type.Default, - TTypeId.TIMESTAMP_TYPE => new TimestampType(TimeUnit.Microsecond, (string?)null), - TTypeId.TINYINT_TYPE => Int8Type.Default, - TTypeId.DECIMAL_TYPE => new Decimal128Type(thriftType.TypeQualifiers.Qualifiers["precision"].I32Value, thriftType.TypeQualifiers.Qualifiers["scale"].I32Value), - TTypeId.CHAR_TYPE - or TTypeId.STRING_TYPE - or TTypeId.VARCHAR_TYPE - or TTypeId.INTERVAL_DAY_TIME_TYPE - or TTypeId.INTERVAL_YEAR_MONTH_TYPE - or TTypeId.ARRAY_TYPE - or TTypeId.MAP_TYPE - or TTypeId.STRUCT_TYPE - or TTypeId.UNION_TYPE - or TTypeId.USER_DEFINED_TYPE => StringType.Default, - TTypeId.TIMESTAMPLOCALTZ_TYPE => throw new NotImplementedException(), - _ => throw new NotImplementedException(), - }; - } - } } } diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksSchemaParser.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabricksSchemaParser.cs new file mode 100644 index 0000000000..995c1edf09 --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksSchemaParser.cs @@ -0,0 +1,58 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Types; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal class SparkDatabricksSchemaParser : SchemaParser + { + public override IArrowType GetArrowType(TPrimitiveTypeEntry thriftType, DataTypeConversion dataTypeConversion) + { + return thriftType.Type switch + { + TTypeId.BIGINT_TYPE => Int64Type.Default, + TTypeId.BINARY_TYPE => BinaryType.Default, + TTypeId.BOOLEAN_TYPE => BooleanType.Default, + TTypeId.DATE_TYPE => Date32Type.Default, + TTypeId.DOUBLE_TYPE => DoubleType.Default, + TTypeId.FLOAT_TYPE => FloatType.Default, + TTypeId.INT_TYPE => Int32Type.Default, + TTypeId.NULL_TYPE => NullType.Default, + TTypeId.SMALLINT_TYPE => Int16Type.Default, + TTypeId.TIMESTAMP_TYPE => new TimestampType(TimeUnit.Microsecond, (string?)null), + TTypeId.TINYINT_TYPE => Int8Type.Default, + TTypeId.DECIMAL_TYPE => NewDecima128Type(thriftType), + TTypeId.CHAR_TYPE + or TTypeId.STRING_TYPE + or TTypeId.VARCHAR_TYPE + or TTypeId.INTERVAL_DAY_TIME_TYPE + or TTypeId.INTERVAL_YEAR_MONTH_TYPE + or TTypeId.ARRAY_TYPE + or TTypeId.MAP_TYPE + or TTypeId.STRUCT_TYPE + or TTypeId.UNION_TYPE + or TTypeId.USER_DEFINED_TYPE => StringType.Default, + TTypeId.TIMESTAMPLOCALTZ_TYPE => throw new NotImplementedException(), + _ => throw new NotImplementedException(), + }; + } + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs index 20c361b046..8d7dc6fe63 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs @@ -18,20 +18,14 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; -using System.Reflection; using System.Text; -using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; -using Apache.Arrow.Adbc.Drivers.Apache.Thrift; -using Apache.Arrow.Adbc.Extensions; using Apache.Arrow.Ipc; -using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; using Thrift; using Thrift.Protocol; @@ -55,7 +49,7 @@ protected override void ValidateAuthentication() Properties.TryGetValue(AdbcOptions.Username, out string? username); Properties.TryGetValue(AdbcOptions.Password, out string? password); Properties.TryGetValue(SparkParameters.AuthType, out string? authType); - bool isValidAuthType = SparkAuthTypeConstants.TryParse(authType, out SparkAuthType authTypeValue); + bool isValidAuthType = AuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue); switch (authTypeValue) { case SparkAuthType.Token: @@ -121,15 +115,10 @@ protected override void ValidateConnection() protected override void ValidateOptions() { Properties.TryGetValue(SparkParameters.DataTypeConv, out string? dataTypeConv); - SparkDataTypeConversionConstants.TryParse(dataTypeConv, out SparkDataTypeConversion dataTypeConversionValue); - DataTypeConversion = dataTypeConversionValue switch - { - SparkDataTypeConversion.None => dataTypeConversionValue!, - _ => throw new NotImplementedException($"Invalid or unsupported data type conversion option: '{dataTypeConv}'. Supported values: {SparkDataTypeConversionConstants.SupportedList}"), - }; + DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv); } - internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema); + internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); protected override Task CreateTransportAsync() { @@ -143,7 +132,7 @@ protected override Task CreateTransportAsync() Properties.TryGetValue(SparkParameters.Path, out string? path); Properties.TryGetValue(SparkParameters.Port, out string? port); Properties.TryGetValue(SparkParameters.AuthType, out string? authType); - bool isValidAuthType = SparkAuthTypeConstants.TryParse(authType, out SparkAuthType authTypeValue); + bool isValidAuthType = AuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue); Properties.TryGetValue(SparkParameters.Token, out string? token); Properties.TryGetValue(AdbcOptions.Username, out string? username); Properties.TryGetValue(AdbcOptions.Password, out string? password); @@ -248,39 +237,5 @@ internal static async Task FetchNextAsync(TOperationHandle op internal override SchemaParser SchemaParser => new HiveServer2SchemaParser(); internal override SparkServerType ServerType => SparkServerType.Http; - - internal class HiveServer2SchemaParser : SchemaParser - { - public override IArrowType GetArrowType(TPrimitiveTypeEntry thriftType) - { - return thriftType.Type switch - { - TTypeId.BIGINT_TYPE => Int64Type.Default, - TTypeId.BINARY_TYPE => BinaryType.Default, - TTypeId.BOOLEAN_TYPE => BooleanType.Default, - TTypeId.DOUBLE_TYPE - or TTypeId.FLOAT_TYPE => DoubleType.Default, - TTypeId.INT_TYPE => Int32Type.Default, - TTypeId.SMALLINT_TYPE => Int16Type.Default, - TTypeId.TINYINT_TYPE => Int8Type.Default, - TTypeId.CHAR_TYPE - or TTypeId.DATE_TYPE - or TTypeId.DECIMAL_TYPE - or TTypeId.NULL_TYPE - or TTypeId.STRING_TYPE - or TTypeId.TIMESTAMP_TYPE - or TTypeId.VARCHAR_TYPE - or TTypeId.INTERVAL_DAY_TIME_TYPE - or TTypeId.INTERVAL_YEAR_MONTH_TYPE - or TTypeId.ARRAY_TYPE - or TTypeId.MAP_TYPE - or TTypeId.STRUCT_TYPE - or TTypeId.UNION_TYPE - or TTypeId.USER_DEFINED_TYPE => StringType.Default, - TTypeId.TIMESTAMPLOCALTZ_TYPE => throw new NotImplementedException(), - _ => throw new NotImplementedException(), - }; - } - } } } diff --git a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs index e0a3c6c899..f2251c648d 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs @@ -39,42 +39,6 @@ public static class SparkAuthTypeConstants public const string UsernameOnly = "username_only"; public const string Basic = "basic"; public const string Token = "token"; - - public static bool TryParse(string? authType, out SparkAuthType authTypeValue) - { - switch (authType?.Trim().ToLowerInvariant()) - { - case null: - case "": - authTypeValue = SparkAuthType.Empty; - return true; - case None: - authTypeValue = SparkAuthType.None; - return true; - case UsernameOnly: - authTypeValue = SparkAuthType.UsernameOnly; - return true; - case Basic: - authTypeValue = SparkAuthType.Basic; - return true; - case Token: - authTypeValue = SparkAuthType.Token; - return true; - default: - authTypeValue = SparkAuthType.Invalid; - return false; - } - } - } - - public enum SparkAuthType - { - Invalid = 0, - None, - UsernameOnly, - Basic, - Token, - Empty = int.MaxValue, } public static class SparkServerTypeConstants @@ -82,71 +46,5 @@ public static class SparkServerTypeConstants public const string Http = "http"; public const string Databricks = "databricks"; public const string Standard = "standard"; - internal const string SupportedList = Http + ", " + Databricks; - - public static bool TryParse(string? serverType, out SparkServerType serverTypeValue) - { - switch (serverType?.Trim().ToLowerInvariant()) - { - case null: - case "": - serverTypeValue = SparkServerType.Empty; - return true; - case Databricks: - serverTypeValue = SparkServerType.Databricks; - return true; - case Http: - serverTypeValue = SparkServerType.Http; - return true; - case Standard: - serverTypeValue = SparkServerType.Standard; - return true; - default: - serverTypeValue = SparkServerType.Invalid; - return false; - } - } - } - - public enum SparkServerType - { - Invalid = 0, - Http, - Databricks, - Standard, - Empty = int.MaxValue, - } - - public static class SparkDataTypeConversionConstants - { - public const string None = "none"; - public const string Scalar = "scalar"; - public const string SupportedList = None; - - public static bool TryParse(string? dataTypeConversion, out SparkDataTypeConversion dataTypeConversionValue) - { - switch (dataTypeConversion?.Trim().ToLowerInvariant()) - { - case null: - case "": - case Scalar: - dataTypeConversionValue = SparkDataTypeConversion.Scalar; - return true; - case None: - dataTypeConversionValue = SparkDataTypeConversion.None; - return true; - default: - dataTypeConversionValue = SparkDataTypeConversion.Invalid; - return false; - } - } - } - - public enum SparkDataTypeConversion - { - Invalid = 0, - None, - Scalar, - Empty = int.MaxValue, } } diff --git a/csharp/src/Drivers/Apache/Spark/SparkServerType.cs b/csharp/src/Drivers/Apache/Spark/SparkServerType.cs new file mode 100644 index 0000000000..351a2a0b9d --- /dev/null +++ b/csharp/src/Drivers/Apache/Spark/SparkServerType.cs @@ -0,0 +1,56 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +namespace Apache.Arrow.Adbc.Drivers.Apache.Spark +{ + internal enum SparkServerType + { + Invalid = 0, + Http, + Databricks, + Standard, + Empty = int.MaxValue, + } + + internal static class ServerTypeParser + { + internal const string SupportedList = SparkServerTypeConstants.Http + ", " + SparkServerTypeConstants.Databricks; + + internal static bool TryParse(string? serverType, out SparkServerType serverTypeValue) + { + switch (serverType?.Trim().ToLowerInvariant()) + { + case null: + case "": + serverTypeValue = SparkServerType.Empty; + return true; + case SparkServerTypeConstants.Databricks: + serverTypeValue = SparkServerType.Databricks; + return true; + case SparkServerTypeConstants.Http: + serverTypeValue = SparkServerType.Http; + return true; + case SparkServerTypeConstants.Standard: + serverTypeValue = SparkServerType.Standard; + return true; + default: + serverTypeValue = SparkServerType.Invalid; + return false; + } + } + } +} diff --git a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs index 54b1717037..51813ed6c4 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs @@ -36,7 +36,7 @@ protected override void ValidateAuthentication() Properties.TryGetValue(AdbcOptions.Username, out string? username); Properties.TryGetValue(AdbcOptions.Password, out string? password); Properties.TryGetValue(SparkParameters.AuthType, out string? authType); - bool isValidAuthType = SparkAuthTypeConstants.TryParse(authType, out SparkAuthType authTypeValue); + bool isValidAuthType = AuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue); switch (authTypeValue) { case SparkAuthType.None: @@ -112,7 +112,7 @@ protected override TOpenSessionReq CreateSessionRequest() Properties.TryGetValue(AdbcOptions.Username, out string? username); Properties.TryGetValue(AdbcOptions.Password, out string? password); Properties.TryGetValue(SparkParameters.AuthType, out string? authType); - bool isValidAuthType = SparkAuthTypeConstants.TryParse(authType, out SparkAuthType authTypeValue); + bool isValidAuthType = AuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue); TOpenSessionReq request = base.CreateSessionRequest(); switch (authTypeValue) { diff --git a/csharp/src/Drivers/Apache/Thrift/SchemaParser.cs b/csharp/src/Drivers/Apache/Thrift/SchemaParser.cs index 546f7e1ed0..bee4a8485b 100644 --- a/csharp/src/Drivers/Apache/Thrift/SchemaParser.cs +++ b/csharp/src/Drivers/Apache/Thrift/SchemaParser.cs @@ -16,6 +16,7 @@ */ using System; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; @@ -23,27 +24,30 @@ namespace Apache.Arrow.Adbc.Drivers.Apache { internal abstract class SchemaParser { - internal Schema GetArrowSchema(TTableSchema thriftSchema) + internal Schema GetArrowSchema(TTableSchema thriftSchema, DataTypeConversion dataTypeConversion) { Field[] fields = new Field[thriftSchema.Columns.Count]; for (int i = 0; i < thriftSchema.Columns.Count; i++) { TColumnDesc column = thriftSchema.Columns[i]; // Note: no nullable metadata is returned from the Thrift interface. - fields[i] = new Field(column.ColumnName, GetArrowType(column.TypeDesc.Types[0]), nullable: true /* assumed */); + fields[i] = new Field(column.ColumnName, GetArrowType(column.TypeDesc.Types[0], dataTypeConversion), nullable: true /* assumed */); } return new Schema(fields, null); } - IArrowType GetArrowType(TTypeEntry thriftType) + IArrowType GetArrowType(TTypeEntry thriftType, DataTypeConversion dataTypeConversion) { if (thriftType.PrimitiveEntry != null) { - return GetArrowType(thriftType.PrimitiveEntry); + return GetArrowType(thriftType.PrimitiveEntry, dataTypeConversion); } throw new InvalidOperationException(); } - public abstract IArrowType GetArrowType(TPrimitiveTypeEntry thriftType); + public abstract IArrowType GetArrowType(TPrimitiveTypeEntry thriftType, DataTypeConversion dataTypeConversion); + + protected static Decimal128Type NewDecima128Type(TPrimitiveTypeEntry thriftType) => + new(thriftType.TypeQualifiers.Qualifiers["precision"].I32Value, thriftType.TypeQualifiers.Qualifiers["scale"].I32Value); } } diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs b/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs index f282c04123..d4662c504f 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs @@ -48,10 +48,10 @@ public abstract class TestBase : IDisposable /// Constructs a new TestBase object with an output helper. /// /// Test output helper for writing test output. - public TestBase(ITestOutputHelper? outputHelper, TestEnvironment.Factory testEnvFacltory) + public TestBase(ITestOutputHelper? outputHelper, TestEnvironment.Factory testEnvFactory) { OutputHelper = outputHelper; - _testEnvFactory = testEnvFacltory; + _testEnvFactory = testEnvFactory; _testEnvironment = new Lazy(() => _testEnvFactory.Create(() => Connection)); _testConfiguration = new Lazy(() => Utils.LoadTestConfiguration(TestConfigVariable)); _connection = new Lazy(() => NewConnection()); diff --git a/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj b/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj index 4b2aa20618..5f20e2e145 100644 --- a/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj +++ b/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj @@ -1,4 +1,4 @@ - + net8.0;net472 @@ -27,6 +27,9 @@ + + PreserveNewest + PreserveNewest @@ -36,7 +39,7 @@ PreserveNewest - + PreserveNewest @@ -44,8 +47,4 @@ - - - - diff --git a/csharp/test/Drivers/Apache/Hive2/DecimalUtilityTests.cs b/csharp/test/Drivers/Apache/Hive2/DecimalUtilityTests.cs new file mode 100644 index 0000000000..a625dabb24 --- /dev/null +++ b/csharp/test/Drivers/Apache/Hive2/DecimalUtilityTests.cs @@ -0,0 +1,172 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Globalization; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2 +{ + /// + /// Class for testing the Decimal Utilities tests. + /// + public class DecimalUtilityTests(ITestOutputHelper outputHelper) + { + private readonly ITestOutputHelper _outputHelper = outputHelper; + + [SkippableTheory] + [MemberData(nameof(Decimal128Data))] + public void TestCanConvertDecimal(string value, int precision, int scale, int byteWidth, byte[] expected, SqlDecimal? expectedDecimal = default) + { + byte[] actual = new byte[byteWidth]; + DecimalUtility.GetBytes(value, precision, scale, byteWidth, actual); + Assert.Equal(expected, actual); + Assert.Equal(0, byteWidth % 4); + int[] buffer = new int[byteWidth / 4]; + for (int i = 0; i < buffer.Length; i++) + { + buffer[i] = BitConverter.ToInt32(actual, i * sizeof(int)); + } + SqlDecimal actualDecimal = GetSqlDecimal128(actual, 0, precision, scale); + if (expectedDecimal != null) Assert.Equal(expectedDecimal, actualDecimal); + } + + [Fact(Skip = "Run manually to confirm equivalent performance")] + public void TestConvertDecimalPerformance() + { + Stopwatch stopwatch = new(); + + int testCount = 1000000; + string testValue = "99999999999999999999999999999999999999"; + int byteWidth = 16; + byte[] buffer = new byte[byteWidth]; + Decimal128Array.Builder builder = new Decimal128Array.Builder(new Types.Decimal128Type(38, 0)); + stopwatch.Restart(); + for (int i = 0; i < testCount; i++) + { + if (decimal.TryParse(testValue, NumberStyles.Float, NumberFormatInfo.InvariantInfo, out var actualDecimal)) + { + builder.Append(new SqlDecimal(actualDecimal)); + } + else + { + builder.Append(testValue); + } + } + stopwatch.Stop(); + _outputHelper.WriteLine($"Decimal128Builder.Append: {testCount} iterations took {stopwatch.ElapsedMilliseconds} elapsed milliseconds"); + + stopwatch.Restart(); + for (int i = 0; i < testCount; i++) + { + DecimalUtility.GetBytes(testValue, 38, 0, byteWidth, buffer); + builder.Append(buffer); + } + stopwatch.Stop(); + _outputHelper.WriteLine($"DecimalUtility.GetBytes: {testCount} iterations took {stopwatch.ElapsedMilliseconds} elapsed milliseconds"); + } + + public static IEnumerable Decimal128Data() + { + yield return new object[] { "0", 1, 0, 16, new byte[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0) }; + + yield return new object[] { "1", 1, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "1E0", 1, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "10e-1", 1, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "0.1e1", 1, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + + yield return new object[] { "12", 2, 0, 16, new byte[] { 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(12) }; + yield return new object[] { "12E0", 2, 0, 16, new byte[] { 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(12) }; + yield return new object[] { "120e-1", 2, 0, 16, new byte[] { 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(12) }; + yield return new object[] { "1.2e1", 2, 0, 16, new byte[] { 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(12) }; + + yield return new object[] { "99999999999999999999999999999999999999", 38, 0, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + yield return new object[] { "99999999999999999999999999999999999999E0", 38, 0, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + yield return new object[] { "999999999999999999999999999999999999990e-1", 38, 0, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + yield return new object[] { "0.99999999999999999999999999999999999999e38", 38, 0, 16, new byte[] { 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75 } }; + + yield return new object[] { "-1", 1, 0, 16, new byte[] { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-1) }; + yield return new object[] { "-1E0", 1, 0, 16, new byte[] { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-1) }; + yield return new object[] { "-10e-1", 1, 0, 16, new byte[] { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-1) }; + yield return new object[] { "-0.1e1", 1, 0, 16, new byte[] { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-1) }; + + yield return new object[] { "-12", 2, 0, 16, new byte[] { 244, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-12) }; + yield return new object[] { "-12E0", 2, 0, 16, new byte[] { 244, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-12) }; + yield return new object[] { "-120e-1", 2, 0, 16, new byte[] { 244, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-12) }; + yield return new object[] { "-1.2e1", 2, 0, 16, new byte[] { 244, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-12) }; + + yield return new object[] { "1", 38, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "1E0", 38, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "10e-1", 38, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "0.1e1", 38, 0, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + + yield return new object[] { "1", 3, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "1E0", 3, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "10e-1", 3, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "0.1e1", 3, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + + yield return new object[] { "1", 38, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "1E0", 38, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "10e-1", 38, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + yield return new object[] { "0.1e1", 38, 2, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(1) }; + + yield return new object[] { "0.1", 38, 1, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "0.1E0", 38, 1, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "1e-1", 38, 1, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "0.01e1", 38, 1, 16, new byte[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + + yield return new object[] { "0.1", 38, 3, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "0.1E0", 38, 3, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "1e-1", 38, 3, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + yield return new object[] { "0.01e1", 38, 3, 16, new byte[] { 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, new SqlDecimal(0.1) }; + + yield return new object[] { "-0.1", 38, 3, 16, new byte[] { 156, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-0.1) }; + yield return new object[] { "-0.1E0", 38, 3, 16, new byte[] { 156, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-0.1) }; + yield return new object[] { "-1e-1", 38, 3, 16, new byte[] { 156, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-0.1) }; + yield return new object[] { "-0.01e1", 38, 3, 16, new byte[] { 156, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 }, new SqlDecimal(-0.1) }; + } + + private static SqlDecimal GetSqlDecimal128(in byte[] valueBuffer, int index, int precision, int scale) + { + const int byteWidth = 16; + const int intWidth = byteWidth / 4; + const int longWidth = byteWidth / 8; + + byte mostSignificantByte = valueBuffer.AsSpan()[(index + 1) * byteWidth - 1]; + bool isPositive = (mostSignificantByte & 0x80) == 0; + + if (isPositive) + { + ReadOnlySpan value = valueBuffer.AsSpan().CastTo().Slice(index * intWidth, intWidth); + return new SqlDecimal((byte)precision, (byte)scale, true, value[0], value[1], value[2], value[3]); + } + else + { + ReadOnlySpan value = valueBuffer.AsSpan().CastTo().Slice(index * longWidth, longWidth); + long data1 = -value[0]; + long data2 = data1 == 0 ? -value[1] : ~value[1]; + + return new SqlDecimal((byte)precision, (byte)scale, false, (int)(data1 & 0xffffffff), (int)(data1 >> 32), (int)(data2 & 0xffffffff), (int)(data2 >> 32)); + } + } + } +} diff --git a/csharp/test/Drivers/Apache/Hive2/HiveServer2ParametersTest.cs b/csharp/test/Drivers/Apache/Hive2/HiveServer2ParametersTest.cs new file mode 100644 index 0000000000..a10e953294 --- /dev/null +++ b/csharp/test/Drivers/Apache/Hive2/HiveServer2ParametersTest.cs @@ -0,0 +1,62 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Xunit; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2 +{ + public class HiveServer2ParametersTest + { + [SkippableTheory] + [MemberData(nameof(GetParametersTestData))] + internal void TestParametersParse(string? dataTypeConversion, DataTypeConversion expected, Type? exceptionType = default) + { + if (exceptionType == default) + Assert.Equal(expected, DataTypeConversionParser.Parse(dataTypeConversion)); + else + Assert.Throws(exceptionType, () => DataTypeConversionParser.Parse(dataTypeConversion)); + } + + public static IEnumerable GetParametersTestData() + { + // Default + yield return new object?[] { null, DataTypeConversion.Scalar }; + yield return new object?[] { "", DataTypeConversion.Scalar }; + yield return new object?[] { ",", DataTypeConversion.Scalar }; + // Explicit + yield return new object?[] { $"scalar", DataTypeConversion.Scalar }; + yield return new object?[] { $"none", DataTypeConversion.None }; + // Ignore "empty", embedded space, mixed-case + yield return new object?[] { $"scalar,", DataTypeConversion.Scalar }; + yield return new object?[] { $",scalar,", DataTypeConversion.Scalar }; + yield return new object?[] { $",scAlAr,", DataTypeConversion.Scalar }; + yield return new object?[] { $"scAlAr", DataTypeConversion.Scalar }; + yield return new object?[] { $" scalar ", DataTypeConversion.Scalar }; + // Combined - conflicting + yield return new object?[] { $"none,scalar", DataTypeConversion.None | DataTypeConversion.Scalar, typeof(ArgumentOutOfRangeException) }; + yield return new object?[] { $" nOnE, scAlAr ", DataTypeConversion.None | DataTypeConversion.Scalar, typeof(ArgumentOutOfRangeException) }; + yield return new object?[] { $", none, scalar, ", DataTypeConversion.None | DataTypeConversion.Scalar , typeof(ArgumentOutOfRangeException) }; + yield return new object?[] { $"scalar,none", DataTypeConversion.None | DataTypeConversion.Scalar , typeof(ArgumentOutOfRangeException) }; + // Invalid options + yield return new object?[] { $"xxx", DataTypeConversion.Empty, typeof(ArgumentOutOfRangeException) }; + yield return new object?[] { $"none,scalar,xxx", DataTypeConversion.None | DataTypeConversion.Scalar, typeof(ArgumentOutOfRangeException) }; + } + } +} diff --git a/csharp/test/Drivers/Apache/Impala/ImpalaTestEnvironment.cs b/csharp/test/Drivers/Apache/Impala/ImpalaTestEnvironment.cs new file mode 100644 index 0000000000..57efaf380b --- /dev/null +++ b/csharp/test/Drivers/Apache/Impala/ImpalaTestEnvironment.cs @@ -0,0 +1,77 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Apache.Impala; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala +{ + public class ImpalaTestEnvironment : TestEnvironment + { + public class Factory : Factory + { + public override ImpalaTestEnvironment Create(Func getConnection) => new(getConnection); + } + + private ImpalaTestEnvironment(Func getConnection) : base(getConnection) { } + + public override string TestConfigVariable => "IMPALA_TEST_CONFIG_FILE"; + + public override string SqlDataResourceLocation => "Impala/Resources/ImpalaData.sql"; + + public override int ExpectedColumnCount => 17; + + public override AdbcDriver CreateNewDriver() => new ImpalaDriver(); + + public override string GetCreateTemporaryTableStatement(string tableName, string columns) + { + return string.Format("CREATE TABLE {0} ({1})", tableName, columns); + } + + public override string Delimiter => "`"; + + public override Dictionary GetDriverParameters(ApacheTestConfiguration testConfiguration) + { + Dictionary parameters = new(StringComparer.OrdinalIgnoreCase); + + if (!string.IsNullOrEmpty(testConfiguration.HostName)) + { + parameters.Add("HostName", testConfiguration.HostName!); + } + if (!string.IsNullOrEmpty(testConfiguration.Port)) + { + parameters.Add("Port", testConfiguration.Port!); + } + return parameters; + } + + public override string VendorVersion => ((HiveServer2Connection)Connection).VendorVersion; + + public override bool SupportsDelete => false; + + public override bool SupportsUpdate => false; + + public override bool SupportCatalogName => false; + + public override bool ValidateAffectedRows => false; + + public override string GetInsertStatement(string tableName, string columnName, string? value) => + string.Format("INSERT INTO {0} ({1}) SELECT {2};", tableName, columnName, value ?? "NULL"); + } +} diff --git a/csharp/test/Drivers/Apache/Impala/ImpalaTests.cs b/csharp/test/Drivers/Apache/Impala/ImpalaTests.cs index ae19e22247..f0eee3e64f 100644 --- a/csharp/test/Drivers/Apache/Impala/ImpalaTests.cs +++ b/csharp/test/Drivers/Apache/Impala/ImpalaTests.cs @@ -15,36 +15,30 @@ * limitations under the License. */ -using System; using System.Collections.Generic; using Apache.Arrow.Adbc.Drivers.Apache.Impala; using Apache.Arrow.Adbc.Tests.Xunit; using Xunit; +using Xunit.Abstractions; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala { [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] - public class ImpalaTests + public class ImpalaTests : TestBase { - [SkippableFact, Order(1)] - public void CanDriverConnect() + public ImpalaTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new ImpalaTestEnvironment.Factory()) { - ApacheTestConfiguration testConfiguration = Utils.GetTestConfiguration("impalaconfig.json"); - - Dictionary parameters = new Dictionary(StringComparer.OrdinalIgnoreCase) - { - { "HostName", testConfiguration.HostName }, - { "Port", testConfiguration.Port }, - }; + } - AdbcDatabase database = new ImpalaDriver().Open(parameters); - AdbcConnection connection = database.Connect(new Dictionary()); - AdbcStatement statement = connection.CreateStatement(); - statement.SqlQuery = testConfiguration.Query; + [SkippableFact, Order(1)] + public void CanExecuteQuery() + { + AdbcStatement statement = Connection.CreateStatement(); + statement.SqlQuery = TestConfiguration.Query; QueryResult queryResult = statement.ExecuteQuery(); - //Adbc.Tests.ConnectionTests.CanDriverConnect(queryResult, testConfiguration.ExpectedResultsCount); - + DriverTests.CanExecuteQuery(queryResult, TestConfiguration.ExpectedResultsCount); } } } diff --git a/csharp/test/Drivers/Apache/Spark/Resources/SparkData-3.4.sql b/csharp/test/Drivers/Apache/Impala/Resources/ImpalaData.sql similarity index 100% rename from csharp/test/Drivers/Apache/Spark/Resources/SparkData-3.4.sql rename to csharp/test/Drivers/Apache/Impala/Resources/ImpalaData.sql diff --git a/csharp/test/Drivers/Apache/Impala/Resources/impalaconfig.json b/csharp/test/Drivers/Apache/Impala/Resources/impalaconfig.json index 550fd3a97c..acd5c1b983 100644 --- a/csharp/test/Drivers/Apache/Impala/Resources/impalaconfig.json +++ b/csharp/test/Drivers/Apache/Impala/Resources/impalaconfig.json @@ -1,6 +1,12 @@ { - "hostName": "", - "port": "", - "query": "", - "expectedResults": 0 + "environment": "Impala", + "hostName": "", + "port": "", + "query": "", + "expectedResults": 0, + "metadata": { + "schema": "", + "table": "", + "expectedColumnCount": 0 + } } diff --git a/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs b/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs index 771168bc4b..81c872f5a2 100644 --- a/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs +++ b/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Globalization; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; using Xunit; using Xunit.Abstractions; @@ -58,7 +59,7 @@ public class DateTimeValueTests : TestBase expectedResults = VendorVersionAsVersion < Version.Parse("3.4.0") - ? new() - { + List expectedResults = TestEnvironment.ServerType != SparkServerType.Databricks + ? + [ -1, // DROP TABLE -1, // CREATE TABLE 1, // INSERT @@ -99,9 +99,9 @@ public void CanExecuteUpdate() 1, // INSERT //1, // UPDATE //1, // DELETE - } - : new List() - { + ] + : + [ -1, // DROP TABLE -1, // CREATE TABLE 1, // INSERT @@ -109,7 +109,7 @@ public void CanExecuteUpdate() 1, // INSERT 1, // UPDATE 1, // DELETE - }; + ]; for (int i = 0; i < queries.Length; i++) { diff --git a/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs b/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs index 1759c210cd..db041cc04f 100644 --- a/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs +++ b/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs @@ -19,6 +19,7 @@ using System.Data.SqlTypes; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; using Xunit; using Xunit.Abstractions; @@ -261,7 +262,8 @@ public async Task TestFloatValuesInsertSelectDelete(float value) string valueString = ConvertFloatToString(value); await InsertSingleValueAsync(table.TableName, columnName, valueString); object doubleValue = (double)value; - object floatValue = TestEnvironment.GetValueForProtocolVersion(doubleValue, value)!; + // Spark over HTTP returns float as double whereas Spark on Databricks returns float. + object floatValue = TestEnvironment.ServerType != SparkServerType.Databricks ? doubleValue : value; await base.SelectAndValidateValuesAsync(table.TableName, columnName, floatValue, 1); string whereClause = GetWhereClause(columnName, value); if (SupportsDelete) await DeleteFromTableAsync(table.TableName, whereClause, 1); diff --git a/csharp/test/Drivers/Apache/Spark/Resources/SparkData-Databricks.sql b/csharp/test/Drivers/Apache/Spark/Resources/SparkData-Databricks.sql new file mode 100644 index 0000000000..908ffbb930 --- /dev/null +++ b/csharp/test/Drivers/Apache/Spark/Resources/SparkData-Databricks.sql @@ -0,0 +1,133 @@ + + -- Licensed to the Apache Software Foundation (ASF) under one or more + -- contributor license agreements. See the NOTICE file distributed with + -- this work for additional information regarding copyright ownership. + -- The ASF licenses this file to You under the Apache License, Version 2.0 + -- (the "License"); you may not use this file except in compliance with + -- the License. You may obtain a copy of the License at + + -- http://www.apache.org/licenses/LICENSE-2.0 + + -- Unless required by applicable law or agreed to in writing, software + -- distributed under the License is distributed on an "AS IS" BASIS, + -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + -- See the License for the specific language governing permissions and + -- limitations under the License. + +DROP TABLE IF EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE}; + +CREATE TABLE IF NOT EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id LONG, + byte BYTE, + short SHORT, + integer INT, + float FLOAT, + number DOUBLE, + decimal NUMERIC(38, 9), + is_active BOOLEAN, + name STRING, + data BINARY, + date DATE, + timestamp TIMESTAMP, + timestamp_ntz TIMESTAMP_NTZ, + timestamp_ltz TIMESTAMP_LTZ, + numbers ARRAY, + person STRUCT < + name STRING, + age LONG + >, + map MAP < + INT, + STRING + >, + varchar VARCHAR(255), + char CHAR(10) +); + +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id, + byte, short, integer, float, number, decimal, + is_active, + name, data, + date, timestamp, timestamp_ntz, timestamp_ltz, + numbers, + person, + map, + varchar, + char +) +VALUES ( + 1, + 2, 3, 4, 7.89, 1.23, 4.56, + TRUE, + 'John Doe', + -- hex-encoded value `abc123` + X'616263313233', + '2023-09-08', '2023-09-08 12:34:56', '2023-09-08 12:34:56', '2023-09-08 12:34:56+00:00', + ARRAY(1, 2, 3), + STRUCT('John Doe', 30), + MAP(1, 'John Doe'), + 'John Doe', + 'John Doe' +); + +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id, + byte, short, integer, float, number, decimal, + is_active, + name, data, + date, timestamp, timestamp_ntz, timestamp_ltz, + numbers, + person, + map, + varchar, + char +) +VALUES ( + 2, + 127, 32767, 2147483647, 3.4028234663852886e+38, 1.7976931348623157e+308, 9.99999999999999999999999999999999E+28BD, + FALSE, + 'Jane Doe', + -- hex-encoded `def456` + X'646566343536', + '2023-09-09', '2023-09-09 13:45:57', '2023-09-09 13:45:57', '2023-09-09 13:45:57+00:00', + ARRAY(4, 5, 6), + STRUCT('Jane Doe', 40), + MAP(1, 'John Doe'), + 'Jane Doe', + 'Jane Doe' +); + +INSERT INTO {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} ( + id, + byte, short, integer, float, number, decimal, + is_active, + name, data, + date, timestamp, timestamp_ntz, timestamp_ltz, + numbers, + person, + map, + varchar, + char +) +VALUES ( + 3, + -128, -32768, -2147483648, -3.4028234663852886e+38, -1.7976931348623157e+308, -9.99999999999999999999999999999999E+28BD, + FALSE, + 'Jack Doe', + -- hex-encoded `def456` + X'646566343536', + '1556-01-02', '1970-01-01 00:00:00', '1970-01-01 00:00:00', '9999-12-31 23:59:59+00:00', + ARRAY(7, 8, 9), + STRUCT('Jack Doe', 50), + MAP(1, 'John Doe'), + 'Jack Doe', + 'Jack Doe' +); + +UPDATE {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} + SET short = 0 + WHERE id = 3; + +DELETE FROM {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} + WHERE id = 3; diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs index 90561c01de..7b2a4fe418 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs @@ -19,7 +19,6 @@ using System.Collections.Generic; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; using Apache.Arrow.Adbc.Drivers.Apache.Spark; -using Apache.Hive.Service.Rpc.Thrift; namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark { @@ -34,11 +33,11 @@ private SparkTestEnvironment(Func getConnection) : base(getConne public override string TestConfigVariable => "SPARK_TEST_CONFIG_FILE"; - public override string SqlDataResourceLocation => VendorVersionAsVersion >= Version.Parse("3.4.0") - ? "Spark/Resources/SparkData-3.4.sql" + public override string SqlDataResourceLocation => ServerType == SparkServerType.Databricks + ? "Spark/Resources/SparkData-Databricks.sql" : "Spark/Resources/SparkData.sql"; - public override int ExpectedColumnCount => VendorVersionAsVersion >= Version.Parse("3.4.0") ? 19 : 17; + public override int ExpectedColumnCount => ServerType == SparkServerType.Databricks ? 19 : 17; public override AdbcDriver CreateNewDriver() => new SparkDriver(); @@ -47,9 +46,11 @@ public override string GetCreateTemporaryTableStatement(string tableName, string return string.Format("CREATE TABLE {0} ({1})", tableName, columns); } - public string? GetValueForProtocolVersion(string? hiveValue, string? databrickValue) => ServerType != SparkServerType.Databricks ? hiveValue : databrickValue; + public string? GetValueForProtocolVersion(string? hiveValue, string? databrickValue) => + ServerType != SparkServerType.Databricks && ((HiveServer2Connection)Connection).DataTypeConversion.HasFlag(DataTypeConversion.None) ? hiveValue : databrickValue; - public object? GetValueForProtocolVersion(object? hiveValue, object? databrickValue) => ServerType != SparkServerType.Databricks ? hiveValue : databrickValue; + public object? GetValueForProtocolVersion(object? hiveValue, object? databrickValue) => + ServerType != SparkServerType.Databricks && ((HiveServer2Connection)Connection).DataTypeConversion.HasFlag(DataTypeConversion.None) ? hiveValue : databrickValue; public override string Delimiter => "`"; @@ -101,7 +102,7 @@ public override Dictionary GetDriverParameters(SparkTestConfigur return parameters; } - protected SparkServerType ServerType => ((SparkConnection)Connection).ServerType; + internal SparkServerType ServerType => ((SparkConnection)Connection).ServerType; public override string VendorVersion => ((HiveServer2Connection)Connection).VendorVersion; diff --git a/csharp/test/Drivers/Apache/Spark/StringValueTests.cs b/csharp/test/Drivers/Apache/Spark/StringValueTests.cs index b5606a372e..875eb2ea1e 100644 --- a/csharp/test/Drivers/Apache/Spark/StringValueTests.cs +++ b/csharp/test/Drivers/Apache/Spark/StringValueTests.cs @@ -17,9 +17,9 @@ using System; using System.Collections.Generic; -using System.Globalization; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; using Xunit; using Xunit.Abstractions; @@ -50,11 +50,11 @@ public static IEnumerable ByteArrayData(int size) [InlineData(null)] [InlineData("")] [InlineData("你好")] - [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.", "3.4.0")] + [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.", SparkServerType.Databricks)] [InlineData(" Leading and trailing spaces ")] - public async Task TestStringData(string? value, string? minVersion = null) + internal async Task TestStringData(string? value, SparkServerType? serverType = default) { - Skip.If(IsBelowMinimumVersion(minVersion)); + Skip.If(serverType != null && TestEnvironment.ServerType != serverType); string columnName = "STRINGTYPE"; using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "STRING")); await ValidateInsertSelectDeleteSingleValueAsync( @@ -71,11 +71,11 @@ await ValidateInsertSelectDeleteSingleValueAsync( [InlineData(null)] [InlineData("")] [InlineData("你好")] - [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.", "3.4.0")] + [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.", SparkServerType.Databricks)] [InlineData(" Leading and trailing spaces ")] - public async Task TestVarcharData(string? value, string? minVersion = null) + internal async Task TestVarcharData(string? value, SparkServerType? serverType = default) { - Skip.If(IsBelowMinimumVersion(minVersion)); + Skip.If(serverType != null && TestEnvironment.ServerType != serverType); string columnName = "VARCHARTYPE"; using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, "VARCHAR(100)")); await ValidateInsertSelectDeleteSingleValueAsync( @@ -94,11 +94,11 @@ await ValidateInsertSelectDeleteSingleValueAsync( [InlineData(null)] [InlineData("")] [InlineData("你好")] - [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.", "3.4.0")] + [InlineData("String contains formatting characters tab\t, newline\n, carriage return\r.", SparkServerType.Databricks)] [InlineData(" Leading and trailing spaces ")] - public async Task TestCharData(string? value, string? minVersion = null) + internal async Task TestCharData(string? value, SparkServerType? serverType = default) { - Skip.If(IsBelowMinimumVersion(minVersion)); + Skip.If(serverType != null && TestEnvironment.ServerType != serverType); string columnName = "CHARTYPE"; int fieldLength = 100; using TemporaryTable table = await NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName, $"CHAR({fieldLength})")); @@ -127,13 +127,13 @@ public async Task TestVarcharExceptionData(string value) value, value != null ? QuoteValue(value) : value)); - bool version34OrGreater = VendorVersionAsVersion >= Version.Parse("3.4.0"); - string[] expectedTexts = version34OrGreater + bool serverTypeDatabricks = TestEnvironment.ServerType == SparkServerType.Databricks; + string[] expectedTexts = serverTypeDatabricks ? ["DELTA_EXCEED_CHAR_VARCHAR_LIMIT", "DeltaInvariantViolationException"] : ["Exceeds", "length limitation: 10"]; AssertContainsAll(expectedTexts, exception.Message); - string? expectedSqlState = version34OrGreater ? "22001" : null; + string? expectedSqlState = serverTypeDatabricks ? "22001" : null; Assert.Equal(expectedSqlState, exception.SqlState); }