Skip to content

Commit

Permalink
Does not support join in update statement #671
Browse files Browse the repository at this point in the history
  • Loading branch information
Tapac committed Mar 1, 2020
1 parent 19dd97a commit f01e0a1
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 26 deletions.
23 changes: 14 additions & 9 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,7 @@ class Join(
}
if (p.joinType != JoinType.CROSS) {
append(" ON ")
p.conditions.appendTo(this, " AND ") { (pkColumn, fkColumn) -> append(pkColumn, " = ", fkColumn) }
if (p.additionalConstraint != null) {
if (p.conditions.isNotEmpty()) {
append(" AND ")
}
append(" (")
append(SqlExpressionBuilder.(p.additionalConstraint)())
append(")")
}
p.appendConditions(this)
}
}
}
Expand Down Expand Up @@ -268,6 +260,19 @@ class Join(
init {
require(joinType == JoinType.CROSS || conditions.isNotEmpty() || additionalConstraint != null) { "Missing join condition on $${this.joinPart}" }
}

fun appendConditions(builder: QueryBuilder) = builder {
conditions.appendTo(this, " AND ") { (pkColumn, fkColumn) -> append(pkColumn, " = ", fkColumn) }
if (additionalConstraint != null) {
if (conditions.isNotEmpty()) {
append(" AND ")
}
append(" (")
append(SqlExpressionBuilder.(additionalConstraint)())
append(")")
}
}

}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jetbrains.exposed.sql.statements

import org.jetbrains.exposed.exceptions.throwUnsupportedException
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi

Expand All @@ -12,8 +13,13 @@ open class UpdateStatement(val targetsSet: ColumnSet, val limit: Int?, val where
return executeUpdate()
}

override fun prepareSQL(transaction: Transaction): String =
transaction.db.dialect.functionProvider.update(targetsSet, firstDataSet, limit, where, transaction)
override fun prepareSQL(transaction: Transaction): String {
return when (targetsSet) {
is Table -> transaction.db.dialect.functionProvider.update(targetsSet, firstDataSet, limit, where, transaction)
is Join -> transaction.db.dialect.functionProvider.update(targetsSet, firstDataSet, limit, where, transaction)
else -> transaction.throwUnsupportedException("UPDATE with ${targetsSet::class.simpleName} unsupported")
}
}

