diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index 3e440e8005..e5e218ea44 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -44,18 +45,15 @@ struct BindStream { Handle bind_schema; int64_t current_row = -1; - struct ArrowSchemaView bind_schema_view; std::vector bind_schema_fields; + std::vector> bind_field_writers; // OIDs for parameter types std::vector param_types; std::vector param_values; - std::vector param_lengths; std::vector param_formats; - std::vector param_values_offsets; - std::vector param_values_buffer; - // XXX: this assumes fixed-length fields only - will need more - // consideration to deal with variable-length fields + std::vector param_lengths; + Handle param_buffer; bool has_tz_field = false; std::string tz_setting; @@ -77,10 +75,11 @@ struct BindStream { CHECK_NA_DETAIL(INTERNAL, ArrowArrayStreamGetSchema(&bind.value, &bind_schema.value, &na_error), &na_error, error); + + struct ArrowSchemaView bind_schema_view; CHECK_NA_DETAIL(INTERNAL, ArrowSchemaViewInit(&bind_schema_view, &bind_schema.value, &na_error), &na_error, error); - if (bind_schema_view.type != ArrowType::NANOARROW_TYPE_STRUCT) { SetError(error, "%s", "[libpq] Bind parameters must have type STRUCT"); return ADBC_STATUS_INVALID_STATE; @@ -99,173 +98,90 @@ struct BindStream { ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, &na_error), &na_error, error); + ArrowBufferInit(¶m_buffer.value); + return std::move(callback)(); } - AdbcStatusCode SetParamTypes(const PostgresTypeResolver& type_resolver, - struct AdbcError* error) { + AdbcStatusCode SetParamTypes(PGconn* pg_conn, const PostgresTypeResolver& type_resolver, + const bool autocommit, struct AdbcError* error) { param_types.resize(bind_schema->n_children); param_values.resize(bind_schema->n_children); param_lengths.resize(bind_schema->n_children); param_formats.resize(bind_schema->n_children, kPgBinaryFormat); - param_values_offsets.reserve(bind_schema->n_children); - - for (size_t i = 0; i < bind_schema_fields.size(); i++) { - PostgresTypeId type_id; - switch (bind_schema_fields[i].type) { - case ArrowType::NANOARROW_TYPE_BOOL: - type_id = PostgresTypeId::kBool; - param_lengths[i] = 1; - break; - case ArrowType::NANOARROW_TYPE_INT8: - case ArrowType::NANOARROW_TYPE_INT16: - type_id = PostgresTypeId::kInt2; - param_lengths[i] = 2; - break; - case ArrowType::NANOARROW_TYPE_INT32: - type_id = PostgresTypeId::kInt4; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_INT64: - type_id = PostgresTypeId::kInt8; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_FLOAT: - type_id = PostgresTypeId::kFloat4; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_DOUBLE: - type_id = PostgresTypeId::kFloat8; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - type_id = PostgresTypeId::kText; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_BINARY: - type_id = PostgresTypeId::kBytea; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_DATE32: - type_id = PostgresTypeId::kDate; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_TIMESTAMP: - type_id = PostgresTypeId::kTimestamp; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - type_id = PostgresTypeId::kInterval; - param_lengths[i] = 16; - break; - case ArrowType::NANOARROW_TYPE_DECIMAL128: - case ArrowType::NANOARROW_TYPE_DECIMAL256: - type_id = PostgresTypeId::kNumeric; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_DICTIONARY: { - struct ArrowSchemaView value_view; - CHECK_NA(INTERNAL, - ArrowSchemaViewInit(&value_view, bind_schema->children[i]->dictionary, - nullptr), - error); - switch (value_view.type) { - case NANOARROW_TYPE_BINARY: - case NANOARROW_TYPE_LARGE_BINARY: - type_id = PostgresTypeId::kBytea; - param_lengths[i] = 0; - break; - case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_STRING: - type_id = PostgresTypeId::kText; - param_lengths[i] = 0; - break; - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", - bind_schema->children[i]->name, - "') has unsupported dictionary value parameter type ", - ArrowTypeString(value_view.type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - break; - } - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", bind_schema->children[i]->name, - "') has unsupported parameter type ", - ArrowTypeString(bind_schema_fields[i].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; + bind_field_writers.resize(bind_schema->n_children); + + for (size_t i = 0; i < bind_field_writers.size(); i++) { + PostgresType type; + CHECK_NA_DETAIL(INTERNAL, + PostgresType::FromSchema(type_resolver, bind_schema->children[i], + &type, &na_error), + &na_error, error); + + // tz-aware timestamps require special handling to set the timezone to UTC + // prior to sending over the binary protocol; must be reset after execute + if (!has_tz_field && type.type_id() == PostgresTypeId::kTimestamptz) { + RAISE_ADBC(SetDatabaseTimezoneUTC(pg_conn, autocommit, error)); + has_tz_field = true; } - param_types[i] = type_resolver.GetOID(type_id); - if (param_types[i] == 0) { - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", bind_schema->children[i]->name, - "') has type with no corresponding PostgreSQL type ", - ArrowTypeString(bind_schema_fields[i].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - } + std::unique_ptr writer; + CHECK_NA_DETAIL( + INTERNAL, + MakeCopyFieldWriter(bind_schema->children[i], array_view->children[i], + type_resolver, &writer, &na_error), + &na_error, error); - size_t param_values_length = 0; - for (int length : param_lengths) { - param_values_offsets.push_back(param_values_length); - param_values_length += length; + param_types[i] = type.oid(); + param_formats[i] = kPgBinaryFormat; + bind_field_writers[i] = std::move(writer); } - param_values_buffer.resize(param_values_length); + return ADBC_STATUS_OK; } - AdbcStatusCode Prepare(PGconn* pg_conn, const std::string& query, - struct AdbcError* error, const bool autocommit) { - // tz-aware timestamps require special handling to set the timezone to UTC - // prior to sending over the binary protocol; must be reset after execute - for (int64_t col = 0; col < bind_schema->n_children; col++) { - if ((bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) && - (strcmp("", bind_schema_fields[col].timezone))) { - has_tz_field = true; - - if (autocommit) { - PGresult* begin_result = PQexec(pg_conn, "BEGIN"); - if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, begin_result, - "[libpq] Failed to begin transaction for timezone data: %s", - PQerrorMessage(pg_conn)); - PQclear(begin_result); - return code; - } - PQclear(begin_result); - } - - PGresult* get_tz_result = PQexec(pg_conn, "SELECT current_setting('TIMEZONE')"); - if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) { - AdbcStatusCode code = SetError(error, get_tz_result, - "[libpq] Could not query current timezone: %s", - PQerrorMessage(pg_conn)); - PQclear(get_tz_result); - return code; - } - - tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0)); - PQclear(get_tz_result); - - PGresult* set_utc_result = PQexec(pg_conn, "SET TIME ZONE 'UTC'"); - if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = SetError(error, set_utc_result, - "[libpq] Failed to set time zone to UTC: %s", - PQerrorMessage(pg_conn)); - PQclear(set_utc_result); - return code; - } - PQclear(set_utc_result); - break; + AdbcStatusCode SetDatabaseTimezoneUTC(PGconn* pg_conn, const bool autocommit, + struct AdbcError* error) { + if (autocommit) { + PGresult* begin_result = PQexec(pg_conn, "BEGIN"); + if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, begin_result, + "[libpq] Failed to begin transaction for timezone data: %s", + PQerrorMessage(pg_conn)); + PQclear(begin_result); + return code; } + PQclear(begin_result); } + PGresult* get_tz_result = PQexec(pg_conn, "SELECT current_setting('TIMEZONE')"); + if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) { + AdbcStatusCode code = + SetError(error, get_tz_result, "[libpq] Could not query current timezone: %s", + PQerrorMessage(pg_conn)); + PQclear(get_tz_result); + return code; + } + + tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0)); + PQclear(get_tz_result); + + PGresult* set_utc_result = PQexec(pg_conn, "SET TIME ZONE 'UTC'"); + if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, set_utc_result, "[libpq] Failed to set time zone to UTC: %s", + PQerrorMessage(pg_conn)); + PQclear(set_utc_result); + return code; + } + PQclear(set_utc_result); + + return ADBC_STATUS_OK; + } + + AdbcStatusCode Prepare(PGconn* pg_conn, const std::string& query, + struct AdbcError* error) { PGresult* result = PQprepare(pg_conn, /*stmtName=*/"", query.c_str(), /*nParams=*/bind_schema->n_children, param_types.data()); if (PQresultStatus(result) != PGRES_COMMAND_OK) { @@ -317,170 +233,40 @@ struct BindStream { AdbcStatusCode BindAndExecuteCurrentRow(PGconn* pg_conn, PGresult** result_out, int result_format, AdbcError* error) { - int64_t row = current_row; + param_buffer->size_bytes = 0; + int64_t last_offset = 0; for (int64_t col = 0; col < array_view->n_children; col++) { - if (ArrowArrayViewIsNull(array_view->children[col], row)) { - param_values[col] = nullptr; - continue; + if (!ArrowArrayViewIsNull(array_view->children[col], current_row)) { + // Note that this Write() call currently writes the (int32_t) byte size of the + // field in addition to the serialized value. + CHECK_NA_DETAIL( + INTERNAL, + bind_field_writers[col]->Write(¶m_buffer.value, current_row, &na_error), + &na_error, error); } else { - param_values[col] = param_values_buffer.data() + param_values_offsets[col]; + CHECK_NA(INTERNAL, ArrowBufferAppendInt32(¶m_buffer.value, 0), error); } - switch (bind_schema_fields[col].type) { - case ArrowType::NANOARROW_TYPE_BOOL: { - const int8_t val = - ArrowBitGet(array_view->children[col]->buffer_views[1].data.as_uint8, row); - std::memcpy(param_values[col], &val, sizeof(int8_t)); - break; - } - - case ArrowType::NANOARROW_TYPE_INT8: { - const int16_t val = - array_view->children[col]->buffer_views[1].data.as_int8[row]; - const uint16_t value = ToNetworkInt16(val); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT16: { - const uint16_t value = ToNetworkInt16( - array_view->children[col]->buffer_views[1].data.as_int16[row]); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT32: { - const uint32_t value = ToNetworkInt32( - array_view->children[col]->buffer_views[1].data.as_int32[row]); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT64: { - const int64_t value = ToNetworkInt64( - array_view->children[col]->buffer_views[1].data.as_int64[row]); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_FLOAT: { - const uint32_t value = ToNetworkFloat4( - array_view->children[col]->buffer_views[1].data.as_float[row]); - std::memcpy(param_values[col], &value, sizeof(uint32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DOUBLE: { - const uint64_t value = ToNetworkFloat8( - array_view->children[col]->buffer_views[1].data.as_double[row]); - std::memcpy(param_values[col], &value, sizeof(uint64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - case ArrowType::NANOARROW_TYPE_BINARY: { - const ArrowBufferView view = - ArrowArrayViewGetBytesUnsafe(array_view->children[col], row); - // TODO: overflow check? - param_lengths[col] = static_cast(view.size_bytes); - param_values[col] = const_cast(view.data.as_char); - break; - } - case ArrowType::NANOARROW_TYPE_DATE32: { - // 2000-01-01 - constexpr int32_t kPostgresDateEpoch = 10957; - const int32_t raw_value = - array_view->children[col]->buffer_views[1].data.as_int32[row]; - if (raw_value < INT32_MIN + kPostgresDateEpoch) { - SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1, - "('", bind_schema->children[col]->name, "') Row #", row + 1, - "has value which exceeds postgres date limits"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_TIMESTAMP: { - int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row]; - - bool overflow_safe = true; - - auto unit = bind_schema_fields[col].time_unit; - - switch (unit) { - case NANOARROW_TIME_UNIT_SECOND: - overflow_safe = - val <= kMaxSafeSecondsToMicros && val >= kMinSafeSecondsToMicros; - if (overflow_safe) { - val *= 1000000; - } - - break; - case NANOARROW_TIME_UNIT_MILLI: - overflow_safe = - val <= kMaxSafeMillisToMicros && val >= kMinSafeMillisToMicros; - if (overflow_safe) { - val *= 1000; - } - break; - case NANOARROW_TIME_UNIT_MICRO: - break; - case NANOARROW_TIME_UNIT_NANO: - val /= 1000; - break; - } - - if (!overflow_safe) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 "' which exceeds PostgreSQL timestamp limits", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (val < (std::numeric_limits::min)() + kPostgresTimestampEpoch) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 "' which would underflow", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) { - const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - } else if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_DURATION) { - // postgres stores an interval as a 64 bit offset in microsecond - // resolution alongside a 32 bit day and 32 bit month - // for now we just send 0 for the day / month values - const uint64_t value = ToNetworkInt64(val); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t)); - } - break; - } - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { - struct ArrowInterval interval; - ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); - ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval); - - const uint32_t months = ToNetworkInt32(interval.months); - const uint32_t days = ToNetworkInt32(interval.days); - const uint64_t ms = ToNetworkInt64(interval.ns / 1000); - - std::memcpy(param_values[col], &ms, sizeof(uint64_t)); - std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t)); - std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t), &months, - sizeof(uint32_t)); - break; - } - default: - SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('", - bind_schema->children[col]->name, - "') has unsupported type for ingestion ", - ArrowTypeString(bind_schema_fields[col].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; + + int64_t param_length = param_buffer->size_bytes - last_offset - sizeof(int32_t); + if (param_length > (std::numeric_limits::max)()) { + SetError(error, "Parameter %" PRId64 " serialized to >2GB of binary", col); + return ADBC_STATUS_INTERNAL; + } + + param_lengths[col] = static_cast(param_length); + last_offset = param_buffer->size_bytes; + } + + last_offset = 0; + for (int64_t col = 0; col < array_view->n_children; col++) { + last_offset += sizeof(int32_t); + if (param_lengths[col] == 0) { + param_values[col] = nullptr; + } else { + param_values[col] = reinterpret_cast(param_buffer->data) + last_offset; } + last_offset += param_lengths[col]; } PGresult* result = diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index d5d94c4c87..08742d5173 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -1139,10 +1139,11 @@ TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) { IsOkStatus(&error)); ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); - ASSERT_THAT(error.message, - ::testing::HasSubstr("Row #1 has value '9223372036854775807' which " - "exceeds PostgreSQL timestamp limits")); + IsStatus(ADBC_STATUS_INTERNAL, &error)); + ASSERT_THAT( + error.message, + ::testing::HasSubstr( + "Row 0 timestamp value 9223372036854775807 with unit 0 would overflow")); } { @@ -1169,10 +1170,11 @@ TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) { IsOkStatus(&error)); ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); - ASSERT_THAT(error.message, - ::testing::HasSubstr("Row #1 has value '-9223372036854775808' which " - "exceeds PostgreSQL timestamp limits")); + IsStatus(ADBC_STATUS_INTERNAL, &error)); + ASSERT_THAT( + error.message, + ::testing::HasSubstr( + "Row 0 timestamp value -9223372036854775808 with unit 0 would overflow")); } } diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc index 21bc2bdbc4..c297703098 100644 --- a/c/driver/postgresql/result_reader.cc +++ b/c/driver/postgresql/result_reader.cc @@ -142,7 +142,7 @@ AdbcStatusCode PqResultArrayReader::Initialize(int64_t* rows_affected, // there is a result with more than zero rows to populate. if (bind_stream_) { RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error)); - RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error)); + RAISE_ADBC(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_, error)); RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error)); RAISE_ADBC(BindNextAndExecute(nullptr, error)); @@ -251,7 +251,7 @@ AdbcStatusCode PqResultArrayReader::ExecuteAll(int64_t* affected_rows, AdbcError // stream (if there is one) or execute the query without binding. if (bind_stream_) { RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error)); - RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error)); + RAISE_ADBC(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_, error)); RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error)); // Reset affected rows to zero before binding and executing any diff --git a/c/driver/postgresql/result_reader.h b/c/driver/postgresql/result_reader.h index 51da6399ed..cbbff6a588 100644 --- a/c/driver/postgresql/result_reader.h +++ b/c/driver/postgresql/result_reader.h @@ -17,6 +17,10 @@ #pragma once +#if !defined(NOMINMAX) +#define NOMINMAX +#endif + #include #include #include @@ -34,13 +38,21 @@ class PqResultArrayReader { public: PqResultArrayReader(PGconn* conn, std::shared_ptr type_resolver, std::string query) - : conn_(conn), helper_(conn, std::move(query)), type_resolver_(type_resolver) { + : conn_(conn), + helper_(conn, std::move(query)), + type_resolver_(type_resolver), + autocommit_(false) { ArrowErrorInit(&na_error_); error_ = ADBC_ERROR_INIT; } ~PqResultArrayReader() { ResetErrors(); } + // Ensure the reader knows what the autocommit status was on creation. This is used + // so that the temporary timezone setting required for parameter binding can be wrapped + // in a transaction (or not) accordingly. + void SetAutocommit(bool autocommit) { autocommit_ = autocommit; } + void SetBind(struct ArrowArrayStream* stream) { bind_stream_ = std::make_unique(); bind_stream_->SetBind(stream); @@ -62,6 +74,7 @@ class PqResultArrayReader { std::shared_ptr type_resolver_; std::vector> field_readers_; nanoarrow::UniqueSchema schema_; + bool autocommit_; struct AdbcError error_; struct ArrowError na_error_; diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index f091b4fde9..d86f775b34 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -456,6 +456,7 @@ AdbcStatusCode PostgresStatement::ExecuteBind(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error) { PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); + reader.SetAutocommit(connection_->autocommit()); reader.SetBind(&bind_); RAISE_ADBC(reader.ToArrayStream(rows_affected, stream, error)); return ADBC_STATUS_OK; diff --git a/c/driver/sqlite/sqlite.cc b/c/driver/sqlite/sqlite.cc index 6628acdd97..207b591f3d 100644 --- a/c/driver/sqlite/sqlite.cc +++ b/c/driver/sqlite/sqlite.cc @@ -1008,7 +1008,8 @@ class SqliteStatement : public driver::Statement { "parameter count mismatch: expected {} but found {}", expected, actual); } - int64_t rows = 0; + int64_t output_rows = 0; + int64_t changed_rows = 0; SqliteMutexGuard guard(conn_); @@ -1027,7 +1028,11 @@ class SqliteStatement : public driver::Statement { } while (sqlite3_step(stmt_) == SQLITE_ROW) { - rows++; + output_rows++; + } + + if (sqlite3_column_count(stmt_) == 0) { + changed_rows += sqlite3_changes(conn_); } if (!binder_.schema.release) break; @@ -1041,9 +1046,10 @@ class SqliteStatement : public driver::Statement { } if (sqlite3_column_count(stmt_) == 0) { - rows = sqlite3_changes(conn_); + return changed_rows; + } else { + return output_rows; } - return rows; } Result ExecuteUpdateImpl(PreparedState& state) { return ExecuteUpdateImpl(); } diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index 3ceed0dee1..9e5ec86283 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -358,7 +358,7 @@ class StatementTest { void TestNewInit(); void TestRelease(); - // ---- Type-specific tests -------------------- + // ---- Type-specific ingest tests ------------- void TestSqlIngestBool(); @@ -427,6 +427,8 @@ class StatementTest { void TestSqlPrepareErrorNoQuery(); void TestSqlPrepareErrorParamCountMismatch(); + void TestSqlBind(); + void TestSqlQueryEmpty(); void TestSqlQueryInts(); void TestSqlQueryFloats(); @@ -533,6 +535,7 @@ class StatementTest { TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) { \ TestSqlPrepareErrorParamCountMismatch(); \ } \ + TEST_F(FIXTURE, SqlBind) { TestSqlBind(); } \ TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); } \ TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \ TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \ diff --git a/c/validation/adbc_validation_statement.cc b/c/validation/adbc_validation_statement.cc index 81b4696981..4549faf20c 100644 --- a/c/validation/adbc_validation_statement.cc +++ b/c/validation/adbc_validation_statement.cc @@ -2136,6 +2136,71 @@ void StatementTest::TestSqlPrepareErrorParamCountMismatch() { ::testing::Not(IsOkStatus(&error))); } +void StatementTest::TestSqlBind() { + if (!quirks()->supports_dynamic_parameter_binding()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + ASSERT_THAT(quirks()->DropTable(&connection, "bindtest", &error), IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement, "CREATE TABLE bindtest (col1 INTEGER, col2 TEXT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, + {{"", NANOARROW_TYPE_INT32}, {"", NANOARROW_TYPE_STRING}}), + IsOkErrno()); + + std::vector> int_values{std::nullopt, -123, 123}; + std::vector> string_values{"abc", std::nullopt, "defg"}; + + int batch_result = MakeBatch( + &schema.value, &array.value, &na_error, int_values, string_values); + ASSERT_THAT(batch_result, IsOkErrno()); + + auto insert_query = std::string("INSERT INTO bindtest VALUES (") + + quirks()->BindParameter(0) + ", " + quirks()->BindParameter(1) + + ")"; + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, insert_query.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + int64_t rows_affected = -10; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(-1), ::testing::Eq(3))); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "SELECT * FROM bindtest ORDER BY \"col1\" ASC NULLS FIRST", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 3); + CompareArray(reader.array_view->children[0], int_values); + CompareArray(reader.array_view->children[1], string_values); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } +} + void StatementTest::TestSqlQueryEmpty() { ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); diff --git a/c/validation/adbc_validation_util.h b/c/validation/adbc_validation_util.h index 6f25462698..e7a4d76b29 100644 --- a/c/validation/adbc_validation_util.h +++ b/c/validation/adbc_validation_util.h @@ -401,42 +401,22 @@ void CompareArray(struct ArrowArrayView* array, SCOPED_TRACE("Array index " + std::to_string(i)); if (v.has_value()) { ASSERT_FALSE(ArrowArrayViewIsNull(array, i)); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]); - } else if constexpr (std::is_same::value) { + ASSERT_EQ(ArrowArrayViewGetDoubleUnsafe(array, i), *v); + } else if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_double[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, ArrowBitGet(array->buffer_views[1].data.as_uint8, i)); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int8[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int16[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int32[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int64[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint8[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint16[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint32[i]); - } else if constexpr (std::is_same::value) { + ASSERT_EQ(ArrowArrayViewGetIntUnsafe(array, i), *v); + } else if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint64[i]); + ASSERT_EQ(ArrowArrayViewGetUIntUnsafe(array, i), *v); } else if constexpr (std::is_same::value) { struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(array, i); std::string str(view.data, view.size_bytes);