Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PostgreSQL RETURNING clause for INSERT / UPDATE / DELETE #1271

Closed
Jakobeha opened this issue Jun 17, 2021 · 5 comments · Fixed by #2061
Closed

PostgreSQL RETURNING clause for INSERT / UPDATE / DELETE #1271

Jakobeha opened this issue Jun 17, 2021 · 5 comments · Fixed by #2061
Assignees

Comments

@Jakobeha
Copy link

Jakobeha commented Jun 17, 2021

Hello,

I'm not sure if there are other statements too. But PostgreSQL allows you to specify RETURNING on these statements to return data from the inserted / updated / deleted rows. For instance, you can delete rows and return them at the same time.

This feature would really be useful to me, and I didn't see any existing functionality or issues discussing it.

I'm planning to hack together my own implementation first by subclassing Statement, and I'll post my code and maybe create a PR.

@vlsi
Copy link

vlsi commented Jun 19, 2021

Just in case, returning can be both "columns" or "aggregates" (e.g. returning count(*))
Oracle supports that as well.

@Jakobeha
Copy link
Author

Here is the code I have so far. It can surely be cleaned up and might have bugs, but it does the job for me. Just include these 3 files and then use Table#updateReturning and Table#deleteReturning. Remember that PostgreSQL does not support LIMIT.

A lot of it is copy / pasted from jetbrains exposed classes. It would definitely be much cleaner to refactor the base classes to have a returning clause instead. But until then, for anyone who wants this functionality, this is a good workaround.

ReturningStatement.kt

import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi
import org.jetbrains.exposed.sql.statements.Statement
import org.jetbrains.exposed.sql.statements.StatementType
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.sql.ResultSet

abstract class ReturningStatement(type: StatementType, targets: List<Table>) :
  Iterable<ResultRow>, Statement<ResultSet>(type, targets) {
  protected val transaction get() = TransactionManager.current()

  abstract val set: FieldSet

  override fun PreparedStatementApi.executeInternal(transaction: Transaction): ResultSet =
    executeQuery()

  private var iterator: Iterator<ResultRow>? = null

  fun exec() {
    require(iterator == null) { "already executed" }

    val resultIterator = ResultIterator(transaction.exec(this)!!)
    iterator = if (transaction.db.supportsMultipleResultSets) resultIterator
    else Iterable { resultIterator }.toList().iterator()
  }

  override fun iterator(): Iterator<ResultRow> =
    iterator ?: throw IllegalStateException("must call exec() first")

  protected inner class ResultIterator(val rs: ResultSet) : Iterator<ResultRow> {
    private var hasNext: Boolean? = null

    private val fieldsIndex = set.realFields.toSet().mapIndexed { index, expression -> expression to index }.toMap()

    override operator fun next(): ResultRow {
      if (hasNext == null) hasNext()
      if (hasNext == false) throw NoSuchElementException()
      hasNext = null
      return ResultRow.create(rs, fieldsIndex)
    }

    override fun hasNext(): Boolean {
      if (hasNext == null) hasNext = rs.next()
      if (hasNext == false) rs.close()
      return hasNext!!
    }
  }
}

DeleteReturningStatement.kt

import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.StatementType

class DeleteReturningStatement(
  private val table: Table,
  private val where: Op<Boolean>? = null,
  private val limit: Int? = 0,
  private val returning: ColumnSet? = null
) : ReturningStatement(StatementType.DELETE, listOf(table)) {
  override val set: FieldSet = returning ?: table

  override fun prepareSQL(transaction: Transaction): String = buildString {
    append("DELETE FROM ")
    append(transaction.identity(table))
    if (where != null) {
      append(" WHERE ")
      append(QueryBuilder(true).append(where).toString())
    }
    if (limit != null) {
      append(" LIMIT ")
      append(limit)
    }
    append(" RETURNING ")
    if (returning != null) {
      append(QueryBuilder(true).append(returning).toString())
    } else {
      append("*")
    }
  }

  override fun arguments(): Iterable<Iterable<Pair<IColumnType, Any?>>> =
    QueryBuilder(true).run {
      where?.toQueryBuilder(this)
      listOf(args)
    }

  companion object {
    fun where(
      table: Table,
      op: Op<Boolean>,
      limit: Int? = 0,
      returning: ColumnSet? = null
    ): DeleteReturningStatement = DeleteReturningStatement(
      table,
      op,
      limit,
      returning
    ).apply {
      exec()
    }
  }
}

fun Table.deleteReturningWhere(
  limit: Int? = 0,
  returning: ColumnSet? = null,
  where: SqlExpressionBuilder.() -> Op<Boolean>
): DeleteReturningStatement =
  DeleteReturningStatement.where(
    this,
    SqlExpressionBuilder.run(where),
    limit,
    returning
  )

UpdateReturningStatement.kt

import org.jetbrains.exposed.dao.id.EntityID
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.StatementType

