Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

join be aware of cancel signal (#9450) #9453

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbms/src/DataStreams/AggregatingBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Block AggregatingBlockInputStream::readImpl()
executed = true;
AggregatedDataVariantsPtr data_variants = std::make_shared<AggregatedDataVariants>();

Aggregator::CancellationHook hook = [&]() {
CancellationHook hook = [&]() {
return this->isCancelled();
};
aggregator.setCancellationHook(hook);
Expand Down
2 changes: 0 additions & 2 deletions dbms/src/DataStreams/HashJoinProbeExec.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class HashJoinProbeExec : public std::enable_shared_from_this<HashJoinProbeExec>
const BlockInputStreamPtr & probe_stream,
size_t max_block_size);

using CancellationHook = std::function<bool()>;

HashJoinProbeExec(
const String & req_id,
const JoinPtr & join_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Block ParallelAggregatingBlockInputStream::readImpl()
{
if (!executed)
{
Aggregator::CancellationHook hook = [&]() {
CancellationHook hook = [&]() {
return this->isCancelled();
};
aggregator.setCancellationHook(hook);
Expand Down
10 changes: 7 additions & 3 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,9 @@ void MPPTask::runImpl()
GET_METRIC(tiflash_coprocessor_request_duration_seconds, type_run_mpp_task).Observe(stopwatch.elapsedSeconds());
});

// set cancellation hook
context->setCancellationHook([this] { return is_cancelled.load(); });

String err_msg;
try
{
Expand Down Expand Up @@ -746,7 +749,7 @@ void MPPTask::abort(const String & message, AbortType abort_type)
if (previous_status == FINISHED || previous_status == CANCELLED || previous_status == FAILED)
{
LOG_WARNING(log, "task already in {} state", magic_enum::enum_name(previous_status));
return;
break;
}
else if (previous_status == INITIALIZING && switchStatus(INITIALIZING, next_task_status))
{
Expand All @@ -755,7 +758,7 @@ void MPPTask::abort(const String & message, AbortType abort_type)
/// so just close all tunnels here
abortTunnels("", false);
LOG_WARNING(log, "Finish abort task from uninitialized");
return;
break;
}
else if (previous_status == RUNNING && switchStatus(RUNNING, next_task_status))
{
Expand All @@ -769,9 +772,10 @@ void MPPTask::abort(const String & message, AbortType abort_type)
scheduleThisTask(ScheduleState::FAILED);
/// runImpl is running, leave remaining work to runImpl
LOG_WARNING(log, "Finish abort task from running");
return;
break;
}
}
is_cancelled = true;
}

bool MPPTask::switchStatus(TaskStatus from, TaskStatus to)
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class MPPTask

MPPTaskManager * manager;
std::atomic<bool> is_registered{false};
std::atomic<bool> is_cancelled{false};

MPPTaskScheduleEntry schedule_entry;

Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ void PhysicalJoin::probeSideTransform(DAGPipeline & probe_pipeline, Context & co
settings.max_block_size);
stream->setExtraInfo(join_probe_extra_info);
}
join_ptr->setCancellationHook([&] { return context.isCancelled(); });
}

void PhysicalJoin::buildSideTransform(DAGPipeline & build_pipeline, Context & context, size_t max_streams)
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Flash/Planner/Plans/PhysicalJoinProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <Flash/Coprocessor/InterpreterUtils.h>
#include <Flash/Executor/PipelineExecutorContext.h>
#include <Flash/Pipeline/Exec/PipelineExecBuilder.h>
#include <Flash/Planner/Plans/PhysicalJoinProbe.h>
#include <Interpreters/Context.h>
Expand Down Expand Up @@ -56,6 +57,7 @@ void PhysicalJoinProbe::buildPipelineExecGroupImpl(
max_block_size,
input_header));
});
join_ptr->setCancellationHook([&]() { return exec_context.isCancelled(); });
join_ptr.reset();
}
} // namespace DB
3 changes: 1 addition & 2 deletions dbms/src/Interpreters/Aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <Interpreters/AggSpillContext.h>
#include <Interpreters/AggregateDescription.h>
#include <Interpreters/AggregationCommon.h>
#include <Interpreters/CancellationHook.h>
#include <TiDB/Collation/Collator.h>
#include <common/StringRef.h>
#include <common/logger_useful.h>
Expand Down Expand Up @@ -1199,8 +1200,6 @@ class Aggregator
*/
Blocks convertBlockToTwoLevel(const Block & block);

using CancellationHook = std::function<bool()>;

/** Set a function that checks whether the current task can be aborted.
*/
void setCancellationHook(CancellationHook cancellation_hook);
Expand Down
22 changes: 22 additions & 0 deletions dbms/src/Interpreters/CancellationHook.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>

namespace DB
{
using CancellationHook = std::function<bool()>;
} // namespace DB
15 changes: 11 additions & 4 deletions dbms/src/Interpreters/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <Debug/MockServerInfo.h>
#include <Encryption/FileProvider_fwd.h>
#include <IO/CompressionSettings.h>
#include <Interpreters/CancellationHook.h>
#include <Interpreters/ClientInfo.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/Settings.h>
Expand Down Expand Up @@ -181,6 +182,9 @@ class Context
TimezoneInfo timezone_info;

