Skip to content

Commit

Permalink
Snowflake Dialect - pt. 7 (#10612)
Browse files Browse the repository at this point in the history
- Closes #9486
- All tests are succeeding or marked pending
- Created follow up tickets for things that still need to be addressed, including:
- Fixing upload / table update #10609
- Fixing `Count_Distinct` on Boolean columns #10611
- Running the tests on CI is not part of this PR - to be addressed separately
  • Loading branch information
radeusgd authored Jul 23, 2024
1 parent 71bae7e commit ba56f8e
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from Standard.Base import all
import Standard.Base.Errors.Illegal_Argument.Illegal_Argument
import Standard.Base.Errors.Illegal_State.Illegal_State
import Standard.Base.Errors.Unimplemented.Unimplemented
import Standard.Base.Runtime.Ref.Ref

import Standard.Table.Internal.Problem_Builder.Problem_Builder
import Standard.Table.Internal.Vector_Builder.Vector_Builder
Expand Down Expand Up @@ -167,13 +168,12 @@ type Snowflake_Dialect
mapping = self.get_type_mapping
source_type = mapping.sql_type_to_value_type column.sql_type_reference.get
target_value_type = mapping.sql_type_to_value_type target_type
# Boolean to Numeric casts need special handling:
transformed_expression = case source_type.is_boolean && target_value_type.is_numeric of
True ->
SQL_Expression.Operation "IIF" [Internals_Access.column_expression column, SQL_Expression.Literal "1", SQL_Expression.Literal "0"]
False -> Internals_Access.column_expression column
target_type_sql_text = mapping.sql_type_to_text target_type
new_expression = SQL_Expression.Operation "CAST" [transformed_expression, SQL_Expression.Literal target_type_sql_text]

new_expression = make_custom_cast column source_type target_value_type . if_nothing <|
source_expression = Internals_Access.column_expression column
target_type_sql_text = mapping.sql_type_to_text target_type
SQL_Expression.Operation "CAST" [source_expression, SQL_Expression.Literal target_type_sql_text]

new_sql_type_reference = infer_result_type_from_database_callback new_expression
Internal_Column.Value column.name new_sql_type_reference new_expression

Expand Down Expand Up @@ -699,5 +699,29 @@ make_distinct_extension distinct_expressions =
SQL_Builder.code " QUALIFY ROW_NUMBER() OVER (PARTITION BY " ++ joined ++ " ORDER BY 1) = 1 "
Context_Extension.Value position=550 expressions=distinct_expressions run_generator=run_generator

## PRIVATE
Returns a custom cast expression if it is needed for a specific pair of types,
or Nothing if the default cast is sufficient.
make_custom_cast (column : Internal_Column) (source_value_type : Value_Type) (target_value_type : Value_Type) -> SQL_Expression | Nothing =
result = Ref.new Nothing

# Custom expression for boolean to float cast, as regular cast does not support it.
if source_value_type.is_boolean && target_value_type.is_floating_point then
result.put <|
SQL_Expression.Operation "IIF" [Internals_Access.column_expression column, SQL_Expression.Literal "1", SQL_Expression.Literal "0"]

# If the text length is bounded, we need to add a `LEFT` call to truncate to desired length avoiding errors.
if target_value_type.is_text && target_value_type.size.is_nothing.not then
# But we only do so if the source type was also text.
# For any other source type, we keep the original behaviour - failing to convert if the text representation would not fit.
if source_value_type.is_text then result.put <|
max_size = (target_value_type.size : Integer)
truncated = SQL_Expression.Operation "LEFT" [Internals_Access.column_expression column, SQL_Expression.Literal max_size.to_text]
# We still need a cast to ensure the Value_Type gets the max size in it - LEFT returns no size limit unfortunately.
target_type_name = "VARCHAR(" + max_size.to_text + ")"
SQL_Expression.Operation "CAST" [truncated, SQL_Expression.Literal target_type_name]

result.get

## PRIVATE
snowflake_dialect_name = "Snowflake"
1 change: 1 addition & 0 deletions lib/scala/pkg/src/main/resources/default/src/Main.enso
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from Standard.Table import all
from Standard.Database import all
from Standard.AWS import all
from Standard.Google_Api import all
from Standard.Snowflake import all
import Standard.Visualization

main =
Expand Down
92 changes: 32 additions & 60 deletions test/Snowflake_Tests/src/Snowflake_Spec.enso
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ snowflake_specific_spec suite_builder default_connection db_name setup =
# The integer column is treated as NUMBER(38, 0) in Snowflake so the value type reflects that:
i.at "Value Type" . to_vector . should_equal [Value_Type.Char, Value_Type.Decimal 38 0, Value_Type.Boolean, Value_Type.Float]

group_builder.specify "should return Table information, also for aggregated results" <|
group_builder.specify "should return Table information, also for aggregated results" pending="TODO: fix https://github.com/enso-org/enso/issues/10611" <|
i = data.t.aggregate columns=[Aggregate_Column.Concatenate "strs", Aggregate_Column.Sum "ints", Aggregate_Column.Count_Distinct "bools"] . column_info
i.at "Column" . to_vector . should_equal ["Concatenate strs", "Sum ints", "Count Distinct bools"]
i.at "Items Count" . to_vector . should_equal [1, 1, 1]
Expand Down Expand Up @@ -240,7 +240,7 @@ snowflake_specific_spec suite_builder default_connection db_name setup =
# We expect warnings about coercing Decimal types
w1 = Problems.expect_warning Inexact_Type_Coercion t1
w1.requested_type . should_equal (Value_Type.Decimal 24 -3)
w1.actual_type . should_equal (Value_Type.Decimal 38 0)
w1.actual_type . should_equal (Value_Type.Decimal Nothing Nothing)

t1.update_rows (Table.new [["d1", [1.2345678910]], ["d2", [12.3456]], ["d3", [1234567.8910]], ["f", [1.5]]]) update_action=Update_Action.Insert . should_succeed

Expand All @@ -257,18 +257,6 @@ snowflake_specific_spec suite_builder default_connection db_name setup =
m1.at "d3" . to_vector . should_equal [1234568]
m1.at "f" . to_vector . should_equal [1.5]

suite_builder.group "[Snowflake] Dialect-specific codegen" group_builder->
data = Snowflake_Info_Data.setup default_connection

group_builder.teardown <|
data.teardown

group_builder.specify "should generate queries for the Distinct operation" <|
t = data.connection.query (SQL_Query.Table_Name data.tinfo)
code_template = 'SELECT "{Tinfo}"."strs" AS "strs", "{Tinfo}"."ints" AS "ints", "{Tinfo}"."bools" AS "bools", "{Tinfo}"."doubles" AS "doubles" FROM (SELECT DISTINCT ON ("{Tinfo}_inner"."strs") "{Tinfo}_inner"."strs" AS "strs", "{Tinfo}_inner"."ints" AS "ints", "{Tinfo}_inner"."bools" AS "bools", "{Tinfo}_inner"."doubles" AS "doubles" FROM (SELECT "{Tinfo}"."strs" AS "strs", "{Tinfo}"."ints" AS "ints", "{Tinfo}"."bools" AS "bools", "{Tinfo}"."doubles" AS "doubles" FROM "{Tinfo}" AS "{Tinfo}") AS "{Tinfo}_inner") AS "{Tinfo}"'
expected_code = code_template.replace "{Tinfo}" data.tinfo
t.distinct ["strs"] . to_sql . prepare . should_equal [expected_code, []]

suite_builder.group "[Snowflake] Table.aggregate should correctly infer result types" group_builder->
data = Snowflake_Aggregate_Data.setup default_connection

Expand All @@ -284,8 +272,13 @@ snowflake_specific_spec suite_builder default_connection db_name setup =
group_builder.specify "Counts" <|
r = data.t.aggregate columns=[Aggregate_Column.Count, Aggregate_Column.Count_Empty "txt", Aggregate_Column.Count_Not_Empty "txt", Aggregate_Column.Count_Distinct "i1", Aggregate_Column.Count_Not_Nothing "i2", Aggregate_Column.Count_Nothing "i3"]
r.column_count . should_equal 6
r.columns.each column->
column.value_type . should_equal (Value_Type.Decimal 18 0)

r.at "Count" . value_type . should_equal (Value_Type.Decimal 18 0)
r.at "Count Empty txt" . value_type . should_equal (Value_Type.Decimal 13 0)
r.at "Count Not Empty txt" . value_type . should_equal (Value_Type.Decimal 13 0)
r.at "Count Distinct i1" . value_type . should_equal (Value_Type.Decimal 18 0)
r.at "Count Not Nothing i2" . value_type . should_equal (Value_Type.Decimal 18 0)
r.at "Count Nothing i3" . value_type . should_equal (Value_Type.Decimal 13 0)

group_builder.specify "Sum" <|
r = data.t.aggregate columns=[Aggregate_Column.Sum "i1", Aggregate_Column.Sum "i2", Aggregate_Column.Sum "i3", Aggregate_Column.Sum "i4", Aggregate_Column.Sum "r1", Aggregate_Column.Sum "r2"]
Expand All @@ -308,24 +301,14 @@ snowflake_specific_spec suite_builder default_connection db_name setup =


suite_builder.group "[Snowflake] Warning/Error handling" group_builder->
group_builder.specify "query warnings should be propagated" <|
long_name = (Name_Generator.random_name "T") + ("a" * 100)
r = default_connection.get.execute_update 'CREATE TEMPORARY TABLE "'+long_name+'" ("A" VARCHAR)'
w1 = Problems.expect_only_warning SQL_Warning r
# The display text may itself be truncated, so we just check the first words.
w1.to_display_text . should_contain "identifier"
# And check the full message for words that could be truncated in short message.
w1.message . should_contain "truncated to"

table = default_connection.get.query (SQL_Query.Raw_SQL 'SELECT 1 AS "'+long_name+'"')
w2 = Problems.expect_only_warning SQL_Warning table
w2.message . should_contain "truncated"
effective_name = table.column_names . at 0
effective_name . should_not_equal long_name
long_name.should_contain effective_name

group_builder.specify "is capable of handling weird tables" <|
default_connection.get.execute_update 'CREATE TEMPORARY TABLE "empty-column-name" ("" VARCHAR)' . should_fail_with SQL_Error
default_connection.get.execute_update 'CREATE TEMPORARY TABLE "empty-column-name" ("" VARCHAR)' . should_succeed
t = default_connection.get.query "empty-column-name"
t.columns.length . should_equal 1
# The column is renamed to something valid upon read:
t.column_names . should_equal ["Column 1"]
# Should be readable:
t.read . at 0 . to_vector . should_equal []

Problems.assume_no_problems <|
default_connection.get.execute_update 'CREATE TEMPORARY TABLE "clashing-unicode-names" ("ś" VARCHAR, "s\u0301" INTEGER)'
Expand All @@ -343,7 +326,9 @@ snowflake_specific_spec suite_builder default_connection db_name setup =
r3.catch.cause . should_be_a Duplicate_Output_Column_Names

r4 = default_connection.get.query 'SELECT 1 AS ""'
r4.should_fail_with SQL_Error
r4.should_fail_with Illegal_Argument
r4.catch.to_display_text . should_contain "The provided custom SQL query is invalid and may suffer data corruption"
r4.catch.to_display_text . should_contain "The name '' is invalid"

suite_builder.group "[Snowflake] Edge Cases" group_builder->
group_builder.specify "materialize should respect the overridden type" pending="TODO" <|
Expand Down Expand Up @@ -525,31 +510,31 @@ snowflake_specific_spec suite_builder default_connection db_name setup =
suite_builder.group "[Snowflake] math functions" group_builder->
group_builder.specify "round, trunc, ceil, floor" <|
col = table_builder [["x", [0.1, 0.9, 3.1, 3.9, -0.1, -0.9, -3.1, -3.9]]] . at "x"
col . cast Value_Type.Integer . ceil . value_type . should_equal Value_Type.Float
col . cast Value_Type.Integer . ceil . value_type . should_equal (Value_Type.Decimal 38 0)

col . cast Value_Type.Float . round . value_type . should_equal Value_Type.Float
col . cast Value_Type.Integer . round . value_type . should_equal Value_Type.Float
col . cast Value_Type.Decimal . round . value_type . should_equal Value_Type.Decimal
col . cast Value_Type.Integer . round . value_type . should_equal (Value_Type.Decimal 38 0)
col . cast Value_Type.Decimal . round . value_type . should_equal (Value_Type.Decimal 38 0)

col . cast Value_Type.Float . round 1 . value_type . should_equal Value_Type.Float
col . cast Value_Type.Integer . round 1 . value_type . should_equal Value_Type.Decimal
col . cast Value_Type.Decimal . round 1 . value_type . should_equal Value_Type.Decimal
col . cast Value_Type.Integer . round 1 . value_type . should_equal (Value_Type.Decimal 38 0)
col . cast Value_Type.Decimal . round 1 . value_type . should_equal (Value_Type.Decimal 38 0)

col . cast Value_Type.Float . round use_bankers=True . value_type . should_equal Value_Type.Float
col . cast Value_Type.Integer . round use_bankers=True . value_type . should_equal Value_Type.Float
col . cast Value_Type.Decimal . round use_bankers=True . value_type . should_equal Value_Type.Decimal
col . cast Value_Type.Decimal . round use_bankers=True . value_type . should_equal Value_Type.Float

col . cast Value_Type.Float . ceil . value_type . should_equal Value_Type.Float
col . cast Value_Type.Integer . ceil . value_type . should_equal Value_Type.Float
col . cast Value_Type.Decimal . ceil . value_type . should_equal Value_Type.Decimal
col . cast Value_Type.Integer . ceil . value_type . should_equal (Value_Type.Decimal 38 0)
col . cast Value_Type.Decimal . ceil . value_type . should_equal (Value_Type.Decimal 38 0)

col . cast Value_Type.Float . floor . value_type . should_equal Value_Type.Float
col . cast Value_Type.Integer . floor . value_type . should_equal Value_Type.Float
col . cast Value_Type.Decimal . floor . value_type . should_equal Value_Type.Decimal
col . cast Value_Type.Integer . floor . value_type . should_equal (Value_Type.Decimal 38 0)
col . cast Value_Type.Decimal . floor . value_type . should_equal (Value_Type.Decimal 38 0)

col . cast Value_Type.Float . truncate . value_type . should_equal Value_Type.Float
col . cast Value_Type.Integer . truncate . value_type . should_equal Value_Type.Float
col . cast Value_Type.Decimal . truncate . value_type . should_equal Value_Type.Decimal
col . cast Value_Type.Integer . truncate . value_type . should_equal (Value_Type.Decimal 38 0)
col . cast Value_Type.Decimal . truncate . value_type . should_equal (Value_Type.Decimal 38 0)

do_op n op =
table = light_table_builder [["x", [n]]]
Expand Down Expand Up @@ -578,18 +563,6 @@ snowflake_specific_spec suite_builder default_connection db_name setup =
do_op Number.positive_infinity op . should_equal Number.positive_infinity
do_op Number.negative_infinity op . should_equal Number.negative_infinity

group_builder.specify "round returns the correct type" <|
do_round 231.2 1 . should_be_a Float
do_round 231.2 0 . should_be_a Float
do_round 231.2 . should_be_a Float
do_round 231.2 -1 . should_be_a Float

group_builder.specify "round returns the correct type" <|
do_round 231 1 . should_be_a Float
do_round 231 0 . should_be_a Float
do_round 231 . should_be_a Float
do_round 231 -1 . should_be_a Float

type Lazy_Ref
Value ~get

Expand All @@ -604,7 +577,6 @@ add_snowflake_specs suite_builder create_connection_fn db_name =
ix = name_counter.get
name_counter . put ix+1
name = Name_Generator.random_name "table_"+ix.to_text

in_mem_table = Table.new columns
in_mem_table.select_into_database_table (connection.if_nothing default_connection.get) name primary_key=Nothing temporary=True
light_table_builder columns =
Expand Down Expand Up @@ -662,7 +634,7 @@ add_table_specs suite_builder =
cloud_setup.with_prepared_environment <|
with_secret "my_snowflake_username" base_details.credentials.username username_secret-> with_secret "my_snowflake_password" base_details.credentials.password password_secret->
secret_credentials = Credentials.Username_And_Password username_secret password_secret
details = Snowflake_Details.Snowflake base_details.account_name secret_credentials base_details.database base_details.schema base_details.warehouse
details = Snowflake_Details.Snowflake base_details.account secret_credentials base_details.database base_details.schema base_details.warehouse
connection = Database.connect details
connection.should_succeed
Panic.with_finalizer connection.close <|
Expand Down
20 changes: 20 additions & 0 deletions test/Table_Tests/src/Common_Table_Operations/Aggregate_Spec.enso
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,26 @@ add_specs suite_builder setup =
m1.columns.first.name . should_equal "Count Distinct A B"
m1.columns.first.to_vector . should_equal [3]

group_builder.specify "should work correctly with Boolean columns" pending=(if prefix.contains "Snowflake" then "TODO: fix https://github.com/enso-org/enso/issues/10611") <|
table = table_builder [["A", [True, True, True]], ["B", [False, False, False]], ["C", [True, False, True]], ["D", [Nothing, False, True]]]

t_with_nulls = table.aggregate columns=[..Count_Distinct "A", ..Count_Distinct "B", ..Count_Distinct "C", ..Count_Distinct "D"]
m1 = materialize t_with_nulls
m1.column_count . should_equal 4
m1.at "Count Distinct A" . to_vector . should_equal [1]
m1.at "Count Distinct B" . to_vector . should_equal [1]
m1.at "Count Distinct C" . to_vector . should_equal [2]
m1.at "Count Distinct D" . to_vector . should_equal [3]

t_without_nulls = table.aggregate columns=[..Count_Distinct "A" ignore_nothing=True, ..Count_Distinct "B" ignore_nothing=True, ..Count_Distinct "C" ignore_nothing=True, ..Count_Distinct "D" ignore_nothing=True]
m2 = materialize t_without_nulls
m2.column_count . should_equal 4
m2.at "Count Distinct A" . to_vector . should_equal [1]
m2.at "Count Distinct B" . to_vector . should_equal [1]
m2.at "Count Distinct C" . to_vector . should_equal [2]
# The NULL is ignored, and not counted towards the total
m2.at "Count Distinct D" . to_vector . should_equal [2]

suite_builder.group prefix+"Table.aggregate Standard_Deviation" pending=(resolve_pending test_selection.std_dev) group_builder->
group_builder.specify "should correctly handle single elements" <|
r1 = table_builder [["X", [1]]] . aggregate columns=[Standard_Deviation "X" (population=False), Standard_Deviation "X" (population=True)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ add_specs suite_builder setup =
table_builder = setup.light_table_builder
pending_datetime = if setup.test_selection.date_time.not then "Date/Time operations are not supported by this backend."
suite_builder.group prefix+"(Derived_Columns_Spec) Table.set with Simple_Expression" group_builder->
group_builder.specify "arithmetics" <|
group_builder.specify "arithmetics" pending=(if prefix.contains "Snowflake" then "TODO: re-enable these once https://github.com/enso-org/enso/pull/10583 is merged") <|
t = table_builder [["A", [1, 2]], ["B", [10, 40]]]
t.set (Simple_Expression.Simple_Expr (Column_Ref.Name "A") Simple_Calculation.Copy) "C" . at "C" . to_vector . should_equal [1, 2]
t.set (..Simple_Expr (..Name "A") ..Copy) "C" . at "C" . to_vector . should_equal [1, 2]
Expand Down
Loading

0 comments on commit ba56f8e

Please sign in to comment.