@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
19
19
20
20
import scala .collection .mutable
21
21
22
- import org .apache .spark .sql .catalyst .expressions .{And , Attribute , AttributeReference , Expression , IntegerLiteral , NamedExpression , PredicateHelper , ProjectionOverSchema , SubqueryExpression }
22
+ import org .apache .spark .sql .catalyst .expressions .{Alias , And , Attribute , AttributeReference , Cast , Expression , IntegerLiteral , NamedExpression , PredicateHelper , ProjectionOverSchema , SubqueryExpression }
23
23
import org .apache .spark .sql .catalyst .expressions .aggregate
24
24
import org .apache .spark .sql .catalyst .expressions .aggregate .AggregateExpression
25
25
import org .apache .spark .sql .catalyst .planning .ScanOperation
@@ -30,7 +30,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
30
30
import org .apache .spark .sql .connector .read .{Scan , ScanBuilder , SupportsPushDownAggregates , SupportsPushDownFilters , V1Scan }
31
31
import org .apache .spark .sql .execution .datasources .DataSourceStrategy
32
32
import org .apache .spark .sql .sources
33
- import org .apache .spark .sql .types .StructType
33
+ import org .apache .spark .sql .types .{ DataType , LongType , StructType }
34
34
import org .apache .spark .sql .util .SchemaUtils ._
35
35
36
36
object V2ScanRelationPushDown extends Rule [LogicalPlan ] with PredicateHelper {
@@ -131,7 +131,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
131
131
case (a : Attribute , b : Attribute ) => b.withExprId(a.exprId)
132
132
case (_, b) => b
133
133
}
134
- val output = groupAttrs ++ newOutput.drop(groupAttrs.length)
134
+ val aggOutput = newOutput.drop(groupAttrs.length)
135
+ val output = groupAttrs ++ aggOutput
135
136
136
137
logInfo(
137
138
s """
@@ -147,40 +148,59 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
147
148
148
149
val scanRelation = DataSourceV2ScanRelation (sHolder.relation, wrappedScan, output)
149
150
150
- val plan = Aggregate (
151
- output.take(groupingExpressions.length), resultExpressions, scanRelation)
152
-
153
- // scalastyle:off
154
- // Change the optimized logical plan to reflect the pushed down aggregate
155
- // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
156
- // SELECT min(c1), max(c1) FROM t GROUP BY c2;
157
- // The original logical plan is
158
- // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
159
- // +- RelationV2[c1#9, c2#10] ...
160
- //
161
- // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
162
- // we have the following
163
- // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
164
- // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
165
- //
166
- // We want to change it to
167
- // == Optimized Logical Plan ==
168
- // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
169
- // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
170
- // scalastyle:on
171
- val aggOutput = output.drop(groupAttrs.length)
172
- plan.transformExpressions {
173
- case agg : AggregateExpression =>
174
- val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
175
- val aggFunction : aggregate.AggregateFunction =
176
- agg.aggregateFunction match {
177
- case max : aggregate.Max => max.copy(child = aggOutput(ordinal))
178
- case min : aggregate.Min => min.copy(child = aggOutput(ordinal))
179
- case sum : aggregate.Sum => sum.copy(child = aggOutput(ordinal))
180
- case _ : aggregate.Count => aggregate.Sum (aggOutput(ordinal))
181
- case other => other
182
- }
183
- agg.copy(aggregateFunction = aggFunction)
151
+ if (r.supportCompletePushDown()) {
152
+ val projectExpressions = resultExpressions.map { expr =>
153
+ // TODO At present, only push down group by attribute is supported.
154
+ // In future, more attribute conversion is extended here. e.g. GetStructField
155
+ expr.transform {
156
+ case agg : AggregateExpression =>
157
+ val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
158
+ val child =
159
+ addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType)
160
+ Alias (child, agg.resultAttribute.name)(agg.resultAttribute.exprId)
161
+ }
162
+ }.asInstanceOf [Seq [NamedExpression ]]
163
+ Project (projectExpressions, scanRelation)
164
+ } else {
165
+ val plan = Aggregate (
166
+ output.take(groupingExpressions.length), resultExpressions, scanRelation)
167
+
168
+ // scalastyle:off
169
+ // Change the optimized logical plan to reflect the pushed down aggregate
170
+ // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
171
+ // SELECT min(c1), max(c1) FROM t GROUP BY c2;
172
+ // The original logical plan is
173
+ // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
174
+ // +- RelationV2[c1#9, c2#10] ...
175
+ //
176
+ // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
177
+ // we have the following
178
+ // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
179
+ // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
180
+ //
181
+ // We want to change it to
182
+ // == Optimized Logical Plan ==
183
+ // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
184
+ // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
185
+ // scalastyle:on
186
+ plan.transformExpressions {
187
+ case agg : AggregateExpression =>
188
+ val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
189
+ val aggAttribute = aggOutput(ordinal)
190
+ val aggFunction : aggregate.AggregateFunction =
191
+ agg.aggregateFunction match {
192
+ case max : aggregate.Max =>
193
+ max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType))
194
+ case min : aggregate.Min =>
195
+ min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType))
196
+ case sum : aggregate.Sum =>
197
+ sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType))
198
+ case _ : aggregate.Count =>
199
+ aggregate.Sum (addCastIfNeeded(aggAttribute, LongType ))
200
+ case other => other
201
+ }
202
+ agg.copy(aggregateFunction = aggFunction)
203
+ }
184
204
}
185
205
}
186
206
case _ => aggNode
@@ -189,6 +209,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
189
209
}
190
210
}
191
211
212
+ private def addCastIfNeeded (aggAttribute : AttributeReference , aggDataType : DataType ) =
213
+ if (aggAttribute.dataType == aggDataType) {
214
+ aggAttribute
215
+ } else {
216
+ Cast (aggAttribute, aggDataType)
217
+ }
218
+
192
219
def applyColumnPruning (plan : LogicalPlan ): LogicalPlan = plan.transform {
193
220
case ScanOperation (project, filters, sHolder : ScanBuilderHolder ) =>
194
221
// column pruning
0 commit comments