DAGContext * dag_context = nullptr;
CancellationHook is_cancelled{[]() {
return false;
}};
using DatabasePtr = std::shared_ptr<IDatabase>;
using Databases = std::map<String, std::shared_ptr<IDatabase>>;
/// Use copy constructor or createGlobal() instead
Expand Down Expand Up @@ -238,8 +242,8 @@ class Context
/// Compute and set actual user settings, client_info.current_user should be set
void calculateUserSettings();

ClientInfo & getClientInfo() { return client_info; };
const ClientInfo & getClientInfo() const { return client_info; };
ClientInfo & getClientInfo() { return client_info; }
const ClientInfo & getClientInfo() const { return client_info; }

void setQuota(
const String & name,
Expand Down Expand Up @@ -376,6 +380,9 @@ class Context
void setDAGContext(DAGContext * dag_context);
DAGContext * getDAGContext() const;

bool isCancelled() const { return is_cancelled(); }
void setCancellationHook(CancellationHook cancellation_hook) { is_cancelled = cancellation_hook; }

/// List all queries.
ProcessList & getProcessList();
const ProcessList & getProcessList() const;
Expand Down Expand Up @@ -502,8 +509,8 @@ class Context

SharedQueriesPtr getSharedQueries();

const TimezoneInfo & getTimezoneInfo() const { return timezone_info; };
TimezoneInfo & getTimezoneInfo() { return timezone_info; };
const TimezoneInfo & getTimezoneInfo() const { return timezone_info; }
TimezoneInfo & getTimezoneInfo() { return timezone_info; }

/// User name and session identifier. Named sessions are local to users.
using SessionKey = std::pair<String, String>;
Expand Down
30 changes: 28 additions & 2 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,10 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info) const