class UpdateReturningStatement(
  private val table: Table,
  private val where: Op<Boolean>? = null,
  private val limit: Int? = null,
  private val returning: ColumnSet? = null
) : ReturningStatement(StatementType.DELETE, listOf(table)) {
  override val set: FieldSet = returning ?: table

  private val firstDataSet: List<Pair<Column<*>, Any?>>
    get() = values.toList()

  override fun prepareSQL(transaction: Transaction): String =
    with(QueryBuilder(true)) {
      +"UPDATE "
      table.describe(transaction, this)

      firstDataSet.appendTo(this, prefix = " SET ") { (col, value) ->
        append("${transaction.identity(col)}=")
        registerArgument(col, value)
      }

      where?.let {
        +" WHERE "
        +it
      }
      limit?.let {
        +" LIMIT "
        +it
      }

      +" RETURNING "
      if (returning != null) {
        append(returning)
      } else {
        +"*"
      }

      toString()
    }

  override fun arguments(): Iterable<Iterable<Pair<IColumnType, Any?>>> =
    QueryBuilder(true).run {
      for ((key, value) in values) {
        registerArgument(key, value)
      }
      where?.toQueryBuilder(this)
      listOf(args)
    }

  // region UpdateBuilder
  private val values: MutableMap<Column<*>, Any?> = LinkedHashMap()

  operator fun <S> set(column: Column<S>, value: S) {
    when {
      values.containsKey(column) -> error("$column is already initialized")
      !column.columnType.nullable && value == null -> error("Trying to set null to not nullable column $column")
      else -> values[column] = value
    }
  }

  @JvmName("setWithEntityIdExpression")
  operator fun <S, ID : EntityID<S>, E : Expression<S>> set(
    column: Column<ID>,
    value: E
  ) {
    require(!values.containsKey(column)) { "$column is already initialized" }
    values[column] = value
  }

  @JvmName("setWithEntityIdValue")
  operator fun <S : Comparable<S>, ID : EntityID<S>, E : S?> set(
    column: Column<ID>,
    value: E
  ) {
    require(!values.containsKey(column)) { "$column is already initialized" }
    values[column] = value
  }

  operator fun <T, S : T, E : Expression<S>> set(column: Column<T>, value: E) =
    update(column, value)

  operator fun <S> set(column: CompositeColumn<S>, value: S) {
    @Suppress("UNCHECKED_CAST")
    column.getRealColumnsWithValues(value).forEach { (realColumn, itsValue) ->
      set(
        realColumn as Column<Any?>,
        itsValue
      )
    }
  }

  fun <T, S : T?> update(column: Column<T>, value: Expression<S>) {
    require(!values.containsKey(column)) { "$column is already initialized" }
    values[column] = value
  }

  fun <T, S : T?> update(
    column: Column<T>,
    value: SqlExpressionBuilder.() -> Expression<S>
  ) {
    require(!values.containsKey(column)) { "$column is already initialized" }
    values[column] = SqlExpressionBuilder.value()
  }
  // endregion
}

fun <T : Table> T.updateReturning(
  where: SqlExpressionBuilder.() -> Op<Boolean>,
  limit: Int? = null,
  returning: ColumnSet? = null,
  body: T.(UpdateReturningStatement) -> Unit
): UpdateReturningStatement = UpdateReturningStatement(
  this,
  SqlExpressionBuilder.run(where),
  limit,
  returning
).apply {
  this@updateReturning.body(this)
  exec()
}

@stengvac
Copy link
Contributor

stengvac commented Oct 1, 2021

In case of need only returning * on update there is shorter impl.

class UpdateReturningStatement(
    table: Table,
    where: Op<Boolean>? = null,
) : UpdateStatement(table, null, where) {

    var resultRows: List<ResultRow> = listOf()
        private set

    override fun PreparedStatementApi.executeInternal(transaction: Transaction): Int {
        if (values.isEmpty()) return 0
        //executeUpdate is return only number of affected so it can't be used
        val updatedReturning = executeQuery()
        resultRows = ResultIterator(updatedReturning, targetsSet).iterator().asSequence().toList()

        return resultRows.size
    }

    override fun prepareSQL(transaction: Transaction): String {
        val sql = super.prepareSQL(transaction)
        return QueryBuilder(prepared = true).apply {
            append(sql)
            targetsSet.realFields.appendTo(prefix = " RETURNING ") {
                it.toQueryBuilder(this)
            }
        }.toString()
    }

    //copied from AbstractQuery
    private class ResultIterator(
        private val rs: ResultSet,
        fieldSet: FieldSet
    ) : Iterator<ResultRow> {
        private var hasNext: Boolean? = null
        private val fieldsIndex = fieldSet.realFields.toSet().mapIndexed { index, expression -> expression to index }.toMap()

        override operator fun next(): ResultRow {
            if (hasNext == null) hasNext()
            if (hasNext == false) throw NoSuchElementException()
            hasNext = null
            return ResultRow.create(rs, fieldsIndex)
        }

        override fun hasNext(): Boolean {
            if (hasNext == null) hasNext = rs.next()
            if (hasNext == false) rs.close()
            return hasNext!!
        }
    }
}

fun <T : Table> T.updateReturning(
    where: SqlExpressionBuilder.() -> Op<Boolean>,
    body: T.(UpdateReturningStatement) -> Unit
): List<ResultRow> {
    val statement = UpdateReturningStatement(
        this,
        SqlExpressionBuilder.run(where)
    )
    body(statement)
    statement.execute(TransactionManager.current())!!

    return statement.resultRows
}

@joc-a
Copy link
Collaborator

joc-a commented May 9, 2023

Hi @Jakobeha, thanks for submitting this issue and the code snippets. Please go ahead and open a PR for this, including the necessary tests for the functionality, and we will review and get back to you.

@Flaxoos
Copy link

Flaxoos commented Oct 30, 2023

Hey @Jakobeha @joc-a , what's the status on this? has a PR been made?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants