Skip to content

Commit

Permalink
Merge branch 'main' into dylan/add_internal_tables_to_pg_class
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzl25 authored Mar 16, 2023
2 parents 495db60 + 9b89bb0 commit 16c0c62
Show file tree
Hide file tree
Showing 20 changed files with 306 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/intergration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- schema-registry
- mysql-cdc
- postgres-cdc
#- mysql-sink
- mysql-sink
- postgres-sink
- iceberg-sink
format: ["json", "protobuf"]
Expand Down
19 changes: 19 additions & 0 deletions e2e_test/batch/basic/join.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ select * from t1 join t2 using(v1) join t3 using(v2);
----
2 1 3 3

statement ok
set batch_parallelism = 1;

query IIIIII
select * from t1 join t2 using(v1) join t3 using(v2);
----
2 1 3 3

statement ok
set batch_parallelism = 1000;

query IIIIII
select * from t1 join t2 using(v1) join t3 using(v2);
----
2 1 3 3

statement ok
set batch_parallelism = 0;

statement ok
create index i1 on t1(v1) include(v2);

Expand Down
2 changes: 1 addition & 1 deletion integration_tests/mysql-sink/mysql_prepare.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CREATE TABLE target_count (
target_id VARCHAR(128),
target_id VARCHAR(128) primary key,
target_count BIGINT
);
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ public static TableSchema fromProto(ConnectorServiceProto.TableSchema tableSchem
.collect(Collectors.toList()));
}