override fun arguments(): Iterable<Iterable<Pair<IColumnType, Any?>>> = QueryBuilder(true).run {
values.forEach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,23 +350,23 @@ abstract class FunctionProvider {
/**
* Returns the SQL command that updates one or more rows of a table.
*
* @param targets Column set to update values from.
* @param target Table to update values from.
* @param columnsAndValues Pairs of column to update and values to update with.
* @param limit Maximum number of rows to update.
* @param where Condition that decides the rows to update.
* @param transaction Transaction where the operation is executed.
*/
open fun update(
targets: ColumnSet,
target: Table,
columnsAndValues: List<Pair<Column<*>, Any?>>,
limit: Int?,
where: Op<Boolean>?,
transaction: Transaction
): String = with(QueryBuilder(true)) {
+"UPDATE "
targets.describe(transaction, this)
+" SET "
columnsAndValues.appendTo(this) { (col, value) ->
target.describe(transaction, this)

columnsAndValues.appendTo(this, prefix = " SET ") { (col, value) ->
append("${transaction.identity(col)}=")
registerArgument(col, value)
}
Expand All @@ -379,6 +379,23 @@ abstract class FunctionProvider {
toString()
}

/**
* Returns the SQL command that updates one or more rows of a join.
*
* @param targets Join to update values from.
* @param columnsAndValues Pairs of column to update and values to update with.
* @param limit Maximum number of rows to update.
* @param where Condition that decides the rows to update.
* @param transaction Transaction where the operation is executed.
*/
open fun update(
targets: Join,
columnsAndValues: List<Pair<Column<*>, Any?>>,
limit: Int?,
where: Op<Boolean>?,
transaction: Transaction
) : String = transaction.throwUnsupportedException("UPDATE with a join clause is unsupported")

/**
* Returns the SQL command that insert a new row into a table, but if another row with the same primary/unique key already exists then it updates the values of that row instead.
* This operation is also known as "Insert or update".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,29 @@ internal open class MysqlFunctionProvider : FunctionProvider() {
val def = super.delete(false, table, where, limit, transaction)
return if (ignore) def.replaceFirst("DELETE", "DELETE IGNORE") else def
}

override fun update(
targets: Join,
columnsAndValues: List<Pair<Column<*>, Any?>>,
limit: Int?,
where: Op<Boolean>?,
transaction: Transaction
): String = with(QueryBuilder(true)) {
+"UPDATE "
targets.describe(transaction, this)
+" SET "
columnsAndValues.appendTo(this) { (col, value) ->
append("${transaction.identity(col)}=")
registerArgument(col, value)
}

where?.let {
+" WHERE "
+it
}
limit?.let { +" LIMIT $it" }
toString()
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,54 @@ internal object OracleFunctionProvider : FunctionProvider() {
}

override fun update(
targets: ColumnSet,
target: Table,
columnsAndValues: List<Pair<Column<*>, Any?>>,
limit: Int?,
where: Op<Boolean>?,
transaction: Transaction
): String {
val def = super.update(targets, columnsAndValues, null, where, transaction)
val def = super.update(target, columnsAndValues, null, where, transaction)
return when {
limit != null && where != null -> "$def AND ROWNUM <= $limit"
limit != null -> "$def WHERE ROWNUM <= $limit"
else -> def
}
}

override fun update(
targets: Join,
columnsAndValues: List<Pair<Column<*>, Any?>>,
limit: Int?,
where: Op<Boolean>?,
transaction: Transaction
): String = with(QueryBuilder(true)) {
val tableToUpdate = columnsAndValues.map { it.first.table }.distinct().singleOrNull()
if (tableToUpdate == null) {
transaction.throwUnsupportedException("Oracle supports a join updates with a single table columns to update.")
}
if (targets.joinParts.any { it.joinType != JoinType.INNER }) {
exposedLogger.warn("All tables in UPDATE statement will be joined with inner join")
}
+"UPDATE ("
val subQuery = targets.selectAll()
where?.let {
subQuery.adjustWhere { it }
}
subQuery.prepareSQL(this)
+") x"

columnsAndValues.appendTo(this, prefix = " SET ") { (col, value) ->
append("${transaction.identity(col)}=")
registerArgument(col, value)
}

limit?.let {
"WHERE ROWNUM <= $it"
}

toString()
}

override fun delete(
ignore: Boolean,
table: Table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ internal object PostgreSQLFunctionProvider : FunctionProvider() {
}

override fun update(
targets: ColumnSet,
target: Table,
columnsAndValues: List<Pair<Column<*>, Any?>>,
limit: Int?,
where: Op<Boolean>?,
Expand All @@ -111,7 +111,51 @@ internal object PostgreSQLFunctionProvider : FunctionProvider() {
if (limit != null) {
transaction.throwUnsupportedException("PostgreSQL doesn't support LIMIT in UPDATE clause.")
}
return super.update(targets, columnsAndValues, limit, where, transaction)
return super.update(target, columnsAndValues, limit, where, transaction)
}

override fun update(
targets: Join,
columnsAndValues: List<Pair<Column<*>, Any?>>,
limit: Int?,
where: Op<Boolean>?,
transaction: Transaction
): String = with(QueryBuilder(true)) {
if (limit != null) {
transaction.throwUnsupportedException("PostgreSQL doesn't support LIMIT in UPDATE clause.")
}
val tableToUpdate = columnsAndValues.map { it.first.table }.distinct().singleOrNull()
if (tableToUpdate == null) {
transaction.throwUnsupportedException("PostgreSQL supports a join updates with a single table columns to update.")
}
if (targets.joinParts.any { it.joinType != JoinType.INNER }) {
exposedLogger.warn("All tables in UPDATE statement will be joined with inner join")
}
+"UPDATE "
tableToUpdate.describe(transaction, this)
+" SET "
columnsAndValues.appendTo(this) { (col, value) ->
append("${transaction.identity(col)}=")
registerArgument(col, value)
}
+" FROM "
if (targets.table != tableToUpdate)
targets.table.describe(transaction, this)

targets.joinParts.appendTo(this, ",") {
if (it.joinPart != tableToUpdate)
it.joinPart.describe(transaction, this)
}
+" WHERE "
targets.joinParts.appendTo(this, " AND ") {
it.appendConditions(this)
}
where?.let {
+ " AND "
+it
}
limit?.let { +" LIMIT $it" }
toString()
}

override fun replace(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,55 @@ internal object SQLServerFunctionProvider : FunctionProvider() {
append("DATEPART(MINUTE, ", expr, ")")
}

override fun update(targets: ColumnSet, columnsAndValues: List<Pair<Column<*>, Any?>>, limit: Int?, where: Op<Boolean>?, transaction: Transaction): String {
val def = super.update(targets, columnsAndValues, null, where, transaction)
override fun update(target: Table, columnsAndValues: List<Pair<Column<*>, Any?>>, limit: Int?, where: Op<Boolean>?, transaction: Transaction): String {
val def = super.update(target, columnsAndValues, null, where, transaction)
return if (limit != null) def.replaceFirst("UPDATE", "UPDATE TOP($limit)") else def
}

override fun update(
targets: Join,
columnsAndValues: List<Pair<Column<*>, Any?>>,
limit: Int?,
where: Op<Boolean>?,
transaction: Transaction
): String = with(QueryBuilder(true)) {
val tableToUpdate = columnsAndValues.map { it.first.table }.distinct().singleOrNull()
if (tableToUpdate == null) {
transaction.throwUnsupportedException("SQLServer supports a join updates with a single table columns to update.")
}
if (targets.joinParts.any { it.joinType != JoinType.INNER }) {
exposedLogger.warn("All tables in UPDATE statement will be joined with inner join")
}
if (limit != null)
+"UPDATE TOP($limit)"
else
+"UPDATE "
tableToUpdate.describe(transaction, this)
+" SET "
columnsAndValues.appendTo(this) { (col, value) ->
append("${transaction.identity(col)}=")
registerArgument(col, value)
}
+" FROM "
if (targets.table != tableToUpdate)
targets.table.describe(transaction, this)

targets.joinParts.appendTo(this, ",") {
if (it.joinPart != tableToUpdate)
it.joinPart.describe(transaction, this)
}
+" WHERE "
targets.joinParts.appendTo(this, " AND ") {
it.appendConditions(this)
}
where?.let {
+ " AND "
+it
}
limit?.let { +" LIMIT $it" }
toString()
}

override fun delete(ignore: Boolean, table: Table, where: String?, limit: Int?, transaction: Transaction): String {
val def = super.delete(ignore, table, where, null, transaction)
return if (limit != null) def.replaceFirst("DELETE", "DELETE TOP($limit)") else def
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ internal object SQLiteFunctionProvider : FunctionProvider() {
}

override fun update(
targets: ColumnSet,
target: Table,
columnsAndValues: List<Pair<Column<*>, Any?>>,
limit: Int?,
where: Op<Boolean>?,
Expand All @@ -108,7 +108,7 @@ internal object SQLiteFunctionProvider : FunctionProvider() {
if (limit != null) {
transaction.throwUnsupportedException("SQLite doesn't support LIMIT in UPDATE clause.")
}
return super.update(targets, columnsAndValues, limit, where, transaction)
return super.update(target, columnsAndValues, limit, where, transaction)
}

override fun delete(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package org.jetbrains.exposed.sql.tests.shared.dml

import org.jetbrains.exposed.exceptions.UnsupportedByDialectException
import org.jetbrains.exposed.sql.select
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.TestDB
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.jetbrains.exposed.sql.tests.shared.expectException
import org.jetbrains.exposed.sql.update
import org.junit.Test

class UpdateTests : DatabaseTestsBase() {
Expand Down Expand Up @@ -56,4 +55,19 @@ class UpdateTests : DatabaseTestsBase() {
}
}
}

@Test
fun testUpdateWithJoin() {
val dialects = listOf(TestDB.H2, TestDB.SQLITE, TestDB.H2_MYSQL)
withCitiesAndUsers(dialects) { cities, users, userData ->
val join = users.innerJoin(userData)
join.update {
it[userData.comment] = users.name
}

join.selectAll().forEach {
assertEquals(it[users.name], it[userData.comment])
}
}
}
}

0 comments on commit f01e0a1

Please sign in to comment.