From 1fc83543dd4ffff5ee169a43c8512f7121c3a719 Mon Sep 17 00:00:00 2001 From: tokoko Date: Mon, 7 Oct 2024 22:12:41 +0000 Subject: [PATCH 1/7] fix(substrait): remove optimize calls from substrait consumer --- .../substrait/src/logical_plan/consumer.rs | 53 +++++++++++++------ .../tests/cases/substrait_validations.rs | 7 +-- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 030536f9f830..2a691a30132e 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -783,7 +783,6 @@ pub async fn from_substrait_rel( let t = ctx.table(table_reference.clone()).await?; let t = ensure_schema_compatability(t, substrait_schema)?; - let t = t.into_optimized_plan()?; extract_projection(t, &read.projection) } Some(ReadType::VirtualTable(vt)) => { @@ -866,7 +865,7 @@ pub async fn from_substrait_rel( // directly use unwrap here since we could determine it is a valid one let table_reference = TableReference::Bare { table: name.into() }; let t = ctx.table(table_reference).await?; - let t = t.into_optimized_plan()?; + let t = t.into_unoptimized_plan(); extract_projection(t, &read.projection) } _ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type), @@ -990,24 +989,46 @@ pub async fn from_substrait_rel( fn ensure_schema_compatability( table: DataFrame, substrait_schema: DFSchema, -) -> Result { +) -> Result { let df_schema = table.schema().to_owned().strip_qualifiers(); if df_schema.logically_equivalent_names_and_types(&substrait_schema) { - return Ok(table); + return Ok(table.into_unoptimized_plan()); } - let selected_columns = substrait_schema - .strip_qualifiers() - .fields() - .iter() - .map(|substrait_field| { - let df_field = - df_schema.field_with_unqualified_name(substrait_field.name())?; - ensure_field_compatability(df_field, substrait_field)?; - Ok(col(format!("\"{}\"", df_field.name()))) - }) - .collect::>()?; - table.select(selected_columns) + let qualified_schema = table.schema().to_owned(); + + let t = table.into_unoptimized_plan(); + + match t { + LogicalPlan::TableScan(mut scan) => { + let column_indices: Vec = substrait_schema + .strip_qualifiers() + .fields() + .iter() + .map(|substrait_field| { + let df_field = + df_schema.field_with_unqualified_name(substrait_field.name())?; + ensure_field_compatability(df_field, substrait_field)?; + + Ok(df_schema + .index_of_column_by_name(None, substrait_field.name().as_str()) + .unwrap()) + }) + .collect::>()?; + + let fields = column_indices + .iter() + .map(|i| qualified_schema.qualified_field(*i)) + .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) + .collect(); + + scan.projected_schema = + DFSchemaRef::new(DFSchema::new_with_metadata(fields, HashMap::new())?); + scan.projection = Some(column_indices); + Ok(LogicalPlan::TableScan(scan)) + } + _ => Ok(t), + } } /// Ensures that the given Substrait field is compatible with the given DataFusion field diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index cb1fb67fc044..96cf807f4fe6 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -91,8 +91,7 @@ mod tests { assert_eq!( format!("{}", plan), "Projection: DATA.a, DATA.b\ - \n Projection: DATA.a, DATA.b\ - \n TableScan: DATA projection=[b, a]" + \n TableScan: DATA projection=[a, b]" ); Ok(()) } @@ -115,9 +114,7 @@ mod tests { assert_eq!( format!("{}", plan), "Projection: DATA.a, DATA.b\ - \n Projection: DATA.a, DATA.b\ - \n Projection: DATA.a, DATA.b, DATA.c\ - \n TableScan: DATA projection=[b, a, c]" + \n TableScan: DATA projection=[b, a]" ); Ok(()) } From 60623afe6e2f44e1312c03b0eb1b06f622097f65 Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 8 Oct 2024 06:36:48 +0000 Subject: [PATCH 2/7] fix(substrait): fix schema comparison in ensure_schema_compatability --- .../substrait/src/logical_plan/consumer.rs | 6 +- .../tests/cases/consumer_integration.rs | 122 +++++++++--------- .../substrait/tests/cases/function_test.rs | 2 +- .../substrait/tests/cases/logical_plans.rs | 4 +- .../tests/cases/roundtrip_logical_plan.rs | 4 +- .../tests/cases/substrait_validations.rs | 2 +- 6 files changed, 69 insertions(+), 71 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 2a691a30132e..19234706e51f 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -990,13 +990,11 @@ fn ensure_schema_compatability( table: DataFrame, substrait_schema: DFSchema, ) -> Result { - let df_schema = table.schema().to_owned().strip_qualifiers(); + let df_schema = table.schema().to_owned(); if df_schema.logically_equivalent_names_and_types(&substrait_schema) { return Ok(table.into_unoptimized_plan()); } - let qualified_schema = table.schema().to_owned(); - let t = table.into_unoptimized_plan(); match t { @@ -1018,7 +1016,7 @@ fn ensure_schema_compatability( let fields = column_indices .iter() - .map(|i| qualified_schema.qualified_field(*i)) + .map(|i| df_schema.qualified_field(*i)) .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) .collect(); diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index b1cc76305031..fffa29df1db5 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -55,7 +55,7 @@ mod tests { \n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count(Int64(1))]]\ \n Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, LINEITEM.L_QUANTITY, LINEITEM.L_EXTENDEDPRICE, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT), LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + LINEITEM.L_TAX), LINEITEM.L_DISCOUNT\ \n Filter: LINEITEM.L_SHIPDATE <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 0, milliseconds: 10368000 }\")\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]" + \n TableScan: LINEITEM" ); Ok(()) } @@ -76,19 +76,19 @@ mod tests { \n CrossJoin:\ \n CrossJoin:\ \n CrossJoin:\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ - \n TableScan: REGION projection=[R_REGIONKEY, R_NAME, R_COMMENT]\ + \n TableScan: PARTSUPP\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION\ + \n TableScan: REGION\ \n CrossJoin:\ \n CrossJoin:\ \n CrossJoin:\ \n CrossJoin:\ - \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ - \n TableScan: REGION projection=[R_REGIONKEY, R_NAME, R_COMMENT]" + \n TableScan: PART\ + \n TableScan: SUPPLIER\ + \n TableScan: PARTSUPP\ + \n TableScan: NATION\ + \n TableScan: REGION" ); Ok(()) } @@ -107,9 +107,9 @@ mod tests { \n Filter: CUSTOMER.C_MKTSEGMENT = Utf8(\"BUILDING\") AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1995-03-15\") AS Date32) AND LINEITEM.L_SHIPDATE > CAST(Utf8(\"1995-03-15\") AS Date32)\ \n CrossJoin:\ \n CrossJoin:\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]" + \n TableScan: LINEITEM\ + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS" ); Ok(()) } @@ -126,8 +126,8 @@ mod tests { \n Filter: ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-07-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1993-10-01\") AS Date32) AND EXISTS ()\ \n Subquery:\ \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]" + \n TableScan: LINEITEM\ + \n TableScan: ORDERS" ); Ok(()) } @@ -147,12 +147,12 @@ mod tests { \n CrossJoin:\ \n CrossJoin:\ \n CrossJoin:\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ - \n TableScan: REGION projection=[R_REGIONKEY, R_NAME, R_COMMENT]" + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION\ + \n TableScan: REGION" ); Ok(()) } @@ -165,7 +165,7 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT) AS REVENUE]]\ \n Projection: LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT\ \n Filter: LINEITEM.L_SHIPDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-01-01\") AS Date32) AND LINEITEM.L_DISCOUNT >= Decimal128(Some(5),3,2) AND LINEITEM.L_DISCOUNT <= Decimal128(Some(7),3,2) AND LINEITEM.L_QUANTITY < CAST(Int32(24) AS Decimal128(15, 2))\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]" + \n TableScan: LINEITEM" ); Ok(()) } @@ -209,10 +209,10 @@ mod tests { \n CrossJoin:\ \n CrossJoin:\ \n CrossJoin:\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]" + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM\ + \n TableScan: NATION" ); Ok(()) } @@ -232,17 +232,17 @@ mod tests { \n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\ \n CrossJoin:\ \n CrossJoin:\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n TableScan: PARTSUPP\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION\ \n Aggregate: groupBy=[[PARTSUPP.PS_PARTKEY]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]]\ \n Projection: PARTSUPP.PS_PARTKEY, PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0))\ \n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\ \n CrossJoin:\ \n CrossJoin:\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]" + \n TableScan: PARTSUPP\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION" ); Ok(()) } @@ -258,8 +258,8 @@ mod tests { \n Projection: LINEITEM.L_SHIPMODE, CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY = Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END, CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY != Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END\ \n Filter: ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"MAIL\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"SHIP\") AS Utf8)) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE AND LINEITEM.L_SHIPDATE < LINEITEM.L_COMMITDATE AND LINEITEM.L_RECEIPTDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_RECEIPTDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ \n CrossJoin:\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]" + \n TableScan: ORDERS\ + \n TableScan: LINEITEM" ); Ok(()) } @@ -277,8 +277,8 @@ mod tests { \n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY]], aggr=[[count(ORDERS.O_ORDERKEY)]]\ \n Projection: CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY\ \n Left Join: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY Filter: NOT ORDERS.O_COMMENT LIKE CAST(Utf8(\"%special%requests%\") AS Utf8)\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]" + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS" ); Ok(()) } @@ -293,8 +293,8 @@ mod tests { \n Projection: CASE WHEN PART.P_TYPE LIKE CAST(Utf8(\"PROMO%\") AS Utf8) THEN LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: LINEITEM.L_PARTKEY = PART.P_PARTKEY AND LINEITEM.L_SHIPDATE >= Date32(\"1995-09-01\") AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-10-01\") AS Date32)\ \n CrossJoin:\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]" + \n TableScan: LINEITEM\ + \n TableScan: PART" ); Ok(()) } @@ -320,10 +320,10 @@ mod tests { \n Subquery:\ \n Projection: SUPPLIER.S_SUPPKEY\ \n Filter: SUPPLIER.S_COMMENT LIKE CAST(Utf8(\"%Customer%Complaints%\") AS Utf8)\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ + \n TableScan: SUPPLIER\ \n CrossJoin:\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ - \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]" + \n TableScan: PARTSUPP\ + \n TableScan: PART" ); Ok(()) } @@ -352,12 +352,12 @@ mod tests { \n Filter: sum(LINEITEM.L_QUANTITY) > CAST(Int32(300) AS Decimal128(15, 2))\ \n Aggregate: groupBy=[[LINEITEM.L_ORDERKEY]], aggr=[[sum(LINEITEM.L_QUANTITY)]]\ \n Projection: LINEITEM.L_ORDERKEY, LINEITEM.L_QUANTITY\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ + \n TableScan: LINEITEM\ \n CrossJoin:\ \n CrossJoin:\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]" + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM" ); Ok(()) } @@ -370,8 +370,8 @@ mod tests { \n Projection: LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#12\") AND (PART.P_CONTAINER = CAST(Utf8(\"SM CASE\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PKG\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(5) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#23\") AND (PART.P_CONTAINER = CAST(Utf8(\"MED BAG\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED PKG\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED PACK\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(10) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#34\") AND (PART.P_CONTAINER = CAST(Utf8(\"LG CASE\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG PKG\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(15) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\")\ \n CrossJoin:\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]" + \n TableScan: LINEITEM\ + \n TableScan: PART" ); Ok(()) } @@ -390,17 +390,17 @@ mod tests { \n Subquery:\ \n Projection: PART.P_PARTKEY\ \n Filter: PART.P_NAME LIKE CAST(Utf8(\"forest%\") AS Utf8)\ - \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]\ + \n TableScan: PART\ \n Subquery:\ \n Projection: Decimal128(Some(5),2,1) * sum(LINEITEM.L_QUANTITY)\ \n Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_QUANTITY)]]\ \n Projection: LINEITEM.L_QUANTITY\ \n Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: PARTSUPP projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\ + \n TableScan: LINEITEM\ + \n TableScan: PARTSUPP\ \n CrossJoin:\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]" + \n TableScan: SUPPLIER\ + \n TableScan: NATION" ); Ok(()) } @@ -418,17 +418,17 @@ mod tests { \n Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8(\"F\") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS () AND NOT EXISTS () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"SAUDI ARABIA\")\ \n Subquery:\ \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ + \n TableScan: LINEITEM\ \n Subquery:\ \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ + \n TableScan: LINEITEM\ \n CrossJoin:\ \n CrossJoin:\ \n CrossJoin:\ - \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\ - \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]" + \n TableScan: SUPPLIER\ + \n TableScan: LINEITEM\ + \n TableScan: ORDERS\ + \n TableScan: NATION" ); Ok(()) } @@ -447,11 +447,11 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[avg(CUSTOMER.C_ACCTBAL)]]\ \n Projection: CUSTOMER.C_ACCTBAL\ \n Filter: CUSTOMER.C_ACCTBAL > Decimal128(Some(0),3,2) AND (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"13\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"31\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"23\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"29\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"30\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"18\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"17\") AS Utf8))\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\ + \n TableScan: CUSTOMER\ \n Subquery:\ \n Filter: ORDERS.O_CUSTKEY = ORDERS.O_ORDERKEY\ - \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT]\ - \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]" + \n TableScan: ORDERS\ + \n TableScan: CUSTOMER" ); Ok(()) } diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs index 5806b55d84c4..b136b0af19c2 100644 --- a/datafusion/substrait/tests/cases/function_test.rs +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -37,7 +37,7 @@ mod tests { plan_str, "Projection: nation.n_name\ \n Filter: contains(nation.n_name, Utf8(\"IA\"))\ - \n TableScan: nation projection=[n_nationkey, n_name, n_regionkey, n_comment]" + \n TableScan: nation" ); Ok(()) } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 6794b32838a8..f4e34af35d78 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -43,7 +43,7 @@ mod tests { assert_eq!( format!("{}", plan), "Projection: NOT DATA.D AS EXPR$0\ - \n TableScan: DATA projection=[D]" + \n TableScan: DATA" ); Ok(()) } @@ -69,7 +69,7 @@ mod tests { format!("{}", plan), "Projection: sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ \n WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: DATA projection=[D, PART, ORD]" + \n TableScan: DATA" ); Ok(()) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ce6d1825cd25..352442a15bd2 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -472,12 +472,12 @@ async fn roundtrip_inlist_5() -> Result<()> { \n Subquery:\ \n Projection: data2.a\ \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ - \n TableScan: data2 projection=[a, b, c, d, e, f]\ + \n TableScan: data2\ \n TableScan: data projection=[a, f], partial_filters=[data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()]\ \n Subquery:\ \n Projection: data2.a\ \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ - \n TableScan: data2 projection=[a, b, c, d, e, f]", + \n TableScan: data2", true).await } diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index 96cf807f4fe6..b6a9c4da83de 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -70,7 +70,7 @@ mod tests { assert_eq!( format!("{}", plan), "Projection: DATA.a, DATA.b\ - \n TableScan: DATA projection=[a, b]" + \n TableScan: DATA" ); Ok(()) } From cbd37b1e754127d13877bf708ae45e79999b2325 Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 8 Oct 2024 17:23:47 +0000 Subject: [PATCH 3/7] fix(substrait): correctly apply read projections --- datafusion/substrait/src/lib.rs | 1 + .../substrait/src/logical_plan/consumer.rs | 85 +++++++++---------- .../tests/cases/substrait_validations.rs | 6 +- 3 files changed, 42 insertions(+), 50 deletions(-) diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 0b1c796553c0..a6f7c033f9d0 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -68,6 +68,7 @@ //! //! // Receive a substrait protobuf from somewhere, and turn it into a LogicalPlan //! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?; +//! let logical_round_trip = ctx.state().optimize(&logical_round_trip)?; //! assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); //! # Ok(()) //! # } diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 19234706e51f..11b7f7447f59 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -55,7 +55,6 @@ use crate::variation_const::{ use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::dataframe::DataFrame; -use datafusion::logical_expr::builder::project; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, @@ -227,7 +226,6 @@ pub async fn from_substrait_plan( // Nothing to do if the schema is already equivalent return Ok(plan); } - match plan { // If the last node of the plan produces expressions, bake the renames into those expressions. // This isn't necessary for correctness, but helps with roundtrip tests. @@ -327,11 +325,10 @@ pub async fn from_substrait_extended_expr( }) } -/// parse projection -pub fn extract_projection( - t: LogicalPlan, +pub fn apply_projection( + schema: DFSchema, projection: &::core::option::Option, -) -> Result { +) -> Result { match projection { Some(MaskExpression { select, .. }) => match &select.as_ref() { Some(projection) => { @@ -340,41 +337,20 @@ pub fn extract_projection( .iter() .map(|item| item.field as usize) .collect(); - match t { - LogicalPlan::TableScan(mut scan) => { - let fields = column_indices - .iter() - .map(|i| scan.projected_schema.qualified_field(*i)) - .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) - }) - .collect(); - scan.projection = Some(column_indices); - scan.projected_schema = DFSchemaRef::new( - DFSchema::new_with_metadata(fields, HashMap::new())?, - ); - Ok(LogicalPlan::TableScan(scan)) - } - LogicalPlan::Projection(projection) => { - // create another Projection around the Projection to handle the field masking - let fields: Vec = column_indices - .into_iter() - .map(|i| { - let (qualifier, field) = - projection.schema.qualified_field(i); - let column = - Column::new(qualifier.cloned(), field.name()); - Expr::Column(column) - }) - .collect(); - project(LogicalPlan::Projection(projection), fields) - } - _ => plan_err!("unexpected plan for table"), - } + + let fields = column_indices + .iter() + .map(|i| schema.qualified_field(*i)) + .map(|(qualifier, field)| { + (qualifier.cloned(), Arc::new(field.clone())) + }) + .collect(); + + Ok(DFSchema::new_with_metadata(fields, HashMap::new())?) } - _ => Ok(t), + _ => Ok(schema), }, - _ => Ok(t), + _ => Ok(schema), } } @@ -781,9 +757,11 @@ pub async fn from_substrait_rel( from_substrait_named_struct(named_struct, extensions)? .replace_qualifier(table_reference.clone()); + let substrait_schema = + apply_projection(substrait_schema, &read.projection)?; + let t = ctx.table(table_reference.clone()).await?; - let t = ensure_schema_compatability(t, substrait_schema)?; - extract_projection(t, &read.projection) + ensure_schema_compatability(t, substrait_schema) } Some(ReadType::VirtualTable(vt)) => { let base_schema = read.base_schema.as_ref().ok_or_else(|| { @@ -834,6 +812,10 @@ pub async fn from_substrait_rel( })) } Some(ReadType::LocalFiles(lf)) => { + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for LocalFiles") + })?; + fn extract_filename(name: &str) -> Option { let corrected_url = if name.starts_with("file://") && !name.starts_with("file:///") { @@ -864,9 +846,16 @@ pub async fn from_substrait_rel( let name = filename.unwrap(); // directly use unwrap here since we could determine it is a valid one let table_reference = TableReference::Bare { table: name.into() }; - let t = ctx.table(table_reference).await?; - let t = t.into_unoptimized_plan(); - extract_projection(t, &read.projection) + let t = ctx.table(table_reference.clone()).await?; + + let substrait_schema = + from_substrait_named_struct(named_struct, extensions)? + .replace_qualifier(table_reference); + + let substrait_schema = + apply_projection(substrait_schema, &read.projection)?; + + ensure_schema_compatability(t, substrait_schema) } _ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type), }, @@ -991,12 +980,13 @@ fn ensure_schema_compatability( substrait_schema: DFSchema, ) -> Result { let df_schema = table.schema().to_owned(); - if df_schema.logically_equivalent_names_and_types(&substrait_schema) { - return Ok(table.into_unoptimized_plan()); - } let t = table.into_unoptimized_plan(); + if df_schema.logically_equivalent_names_and_types(&substrait_schema) { + return Ok(t); + } + match t { LogicalPlan::TableScan(mut scan) => { let column_indices: Vec = substrait_schema @@ -1023,6 +1013,7 @@ fn ensure_schema_compatability( scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata(fields, HashMap::new())?); scan.projection = Some(column_indices); + Ok(LogicalPlan::TableScan(scan)) } _ => Ok(t), diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index b6a9c4da83de..afda9004849f 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -103,10 +103,10 @@ mod tests { ); // the DataFusion schema { b, a, c, d } contains the Substrait schema { a, b, c } let df_schema = vec![ - ("b", DataType::Int32, true), + ("d", DataType::Int32, true), ("a", DataType::Int32, false), ("c", DataType::Int32, false), - ("d", DataType::Int32, false), + ("b", DataType::Int32, false), ]; let ctx = generate_context_with_table("DATA", df_schema)?; let plan = from_substrait_plan(&ctx, &proto_plan).await?; @@ -114,7 +114,7 @@ mod tests { assert_eq!( format!("{}", plan), "Projection: DATA.a, DATA.b\ - \n TableScan: DATA projection=[b, a]" + \n TableScan: DATA projection=[a, b]" ); Ok(()) } From 5cab5c77088eb322088fe0d81f2e1d2d37eb569c Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 8 Oct 2024 18:50:57 +0000 Subject: [PATCH 4/7] fix(substrait): nits --- datafusion/substrait/src/logical_plan/consumer.rs | 4 ++-- datafusion/substrait/tests/cases/substrait_validations.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 11b7f7447f59..efc88bb56405 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -348,9 +348,9 @@ pub fn apply_projection( Ok(DFSchema::new_with_metadata(fields, HashMap::new())?) } - _ => Ok(schema), + None => Ok(schema), }, - _ => Ok(schema), + None => Ok(schema), } } diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index afda9004849f..5ae586afe56f 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -101,7 +101,7 @@ mod tests { let proto_plan = read_json( "tests/testdata/test_plans/simple_select_with_mask.substrait.json", ); - // the DataFusion schema { b, a, c, d } contains the Substrait schema { a, b, c } + // the DataFusion schema { d, a, c, b } contains the Substrait schema { a, b, c } let df_schema = vec![ ("d", DataType::Int32, true), ("a", DataType::Int32, false), From e8b2b2a695e657034824200f9f2403c01179a3fd Mon Sep 17 00:00:00 2001 From: tokoko Date: Wed, 9 Oct 2024 06:40:06 +0000 Subject: [PATCH 5/7] fix(substrait): split schema validation and apply_projection --- .../substrait/src/logical_plan/consumer.rs | 72 ++++++++++++------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index efc88bb56405..8a7e0bc19761 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -68,7 +68,7 @@ use datafusion::{ prelude::{Column, SessionContext}, scalar::ScalarValue, }; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; @@ -325,11 +325,11 @@ pub async fn from_substrait_extended_expr( }) } -pub fn apply_projection( +pub fn apply_masking( schema: DFSchema, - projection: &::core::option::Option, + mask_expression: &::core::option::Option, ) -> Result { - match projection { + match mask_expression { Some(MaskExpression { select, .. }) => match &select.as_ref() { Some(projection) => { let column_indices: Vec = projection @@ -346,7 +346,10 @@ pub fn apply_projection( }) .collect(); - Ok(DFSchema::new_with_metadata(fields, HashMap::new())?) + Ok(DFSchema::new_with_metadata( + fields, + schema.metadata().clone(), + )?) } None => Ok(schema), }, @@ -753,15 +756,20 @@ pub async fn from_substrait_rel( }, }; + let t = ctx.table(table_reference.clone()).await?; + let substrait_schema = from_substrait_named_struct(named_struct, extensions)? - .replace_qualifier(table_reference.clone()); + .replace_qualifier(table_reference); - let substrait_schema = - apply_projection(substrait_schema, &read.projection)?; + ensure_schema_compatability( + t.schema().to_owned(), + substrait_schema.clone(), + )?; - let t = ctx.table(table_reference.clone()).await?; - ensure_schema_compatability(t, substrait_schema) + let substrait_schema = apply_masking(substrait_schema, &read.projection)?; + + apply_projection(t, substrait_schema) } Some(ReadType::VirtualTable(vt)) => { let base_schema = read.base_schema.as_ref().ok_or_else(|| { @@ -852,10 +860,14 @@ pub async fn from_substrait_rel( from_substrait_named_struct(named_struct, extensions)? .replace_qualifier(table_reference); - let substrait_schema = - apply_projection(substrait_schema, &read.projection)?; + ensure_schema_compatability( + t.schema().to_owned(), + substrait_schema.clone(), + )?; + + let substrait_schema = apply_masking(substrait_schema, &read.projection)?; - ensure_schema_compatability(t, substrait_schema) + apply_projection(t, substrait_schema) } _ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type), }, @@ -972,13 +984,27 @@ pub async fn from_substrait_rel( /// 1. All fields present in the Substrait schema are present in the DataFusion schema. The /// DataFusion schema may have MORE fields, but not the other way around. /// 2. All fields are compatible. See [`ensure_field_compatability`] for details -/// -/// This function returns a DataFrame with fields adjusted if necessary in the event that the -/// Substrait schema is a subset of the DataFusion schema. fn ensure_schema_compatability( - table: DataFrame, + table_schema: DFSchema, substrait_schema: DFSchema, -) -> Result { +) -> Result<()> { + substrait_schema + .strip_qualifiers() + .fields() + .iter() + .map(|substrait_field| { + let df_field = + table_schema.field_with_unqualified_name(substrait_field.name())?; + ensure_field_compatability(df_field, substrait_field) + }) + .collect::>()?; + + Ok(()) +} + +/// This function returns a DataFrame with fields adjusted if necessary in the event that the +/// Substrait schema is a subset of the DataFusion schema. +fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result { let df_schema = table.schema().to_owned(); let t = table.into_unoptimized_plan(); @@ -994,10 +1020,6 @@ fn ensure_schema_compatability( .fields() .iter() .map(|substrait_field| { - let df_field = - df_schema.field_with_unqualified_name(substrait_field.name())?; - ensure_field_compatability(df_field, substrait_field)?; - Ok(df_schema .index_of_column_by_name(None, substrait_field.name().as_str()) .unwrap()) @@ -1010,8 +1032,10 @@ fn ensure_schema_compatability( .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) .collect(); - scan.projected_schema = - DFSchemaRef::new(DFSchema::new_with_metadata(fields, HashMap::new())?); + scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata( + fields, + df_schema.metadata().clone(), + )?); scan.projection = Some(column_indices); Ok(LogicalPlan::TableScan(scan)) From bbf7e48069b29d6f70ca3986cbf27dfd181b0cf0 Mon Sep 17 00:00:00 2001 From: tokoko Date: Thu, 10 Oct 2024 22:38:20 +0000 Subject: [PATCH 6/7] fix(substrait): return an error when apply_projection is called with something other than a TableScan --- datafusion/substrait/src/logical_plan/consumer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 8a7e0bc19761..d2d8943e591a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1040,7 +1040,7 @@ fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result Ok(t), + _ => plan_err!("DataFrame passed to apply_projection must be a TableScan"), } } From 229429e62a8b64781c4f1ecc1585952c7e129999 Mon Sep 17 00:00:00 2001 From: tokoko Date: Thu, 10 Oct 2024 22:56:56 +0000 Subject: [PATCH 7/7] fix(substrait): clippy errors --- datafusion/substrait/src/logical_plan/consumer.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index d2d8943e591a..442c8f49513b 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -992,14 +992,11 @@ fn ensure_schema_compatability( .strip_qualifiers() .fields() .iter() - .map(|substrait_field| { + .try_for_each(|substrait_field| { let df_field = table_schema.field_with_unqualified_name(substrait_field.name())?; ensure_field_compatability(df_field, substrait_field) }) - .collect::>()?; - - Ok(()) } /// This function returns a DataFrame with fields adjusted if necessary in the event that the