/** @deprecated pk here is from Risingwave, it may not match the pk in the database */
@Deprecated
public List<String> getPrimaryKeys() {
return primaryKeys;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import com.risingwave.proto.Data;
import io.grpc.Status;
import java.sql.*;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
Expand All @@ -30,10 +32,13 @@ public class JDBCSink extends SinkBase {
public static final String INSERT_TEMPLATE = "INSERT INTO %s (%s) VALUES (%s)";
private static final String DELETE_TEMPLATE = "DELETE FROM %s WHERE %s";
private static final String UPDATE_TEMPLATE = "UPDATE %s SET %s WHERE %s";
private static final String ERROR_REPORT_TEMPLATE = "Error when exec %s, message %s";

private final String tableName;
private final Connection conn;
private final String jdbcUrl;
private final List<String> pkColumnNames;
public static final String JDBC_COLUMN_NAME_KEY = "COLUMN_NAME";

private String updateDeleteConditionBuffer;
private Object[] updateDeleteValueBuffer;
Expand All @@ -48,16 +53,38 @@ public JDBCSink(String tableName, String jdbcUrl, TableSchema tableSchema) {
try {
this.conn = DriverManager.getConnection(jdbcUrl);
this.conn.setAutoCommit(false);
this.pkColumnNames = getPkColumnNames(conn, tableName);
} catch (SQLException e) {
throw Status.INTERNAL.withCause(e).asRuntimeException();
throw Status.INTERNAL
.withDescription(
String.format(ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage()))
.asRuntimeException();
}
}

private static List<String> getPkColumnNames(Connection conn, String tableName) {
List<String> pkColumnNames = new ArrayList<>();
try {
var pks = conn.getMetaData().getPrimaryKeys(null, null, tableName);
while (pks.next()) {
pkColumnNames.add(pks.getString(JDBC_COLUMN_NAME_KEY));
}
} catch (SQLException e) {
throw Status.INTERNAL
.withDescription(
String.format(ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage()))
.asRuntimeException();
}
LOG.info("detected pk {}", pkColumnNames);
return pkColumnNames;
}

public JDBCSink(Connection conn, TableSchema tableSchema, String tableName) {
super(tableSchema);
this.tableName = tableName;
this.jdbcUrl = null;
this.conn = conn;
this.pkColumnNames = getPkColumnNames(conn, tableName);
}

private PreparedStatement prepareStatement(SinkRow row) {
Expand All @@ -79,35 +106,75 @@ private PreparedStatement prepareStatement(SinkRow row) {
}
return stmt;
} catch (SQLException e) {
throw io.grpc.Status.INTERNAL.withCause(e).asRuntimeException();
throw io.grpc.Status.INTERNAL
.withDescription(
String.format(
ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage()))
.asRuntimeException();
}
case DELETE:
String deleteCondition =
getTableSchema().getPrimaryKeys().stream()
.map(key -> key + " = ?")
.collect(Collectors.joining(" AND "));
String deleteCondition;
if (this.pkColumnNames.isEmpty()) {
deleteCondition =
IntStream.range(0, getTableSchema().getNumColumns())
.mapToObj(
index ->
getTableSchema().getColumnNames()[index]
+ " = ?")
.collect(Collectors.joining(" AND "));
} else {
deleteCondition =
this.pkColumnNames.stream()
.map(key -> key + " = ?")
.collect(Collectors.joining(" AND "));
}
String deleteStmt = String.format(DELETE_TEMPLATE, tableName, deleteCondition);
try {
int placeholderIdx = 1;
PreparedStatement stmt =
conn.prepareStatement(deleteStmt, Statement.RETURN_GENERATED_KEYS);
for (String primaryKey : getTableSchema().getPrimaryKeys()) {
for (String primaryKey : this.pkColumnNames) {
Object fromRow = getTableSchema().getFromRow(primaryKey, row);
stmt.setObject(placeholderIdx++, fromRow);
}
return stmt;
} catch (SQLException e) {
throw Status.INTERNAL.withCause(e).asRuntimeException();
throw Status.INTERNAL
.withDescription(
String.format(
ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage()))
.asRuntimeException();
}
case UPDATE_DELETE:
updateDeleteConditionBuffer =
getTableSchema().getPrimaryKeys().stream()
.map(key -> key + " = ?")
.collect(Collectors.joining(" AND "));
updateDeleteValueBuffer =
getTableSchema().getPrimaryKeys().stream()
.map(key -> getTableSchema().getFromRow(key, row))
.toArray();
if (this.pkColumnNames.isEmpty()) {
updateDeleteConditionBuffer =
IntStream.range(0, getTableSchema().getNumColumns())
.mapToObj(
index ->
getTableSchema().getColumnNames()[index]
+ " = ?")
.collect(Collectors.joining(" AND "));
updateDeleteValueBuffer =
IntStream.range(0, getTableSchema().getNumColumns())
.mapToObj(
index ->
getTableSchema()
.getFromRow(
getTableSchema()
.getColumnNames()[
index],
row))
.toArray();
} else {
updateDeleteConditionBuffer =
this.pkColumnNames.stream()
.map(key -> key + " = ?")
.collect(Collectors.joining(" AND "));
updateDeleteValueBuffer =
this.pkColumnNames.stream()
.map(key -> getTableSchema().getFromRow(key, row))
.toArray();
}
LOG.debug(
"update delete condition: {} on values {}",
updateDeleteConditionBuffer,
Expand Down Expand Up @@ -144,7 +211,11 @@ private PreparedStatement prepareStatement(SinkRow row) {
updateDeleteValueBuffer = null;
return stmt;
} catch (SQLException e) {
throw Status.INTERNAL.withCause(e).asRuntimeException();
throw Status.INTERNAL
.withDescription(
String.format(
ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage()))
.asRuntimeException();
}
default:
throw Status.INVALID_ARGUMENT
Expand All @@ -163,10 +234,14 @@ public void write(Iterator<SinkRow> rows) {
}
if (stmt != null) {
try {
LOG.debug("Executing statement: " + stmt);
LOG.debug("Executing statement: {}", stmt);
stmt.executeUpdate();
} catch (SQLException e) {
throw Status.INTERNAL.withCause(e).asRuntimeException();
throw Status.INTERNAL
.withDescription(
String.format(
ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage()))
.asRuntimeException();
}
} else {
throw Status.INTERNAL
Expand All @@ -187,7 +262,10 @@ public void sync() {
try {
conn.commit();
} catch (SQLException e) {
throw io.grpc.Status.INTERNAL.withCause(e).asRuntimeException();
throw io.grpc.Status.INTERNAL
.withDescription(
String.format(ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage()))
.asRuntimeException();
}
}

Expand All @@ -196,7 +274,10 @@ public void drop() {
try {
conn.close();
} catch (SQLException e) {
throw io.grpc.Status.INTERNAL.withCause(e).asRuntimeException();
throw io.grpc.Status.INTERNAL
.withDescription(
String.format(ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage()))
.asRuntimeException();
}
}

Expand Down
24 changes: 23 additions & 1 deletion src/common/src/session_config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod search_path;
mod transaction_isolation_level;
mod visibility_mode;

use std::num::NonZeroU64;
use std::ops::Deref;

use chrono_tz::Tz;
Expand All @@ -33,7 +34,7 @@ use crate::util::epoch::Epoch;

// This is a hack, &'static str is not allowed as a const generics argument.
// TODO: refine this using the adt_const_params feature.
const CONFIG_KEYS: [&str; 20] = [
const CONFIG_KEYS: [&str; 21] = [
"RW_IMPLICIT_FLUSH",
"CREATE_COMPACTION_GROUP_FOR_MV",
"QUERY_MODE",
Expand All @@ -54,6 +55,7 @@ const CONFIG_KEYS: [&str; 20] = [
"RW_FORCE_TWO_PHASE_AGG",
"RW_ENABLE_SHARE_PLAN",
"INTERVALSTYLE",
"BATCH_PARALLELISM",
];

// MUST HAVE 1v1 relationship to CONFIG_KEYS. e.g. CONFIG_KEYS[IMPLICIT_FLUSH] =
Expand All @@ -78,6 +80,7 @@ const ENABLE_TWO_PHASE_AGG: usize = 16;
const FORCE_TWO_PHASE_AGG: usize = 17;
const RW_ENABLE_SHARE_PLAN: usize = 18;
const INTERVAL_STYLE: usize = 19;
const BATCH_PARALLELISM: usize = 20;

trait ConfigEntry: Default + for<'a> TryFrom<&'a [&'a str], Error = RwError> {
fn entry_name() -> &'static str;
Expand Down Expand Up @@ -278,6 +281,7 @@ type EnableTwoPhaseAgg = ConfigBool<ENABLE_TWO_PHASE_AGG, true>;
type ForceTwoPhaseAgg = ConfigBool<FORCE_TWO_PHASE_AGG, false>;
type EnableSharePlan = ConfigBool<RW_ENABLE_SHARE_PLAN, true>;
type IntervalStyle = ConfigString<INTERVAL_STYLE>;
type BatchParallelism = ConfigU64<BATCH_PARALLELISM, 0>;

#[derive(Derivative)]
#[derivative(Default)]
Expand Down Expand Up @@ -354,6 +358,8 @@ pub struct ConfigMap {

/// see <https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-INTERVALSTYLE>
interval_style: IntervalStyle,

batch_parallelism: BatchParallelism,
}

impl ConfigMap {
Expand Down Expand Up @@ -410,6 +416,8 @@ impl ConfigMap {
self.enable_share_plan = val.as_slice().try_into()?;
} else if key.eq_ignore_ascii_case(IntervalStyle::entry_name()) {
self.interval_style = val.as_slice().try_into()?;
} else if key.eq_ignore_ascii_case(BatchParallelism::entry_name()) {
self.batch_parallelism = val.as_slice().try_into()?;
} else {
return Err(ErrorCode::UnrecognizedConfigurationParameter(key.to_string()).into());
}
Expand Down Expand Up @@ -458,6 +466,8 @@ impl ConfigMap {
Ok(self.enable_share_plan.to_string())
} else if key.eq_ignore_ascii_case(IntervalStyle::entry_name()) {
Ok(self.interval_style.to_string())
} else if key.eq_ignore_ascii_case(BatchParallelism::entry_name()) {
Ok(self.batch_parallelism.to_string())
} else {
Err(ErrorCode::UnrecognizedConfigurationParameter(key.to_string()).into())
}
Expand Down Expand Up @@ -560,6 +570,11 @@ impl ConfigMap {
setting : self.interval_style.to_string(),
description : String::from("It is typically set by an application upon connection to the server.")
},
VariableInfo{
name : BatchParallelism::entry_name().to_lowercase(),
setting : self.batch_parallelism.to_string(),
description: String::from("Sets the parallelism for batch. If 0, use default value.")
},
]
}

Expand Down Expand Up @@ -648,4 +663,11 @@ impl ConfigMap {
pub fn get_interval_style(&self) -> &str {
&self.interval_style
}

pub fn get_batch_parallelism(&self) -> Option<NonZeroU64> {
if self.batch_parallelism.0 != 0 {
return Some(NonZeroU64::new(self.batch_parallelism.0).unwrap());
}
None
}
}
1 change: 1 addition & 0 deletions src/connector/src/source/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ pub type DataType = risingwave_common::types::DataType;
pub struct Column {
pub name: String,
pub data_type: DataType,
pub is_visible: bool,
}

/// Split id resides in every source message, use `Arc` to avoid copying.
Expand Down
Loading

0 comments on commit 16c0c62

Please sign in to comment.