Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32976][SQL]Support column list in INSERT statement #29893

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ query
;

insertInto
: INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable
| INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? #insertIntoTable
: INSERT OVERWRITE TABLE? multipartIdentifier identifierList? (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable
| INSERT INTO TABLE? multipartIdentifier identifierList? partitionSpec? (IF NOT EXISTS)? #insertIntoTable
| INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir
| INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case u @ UnresolvedRelation(ident, _, isStreaming) =>
lookupTempView(ident, isStreaming).getOrElse(u)
case i @ InsertIntoStatement(UnresolvedRelation(ident, _, false), _, _, _, _) =>
case i @ InsertIntoStatement(UnresolvedRelation(ident, _, false), _, _, _, _, _) =>
lookupTempView(ident)
.map(view => i.copy(table = view))
.getOrElse(i)
Expand Down Expand Up @@ -922,7 +922,7 @@ class Analyzer(
.map(ResolvedTable(catalog.asTableCatalog, ident, _))
.getOrElse(u)

case i @ InsertIntoStatement(u @ UnresolvedRelation(_, _, false), _, _, _, _)
case i @ InsertIntoStatement(u @ UnresolvedRelation(_, _, false), _, _, _, _, _)
if i.query.resolved =>
lookupV2Relation(u.multipartIdentifier, u.options, false)
.map(v2Relation => i.copy(table = v2Relation))
Expand Down Expand Up @@ -992,7 +992,7 @@ class Analyzer(
}

def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp {
case i @ InsertIntoStatement(table, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(table, _, _, _, _, _) if i.query.resolved =>
val relation = table match {
case u @ UnresolvedRelation(_, _, false) =>
lookupRelation(u.multipartIdentifier, u.options, false).getOrElse(u)
Expand Down Expand Up @@ -1088,7 +1088,8 @@ class Analyzer(

object ResolveInsertInto extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _)
if i.query.resolved && i.columns.forall(_.resolved) =>
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw new AnalysisException(
Expand All @@ -1102,11 +1103,12 @@ class Analyzer(
val query = addStaticPartitionColumns(r, i.query, staticPartitions)

if (!i.overwrite) {
AppendData.byPosition(r, query)
AppendData.byPosition(r, query, i.columns)
} else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) {
OverwritePartitionsDynamic.byPosition(r, query)
OverwritePartitionsDynamic.byPosition(r, query, i.columns)
} else {
OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions))
OverwriteByExpression.byPosition(
r, query, staticDeleteExpression(r, staticPartitions), i.columns)
}
}

Expand Down Expand Up @@ -3028,38 +3030,50 @@ class Analyzer(
*/
object ResolveOutputRelation extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case append @ AppendData(table, query, _, isByName)
if table.resolved && query.resolved && !append.outputResolved =>
case i @ InsertIntoStatement(table, _, _, _, _, _)
if table.resolved && i.columns.exists(!_.resolved) =>
val tableOutputs = table.output.map(x => x.name -> x).toMap
val resolved = i.columns.map {
case u: UnresolvedAttribute => withPosition(u) {
table.resolve(u.nameParts, resolver).map(_.toAttribute)
.getOrElse(failAnalysis(s"Cannot resolve column name ${u.name}"))
}
case other => other
}
i.copy(columns = resolved)

case append @ AppendData(table, query, cols, _, isByName)
if table.resolved && query.resolved && (!append.outputResolved || cols.nonEmpty) =>
validateStoreAssignmentPolicy()
val projection =
TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf)
val projection = TableOutputResolver.resolveOutputColumns(
table.name, table.output, cols, query, isByName, conf)

if (projection != query) {
append.copy(query = projection)
append.copy(query = projection, columns = Nil)
} else {
append
}

case overwrite @ OverwriteByExpression(table, _, query, _, isByName)
if table.resolved && query.resolved && !overwrite.outputResolved =>
case overwrite @ OverwriteByExpression(table, _, query, cols, _, isByName)
if table.resolved && query.resolved && (!overwrite.outputResolved || cols.nonEmpty) =>
validateStoreAssignmentPolicy()
val projection =
TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf)
val projection = TableOutputResolver.resolveOutputColumns(
table.name, table.output, cols, query, isByName, conf)

if (projection != query) {
overwrite.copy(query = projection)
overwrite.copy(query = projection, columns = Nil)
} else {
overwrite
}

case overwrite @ OverwritePartitionsDynamic(table, query, _, isByName)
if table.resolved && query.resolved && !overwrite.outputResolved =>
case overwrite @ OverwritePartitionsDynamic(table, query, cols, _, isByName)
if table.resolved && query.resolved && (!overwrite.outputResolved || cols.nonEmpty) =>
validateStoreAssignmentPolicy()
val projection =
TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf)
val projection = TableOutputResolver.resolveOutputColumns(
table.name, table.output, cols, query, isByName, conf)

if (projection != query) {
overwrite.copy(query = projection)
overwrite.copy(query = projection, columns = Nil)
} else {
overwrite
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ trait CheckAnalysis extends PredicateHelper {
case u: UnresolvedRelation =>
u.failAnalysis(s"Table or view not found: ${u.multipartIdentifier.quoted}")

case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) =>
case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) =>
failAnalysis(s"Table not found: ${u.multipartIdentifier.quoted}")

case u: UnresolvedV2Relation if isView(u.originalNameParts) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,31 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.util.SchemaUtils

object TableOutputResolver {
def resolveOutputColumns(
tableName: String,
expected: Seq[Attribute],
specified: Seq[Attribute],
query: LogicalPlan,
byName: Boolean,
conf: SQLConf): LogicalPlan = {

if (expected.size < query.output.size) {
SchemaUtils.checkColumnNameDuplication(
specified.map(_.name), "in the column list", conf.resolver)

val expectedSize = if (specified.nonEmpty && expected.size != specified.size) {
throw new AnalysisException(
s"$tableName requires that the data to be inserted have the same number" +
s" of columns as the target table that has ${expected.size} column(s) but the" +
s" specified part has only ${specified.length} column(s)"
)
} else {
expected.size
}

if (expectedSize < query.output.size) {
throw new AnalysisException(
s"""Cannot write to '$tableName', too many data columns:
|Table columns: ${expected.map(c => s"'${c.name}'").mkString(", ")}
Expand All @@ -54,17 +69,24 @@ object TableOutputResolver {
}

} else {
if (expected.size > query.output.size) {
if (expectedSize > query.output.size) {
throw new AnalysisException(
s"""Cannot write to '$tableName', not enough data columns:
|Table columns: ${expected.map(c => s"'${c.name}'").mkString(", ")}
|Data columns: ${query.output.map(c => s"'${c.name}'").mkString(", ")}"""
.stripMargin)
}

query.output.zip(expected).flatMap {
case (queryExpr, tableAttr) =>
checkField(tableAttr, queryExpr, byName, conf, err => errors += err)
if (specified.nonEmpty) {
val nameToQueryExpr = specified.zip(query.output).toMap
expected.flatMap { tblAttr =>
checkField(tblAttr, nameToQueryExpr(tblAttr), byName, conf, err => errors += err)
}
} else {
query.output.zip(expected).flatMap {
case (queryExpr, tableAttr) =>
checkField(tableAttr, queryExpr, byName, conf, err => errors += err)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ package object dsl {
partition: Map[String, Option[String]] = Map.empty,
overwrite: Boolean = false,
ifPartitionNotExists: Boolean = false): LogicalPlan =
InsertIntoStatement(table, partition, logicalPlan, overwrite, ifPartitionNotExists)
InsertIntoStatement(table, partition, Nil, logicalPlan, overwrite, ifPartitionNotExists)

def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging

/**
* Parameters used for writing query to a table:
* (multipartIdentifier, partitionKeys, ifPartitionNotExists).
* (multipartIdentifier, tableColumnList, partitionKeys, ifPartitionNotExists).
*/
type InsertTableParams = (Seq[String], Map[String, Option[String]], Boolean)
type InsertTableParams = (Seq[String], Seq[Attribute], Map[String, Option[String]], Boolean)

/**
* Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
Expand All @@ -269,18 +269,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
ctx match {
case table: InsertIntoTableContext =>
val (tableIdent, partition, ifPartitionNotExists) = visitInsertIntoTable(table)
val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertIntoTable(table)
InsertIntoStatement(
UnresolvedRelation(tableIdent),
partition,
cols,
query,
overwrite = false,
ifPartitionNotExists)
case table: InsertOverwriteTableContext =>
val (tableIdent, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table)
val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table)
InsertIntoStatement(
UnresolvedRelation(tableIdent),
partition,
cols,
query,
overwrite = true,
ifPartitionNotExists)
Expand All @@ -301,13 +303,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
override def visitInsertIntoTable(
ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) {
val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier)
val cols = Option(ctx.identifierList())
.map(visitIdentifierList)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a check that column name must be one part?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary, it comes from a grammar that accepts a one-part name

.getOrElse(Nil).map(UnresolvedAttribute(_))
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)

if (ctx.EXISTS != null) {
operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx)
}

(tableIdent, partitionKeys, false)
(tableIdent, cols, partitionKeys, false)
}

/**
Expand All @@ -317,6 +322,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) {
assert(ctx.OVERWRITE() != null)
val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier)
val cols = Option(ctx.identifierList())
.map(visitIdentifierList).getOrElse(Nil).map(UnresolvedAttribute(_))
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)

val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
Expand All @@ -325,7 +332,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
dynamicPartitionKeys.keys.mkString(", "), ctx)
}

(tableIdent, partitionKeys, ctx.EXISTS() != null)
(tableIdent, cols, partitionKeys, ctx.EXISTS() != null)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.analysis.ViewType
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, FunctionResource}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.{DataType, StructType}
Expand Down Expand Up @@ -310,6 +310,7 @@ case class DescribeColumnStatement(
* An INSERT INTO statement, as parsed from SQL.
*
* @param table the logical plan representing the table.
* @param columns the list of columns that belong to the table.
* @param query the logical plan representing data to write to.
* @param overwrite overwrite existing table or partitions.
* @param partitionSpec a map from the partition key to the partition value (optional).
Expand All @@ -324,6 +325,7 @@ case class DescribeColumnStatement(
case class InsertIntoStatement(
table: LogicalPlan,
partitionSpec: Map[String, Option[String]],
columns: Seq[Attribute],
query: LogicalPlan,
overwrite: Boolean,
ifPartitionNotExists: Boolean) extends ParsedStatement {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,26 @@ trait V2WriteCommand extends Command {
case class AppendData(
table: NamedRelation,
query: LogicalPlan,
columns: Seq[Attribute],
writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand
isByName: Boolean) extends V2WriteCommand {
override lazy val resolved: Boolean = outputResolved && columns.forall(_.resolved)
}

object AppendData {
def byName(
table: NamedRelation,
df: LogicalPlan,
writeOptions: Map[String, String] = Map.empty): AppendData = {
new AppendData(table, df, writeOptions, isByName = true)
new AppendData(table, df, Nil, writeOptions, isByName = true)
}

def byPosition(
table: NamedRelation,
query: LogicalPlan,
columns: Seq[Attribute] = Nil,
writeOptions: Map[String, String] = Map.empty): AppendData = {
new AppendData(table, query, writeOptions, isByName = false)
new AppendData(table, query, columns, writeOptions, isByName = false)
}
}

Expand All @@ -84,9 +88,11 @@ case class OverwriteByExpression(
table: NamedRelation,
deleteExpr: Expression,
query: LogicalPlan,
columns: Seq[Attribute],
writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand {
override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved
override lazy val resolved: Boolean =
outputResolved && deleteExpr.resolved && columns.forall(_.resolved)
}

object OverwriteByExpression {
Expand All @@ -95,15 +101,16 @@ object OverwriteByExpression {
df: LogicalPlan,
deleteExpr: Expression,
writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = {
OverwriteByExpression(table, deleteExpr, df, writeOptions, isByName = true)
OverwriteByExpression(table, deleteExpr, df, Nil, writeOptions, isByName = true)
}

def byPosition(
table: NamedRelation,
query: LogicalPlan,
deleteExpr: Expression,
columns: Seq[Attribute] = Nil,
writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = {
OverwriteByExpression(table, deleteExpr, query, writeOptions, isByName = false)
OverwriteByExpression(table, deleteExpr, query, columns, writeOptions, isByName = false)
}
}

Expand All @@ -113,22 +120,26 @@ object OverwriteByExpression {
case class OverwritePartitionsDynamic(
table: NamedRelation,
query: LogicalPlan,
columns: Seq[Attribute] = Nil,
writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand
isByName: Boolean) extends V2WriteCommand {
override lazy val resolved: Boolean = outputResolved && columns.forall(_.resolved)
}

object OverwritePartitionsDynamic {
def byName(
table: NamedRelation,
df: LogicalPlan,
writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = {
OverwritePartitionsDynamic(table, df, writeOptions, isByName = true)
OverwritePartitionsDynamic(table, df, Nil, writeOptions, isByName = true)
}

def byPosition(
table: NamedRelation,
query: LogicalPlan,
columns: Seq[Attribute] = Nil,
writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = {
OverwritePartitionsDynamic(table, query, writeOptions, isByName = false)
OverwritePartitionsDynamic(table, query, columns, writeOptions, isByName = false)
}
}

Expand Down
Loading