Block Join::removeUselessColumn(Block & block) const
{
// cancelled
if (!block)
return block;

Block projected_block;
for (const auto & name : tidb_output_column_names)
{
Expand All @@ -1368,6 +1372,8 @@ Block Join::joinBlockHash(ProbeProcessInfo & probe_process_info) const
probe_process_info.prepareForHashProbe(key_names_left, non_equal_conditions.left_filter_column, kind, strictness);
while (true)
{
if (is_cancelled())
return {};
auto block = doJoinBlockHash(probe_process_info);
assert(block);
block = removeUselessColumn(block);
Expand Down Expand Up @@ -1468,6 +1474,8 @@ Block Join::joinBlockCross(ProbeProcessInfo & probe_process_info) const

while (true)
{
if (is_cancelled())
return {};
Block block = doJoinBlockCross(probe_process_info);
assert(block);
block = removeUselessColumn(block);
Expand Down Expand Up @@ -1582,6 +1590,8 @@ Block Join::joinBlockNullAware(ProbeProcessInfo & probe_process_info) const

/// Null aware join never expand the left block, just handle the whole block at one time is enough
probe_process_info.all_rows_joined_finish = true;
if (is_cancelled())
return {};

return removeUselessColumn(block);
}
Expand Down Expand Up @@ -1616,19 +1626,31 @@ void Join::joinBlockNullAwareImpl(
right_side_info);

RUNTIME_ASSERT(res.size() == rows, "NASemiJoinResult size {} must be equal to block size {}", res.size(), rows);
if (is_cancelled())
return;

size_t right_columns = block.columns() - left_columns;

if (!res_list.empty())
{
NASemiJoinHelper<KIND, STRICTNESS, typename Maps::MappedType::Base_t>
helper(block, left_columns, right_columns, blocks, null_rows, max_block_size, non_equal_conditions);
NASemiJoinHelper<KIND, STRICTNESS, typename Maps::MappedType::Base_t> helper(
block,
left_columns,
right_columns,
blocks,
null_rows,
max_block_size,
non_equal_conditions,
is_cancelled);

helper.joinResult(res_list);

RUNTIME_CHECK_MSG(res_list.empty(), "NASemiJoinResult list must be empty after calculating join result");
}

if (is_cancelled())
return;

/// Now all results are known.

std::unique_ptr<IColumn::Filter> filter;
Expand Down Expand Up @@ -2004,6 +2026,9 @@ Block Join::joinBlock(ProbeProcessInfo & probe_process_info, bool dry_run) const
else
block = joinBlockHash(probe_process_info);

// if cancelled, just return empty block
if (!block)
return block;
/// for (cartesian)antiLeftSemi join, the meaning of "match-helper" is `non-matched` instead of `matched`.
if (kind == LeftOuterAnti || kind == Cross_LeftOuterAnti)
{
Expand Down Expand Up @@ -2275,6 +2300,7 @@ std::optional<RestoreInfo> Join::getOneRestoreStream(size_t max_block_size_)
restore_join->initBuild(build_sample_block, restore_join_build_concurrency);
restore_join->setInitActiveBuildThreads();
restore_join->initProbe(probe_sample_block, restore_join_build_concurrency);
restore_join->setCancellationHook(is_cancelled);
BlockInputStreams restore_scan_hash_map_streams;
restore_scan_hash_map_streams.resize(restore_join_build_concurrency, nullptr);
if (needScanHashMapAfterProbe(kind))
Expand Down
6 changes: 6 additions & 0 deletions dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <Flash/Coprocessor/JoinInterpreterHelper.h>
#include <Flash/Coprocessor/RuntimeFilterMgr.h>
#include <Interpreters/AggregationCommon.h>
#include <Interpreters/CancellationHook.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/HashJoinSpillContext.h>
#include <Interpreters/JoinHashMap.h>
Expand Down Expand Up @@ -306,6 +307,8 @@ class Join
bool hasProbeSideMarkedSpillData(size_t stream_index) const;
void flushProbeSideMarkedSpillData(size_t stream_index);

void setCancellationHook(CancellationHook cancellation_hook) { is_cancelled = cancellation_hook; }

static const String match_helper_prefix;
static const DataTypePtr match_helper_type;
static const String flag_mapped_entry_helper_prefix;
Expand Down Expand Up @@ -433,6 +436,9 @@ class Join
// the index of vector is the stream_index.
std::vector<MarkedSpillData> build_side_marked_spilled_data;
std::vector<MarkedSpillData> probe_side_marked_spilled_data;
CancellationHook is_cancelled{[]() {
return false;
}};

private:
/** Set information about structure of right hand of JOIN (joined data).
Expand Down
14 changes: 10 additions & 4 deletions dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,15 @@ NASemiJoinHelper<KIND, STRICTNESS, Mapped>::NASemiJoinHelper(
const BlocksList & right_blocks_,
const std::vector<RowsNotInsertToMap *> & null_rows_,
size_t max_block_size_,
const JoinNonEqualConditions & non_equal_conditions_)
const JoinNonEqualConditions & non_equal_conditions_,
CancellationHook is_cancelled_)
: block(block_)
, left_columns(left_columns_)
, right_columns(right_columns_)
, right_blocks(right_blocks_)
, null_rows(null_rows_)
, max_block_size(max_block_size_)
, is_cancelled(is_cancelled_)
, non_equal_conditions(non_equal_conditions_)
{
static_assert(KIND == NullAware_Anti || KIND == NullAware_LeftOuterAnti || KIND == NullAware_LeftOuterSemi);
Expand Down Expand Up @@ -265,17 +267,17 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::joinResult(std::list<NASemiJoin
res_list.swap(next_step_res_list);
}

if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStep<NASemiJoinStep::NOT_NULL_KEY_CHECK_NULL_ROWS>(res_list, next_step_res_list);
res_list.swap(next_step_res_list);
if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStep<NASemiJoinStep::NULL_KEY_CHECK_NULL_ROWS>(res_list, next_step_res_list);
res_list.swap(next_step_res_list);
if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStepAllBlocks(res_list);
Expand Down Expand Up @@ -309,6 +311,8 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::runStep(

while (!res_list.empty())
{
if (is_cancelled())
return;
MutableColumns columns(block_columns);
for (size_t i = 0; i < block_columns; ++i)
{
Expand Down Expand Up @@ -369,6 +373,8 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::runStepAllBlocks(std::list<NASe
NASemiJoinHelper::Result * res = *res_list.begin();
for (const auto & right_block : right_blocks)
{
if (is_cancelled())
return;
if (res->getStep() == NASemiJoinStep::DONE)
break;

Expand Down
5 changes: 4 additions & 1 deletion dbms/src/Interpreters/NullAwareSemiJoinHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <Common/Logger.h>
#include <Core/Block.h>
#include <Flash/Coprocessor/JoinInterpreterHelper.h>
#include <Interpreters/CancellationHook.h>
#include <Parsers/ASTTablesInSelectQuery.h>


Expand Down Expand Up @@ -173,7 +174,8 @@ class NASemiJoinHelper
const BlocksList & right_blocks,
const std::vector<RowsNotInsertToMap *> & null_rows,
size_t max_block_size,
const JoinNonEqualConditions & non_equal_conditions);
const JoinNonEqualConditions & non_equal_conditions,
CancellationHook is_cancelled_);

void joinResult(std::list<Result *> & res_list);

Expand All @@ -197,6 +199,7 @@ class NASemiJoinHelper
const BlocksList & right_blocks;
const std::vector<RowsNotInsertToMap *> & null_rows;
size_t max_block_size;
CancellationHook is_cancelled;

const JoinNonEqualConditions & non_equal_conditions;
};
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Operators/AggregateContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace DB
void AggregateContext::initBuild(
const Aggregator::Params & params,
size_t max_threads_,
Aggregator::CancellationHook && hook,
CancellationHook && hook,
const RegisterOperatorSpillContext & register_operator_spill_context)
{
assert(status.load() == AggStatus::init);
Expand Down
Loading