Skip to content

Commit

Permalink
Disable DDL execution within transaction
Browse files Browse the repository at this point in the history
- Rollback auto transaction when throwing exception

Signed-off-by: Wenbo Li <[email protected]>
  • Loading branch information
hnjylwb committed Mar 13, 2024
1 parent 3e58a6d commit e5c5b2a
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 117 deletions.
267 changes: 150 additions & 117 deletions src/database/database_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ void DatabaseEngine::ExecuteSql(const std::string &sql, ResultWriter &writer, co
DescribeTable(table_name, writer);
}
} else if (sql[1] == 'c') {
if (InTransaction(connection)) {
throw DbException("Cannot execute \\c within a transaction block");
}
if (sql.size() > 3) {
auto db_name = sql.substr(3);
StringUtil::RTrim(db_name);
Expand Down Expand Up @@ -148,130 +151,156 @@ void DatabaseEngine::ExecuteSql(const std::string &sql, ResultWriter &writer, co

bool is_modification_sql = false;
// 如果该语句不在事务块内,则自动开启一个事务
if (xids_.find(&connection) == xids_.end()) {
if (!InTransaction(connection)) {
Begin(connection);
auto_transaction_set_.insert(&connection);
}
switch (statement->type_) {
// 对于 DDL 查询,直接调用对应的函数
case StatementType::CREATE_DATABASE_STATEMENT: {
const auto &create_database_statement = dynamic_cast<CreateDatabaseStatement &>(*statement);
CreateDatabase(create_database_statement.database_, false, writer);
break;
}
case StatementType::CREATE_TABLE_STATEMENT: {
const auto &create_table_statement = dynamic_cast<CreateTableStatement &>(*statement);
CreateTable(create_table_statement.table_, ColumnList(create_table_statement.columns_), writer);
break;
}
case StatementType::CREATE_INDEX_STATEMENT: {
const auto &create_index_statement = dynamic_cast<CreateIndexStatement &>(*statement);
CreateIndex(create_index_statement.index_name_, create_index_statement.table_name_,
create_index_statement.column_names_, writer);
break;
}
case StatementType::DROP_DATABASE_STATEMENT: {
const auto &drop_database_statement = dynamic_cast<DropDatabaseStatement &>(*statement);
DropDatabase(drop_database_statement.database_, drop_database_statement.missing_ok_, writer);
break;
}
case StatementType::DROP_TABLE_STATEMENT: {
const auto &drop_table_statement = dynamic_cast<DropTableStatement &>(*statement);
DropTable(drop_table_statement.table_, writer);
break;
}
case StatementType::DROP_INDEX_STATEMENT: {
const auto &drop_index_statement = dynamic_cast<DropIndexStatement &>(*statement);
DropIndex(drop_index_statement.index_name_, writer);
break;
}
case StatementType::EXPLAIN_STATEMENT: {
const auto &explain_statement = dynamic_cast<ExplainStatement &>(*statement);
Explain(explain_statement, writer);
break;
}
case StatementType::LOCK_STATEMENT: {
const auto &lock_statement = dynamic_cast<LockStatement &>(*statement);
Lock(xids_[&connection], lock_statement, writer);
break;
}
case StatementType::VARIABLE_SET_STATEMENT: {
const auto &variable_set_statement = dynamic_cast<VariableSetStatement &>(*statement);
VariableSet(connection, variable_set_statement, writer);
break;
}
case StatementType::VARIABLE_SHOW_STATEMENT: {
const auto &variable_show_statement = dynamic_cast<VariableShowStatement &>(*statement);
VariableShow(connection, variable_show_statement, writer);
break;
}
case StatementType::ANALYZE_STATEMENT: {
const auto &analyze_statement = dynamic_cast<AnalyzeStatement &>(*statement);
Analyze(analyze_statement, writer);
break;
}
case StatementType::VACUUM_STATEMENT: {
const auto &vacuum_statement = dynamic_cast<VacuumStatement &>(*statement);
Vacuum(vacuum_statement, writer);
break;
}
case StatementType::UPDATE_STATEMENT:
case StatementType::DELETE_STATEMENT:
is_modification_sql = true;
// 对于 DML 查询,需要生成查询计划并执行
default: {
try {
// 生成查询计划
Planner planner(force_join_);
auto plan = planner.PlanQuery(*statement);

if (enable_optimizer_) {
// 查询计划优化
Optimizer optimizer(*catalog_, join_order_algorithm_);
plan = optimizer.Optimize(plan);
try {
switch (statement->type_) {
// 对于 DDL 查询,直接调用对应的函数
case StatementType::CREATE_DATABASE_STATEMENT: {
if (CheckInTransaction(connection)) {
throw DbException("Cannot execute DDL statement within a transaction block");
}

// 得到优化后的查询计划后,打印表头
auto column_list = plan->OutputColumns();
writer.BeginTable();
writer.BeginHeader();
for (size_t i = 0; i < column_list.Length(); i++) {
writer.WriteHeaderCell(column_list.GetColumn(i).name_);
const auto &create_database_statement = dynamic_cast<CreateDatabaseStatement &>(*statement);
CreateDatabase(create_database_statement.database_, false, writer);
break;
}
case StatementType::CREATE_TABLE_STATEMENT: {
if (CheckInTransaction(connection)) {
throw DbException("Cannot execute DDL statement within a transaction block");
}
writer.EndHeader();

// 生成查询上下文信息,如查询属于哪个事务,隔离级别等
IsolationLevel isolation_level = DEFAULT_ISOLATION_LEVEL;
if (isolation_levels_.find(&connection) != isolation_levels_.end()) {
isolation_level = isolation_levels_[&connection];
const auto &create_table_statement = dynamic_cast<CreateTableStatement &>(*statement);
CreateTable(create_table_statement.table_, ColumnList(create_table_statement.columns_), writer);
break;
}
case StatementType::CREATE_INDEX_STATEMENT: {
if (CheckInTransaction(connection)) {
throw DbException("Cannot execute DDL statement within a transaction block");
}
auto executor_context = std::make_unique<ExecutorContext>(
*buffer_pool_, *catalog_, *transaction_manager_, *lock_manager_, xids_[&connection], isolation_level,
transaction_manager_->GetCidAndIncrement(xids_[&connection]), is_modification_sql);

// 根据查询上下文和查询计划,生成执行器
auto executor = ExecutorFactory::CreateExecutor(*executor_context, plan);
executor->Init();
size_t record_count = 0;
while (auto record = executor->Next()) {
writer.BeginRow();
for (const auto &value : record->GetValues()) {
writer.WriteCell(value.ToString());
}
writer.EndRow();
record_count++;
const auto &create_index_statement = dynamic_cast<CreateIndexStatement &>(*statement);
CreateIndex(create_index_statement.index_name_, create_index_statement.table_name_,
create_index_statement.column_names_, writer);
break;
}
case StatementType::DROP_DATABASE_STATEMENT: {
if (CheckInTransaction(connection)) {
throw DbException("Cannot execute DDL statement within a transaction block");
}
writer.EndTable();
writer.WriteRowCount(record_count);
} catch (DbException &e) {
if (auto_transaction_set_.find(&connection) != auto_transaction_set_.end()) {
Rollback(connection);
auto_transaction_set_.erase(&connection);
const auto &drop_database_statement = dynamic_cast<DropDatabaseStatement &>(*statement);
DropDatabase(drop_database_statement.database_, drop_database_statement.missing_ok_, writer);
break;
}
case StatementType::DROP_TABLE_STATEMENT: {
if (CheckInTransaction(connection)) {
throw DbException("Cannot execute DDL statement within a transaction block");
}
throw e;
const auto &drop_table_statement = dynamic_cast<DropTableStatement &>(*statement);
DropTable(drop_table_statement.table_, writer);
break;
}
case StatementType::DROP_INDEX_STATEMENT: {
if (CheckInTransaction(connection)) {
throw DbException("Cannot execute DDL statement within a transaction block");
}
const auto &drop_index_statement = dynamic_cast<DropIndexStatement &>(*statement);
DropIndex(drop_index_statement.index_name_, writer);
break;
}
case StatementType::EXPLAIN_STATEMENT: {
const auto &explain_statement = dynamic_cast<ExplainStatement &>(*statement);
Explain(explain_statement, writer);
break;
}
case StatementType::LOCK_STATEMENT: {
const auto &lock_statement = dynamic_cast<LockStatement &>(*statement);
Lock(xids_[&connection], lock_statement, writer);
break;
}
case StatementType::VARIABLE_SET_STATEMENT: {
const auto &variable_set_statement = dynamic_cast<VariableSetStatement &>(*statement);
VariableSet(connection, variable_set_statement, writer);
break;
}
case StatementType::VARIABLE_SHOW_STATEMENT: {
const auto &variable_show_statement = dynamic_cast<VariableShowStatement &>(*statement);
VariableShow(connection, variable_show_statement, writer);
break;
}
case StatementType::ANALYZE_STATEMENT: {
const auto &analyze_statement = dynamic_cast<AnalyzeStatement &>(*statement);
Analyze(analyze_statement, writer);
break;
}
case StatementType::VACUUM_STATEMENT: {
const auto &vacuum_statement = dynamic_cast<VacuumStatement &>(*statement);
Vacuum(vacuum_statement, writer);
break;
}
case StatementType::UPDATE_STATEMENT:
case StatementType::DELETE_STATEMENT:
is_modification_sql = true;
// 对于 DML 查询,需要生成查询计划并执行
default: {
try {
// 生成查询计划
Planner planner(force_join_);
auto plan = planner.PlanQuery(*statement);

if (enable_optimizer_) {
// 查询计划优化
Optimizer optimizer(*catalog_, join_order_algorithm_);
plan = optimizer.Optimize(plan);
}

// 得到优化后的查询计划后,打印表头
auto column_list = plan->OutputColumns();
writer.BeginTable();
writer.BeginHeader();
for (size_t i = 0; i < column_list.Length(); i++) {
writer.WriteHeaderCell(column_list.GetColumn(i).name_);
}
writer.EndHeader();

// 生成查询上下文信息,如查询属于哪个事务,隔离级别等
IsolationLevel isolation_level = DEFAULT_ISOLATION_LEVEL;
if (isolation_levels_.find(&connection) != isolation_levels_.end()) {
isolation_level = isolation_levels_[&connection];
}
auto executor_context = std::make_unique<ExecutorContext>(
*buffer_pool_, *catalog_, *transaction_manager_, *lock_manager_, xids_[&connection], isolation_level,
transaction_manager_->GetCidAndIncrement(xids_[&connection]), is_modification_sql);

// 根据查询上下文和查询计划,生成执行器
auto executor = ExecutorFactory::CreateExecutor(*executor_context, plan);
executor->Init();
size_t record_count = 0;
while (auto record = executor->Next()) {
writer.BeginRow();
for (const auto &value : record->GetValues()) {
writer.WriteCell(value.ToString());
}
writer.EndRow();
record_count++;
}
writer.EndTable();
writer.WriteRowCount(record_count);
} catch (DbException &e) {
if (auto_transaction_set_.find(&connection) != auto_transaction_set_.end()) {
Rollback(connection);
auto_transaction_set_.erase(&connection);
}
throw e;
}
break;
}
break;
}
} catch (DbException &e) {
if (auto_transaction_set_.find(&connection) != auto_transaction_set_.end()) {
Rollback(connection);
auto_transaction_set_.erase(&connection);
}
throw e;
}
// 如果事务是自动开启的,查询结束后需要自动提交
if (auto_transaction_set_.find(&connection) != auto_transaction_set_.end()) {
Expand Down Expand Up @@ -301,6 +330,10 @@ void DatabaseEngine::Help(ResultWriter &writer) const {
WriteOneCell(help, writer);
}

bool DatabaseEngine::CheckInTransaction(const Connection &connection) const {
return auto_transaction_set_.find(&connection) == auto_transaction_set_.end() && InTransaction(connection);
}

void DatabaseEngine::CreateDatabase(const std::string &db_name, bool exists_ok, ResultWriter &writer) {
catalog_->CreateDatabase(db_name, exists_ok);
WriteOneCell("CREATE DATABASE", writer);
Expand Down Expand Up @@ -407,7 +440,7 @@ void DatabaseEngine::DropIndex(const std::string &index_name, ResultWriter &writ
}

void DatabaseEngine::Begin(const Connection &connection) {
if (xids_.find(&connection) != xids_.end()) {
if (InTransaction(connection)) {
throw DbException("There is already a transaction in progress");
} else {
auto xid = transaction_manager_->Begin();
Expand All @@ -417,7 +450,7 @@ void DatabaseEngine::Begin(const Connection &connection) {
}

void DatabaseEngine::Commit(const Connection &connection) {
if (xids_.find(&connection) == xids_.end()) {
if (!InTransaction(connection)) {
throw DbException("There is no transaction in process");
} else {
log_manager_->AppendCommitLog(xids_[&connection]);
Expand All @@ -427,7 +460,7 @@ void DatabaseEngine::Commit(const Connection &connection) {
}

void DatabaseEngine::Rollback(const Connection &connection) {
if (xids_.find(&connection) == xids_.end()) {
if (!InTransaction(connection)) {
throw DbException("There is no transaction in process");
} else {
log_manager_->Rollback(xids_[&connection]);
Expand Down
3 changes: 3 additions & 0 deletions src/database/database_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class DatabaseEngine {
private:
void Help(ResultWriter &writer) const;

// 检查事务是否正在进行(不包含自动开启的事务)
bool CheckInTransaction(const Connection &connection) const;

void CreateDatabase(const std::string &db_name, bool exists_ok, ResultWriter &writer);
void ShowDatabases(ResultWriter &writer) const;
void ChangeDatabase(const std::string &db_name, ResultWriter &writer);
Expand Down

0 comments on commit e5c5b2a

Please sign in to comment.