Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

join be aware of cancel signal (#9450) #9452

Merged
2 changes: 1 addition & 1 deletion dbms/src/DataStreams/AggregatingBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Block AggregatingBlockInputStream::readImpl()
executed = true;
AggregatedDataVariantsPtr data_variants = std::make_shared<AggregatedDataVariants>();

Aggregator::CancellationHook hook = [&]() {
CancellationHook hook = [&]() {
return this->isCancelled();
};
aggregator.setCancellationHook(hook);
Expand Down
2 changes: 0 additions & 2 deletions dbms/src/DataStreams/HashJoinProbeExec.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class HashJoinProbeExec : public std::enable_shared_from_this<HashJoinProbeExec>
const BlockInputStreamPtr & probe_stream,
size_t max_block_size);

using CancellationHook = std::function<bool()>;

HashJoinProbeExec(
const String & req_id,
const JoinPtr & join_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Block ParallelAggregatingBlockInputStream::readImpl()
{
if (!executed)
{
Aggregator::CancellationHook hook = [&]() {
CancellationHook hook = [&]() {
return this->isCancelled();
};
aggregator.setCancellationHook(hook);
Expand Down
10 changes: 7 additions & 3 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ void MPPTask::runImpl()
GET_METRIC(tiflash_coprocessor_request_duration_seconds, type_run_mpp_task).Observe(stopwatch.elapsedSeconds());
});

// set cancellation hook
context->setCancellationHook([this] { return is_cancelled.load(); });

String err_msg;
try
{
Expand Down Expand Up @@ -746,7 +749,7 @@ void MPPTask::abort(const String & message, AbortType abort_type)
if (previous_status == FINISHED || previous_status == CANCELLED || previous_status == FAILED)
{
LOG_WARNING(log, "task already in {} state", magic_enum::enum_name(previous_status));
return;
break;
}
else if (previous_status == INITIALIZING && switchStatus(INITIALIZING, next_task_status))
{
Expand All @@ -755,7 +758,7 @@ void MPPTask::abort(const String & message, AbortType abort_type)
/// so just close all tunnels here
abortTunnels("", false);
LOG_WARNING(log, "Finish abort task from uninitialized");
return;
break;
}
else if (previous_status == RUNNING && switchStatus(RUNNING, next_task_status))
{
Expand All @@ -769,9 +772,10 @@ void MPPTask::abort(const String & message, AbortType abort_type)
scheduleThisTask(ScheduleState::FAILED);
/// runImpl is running, leave remaining work to runImpl
LOG_WARNING(log, "Finish abort task from running");
return;
break;
}
}
is_cancelled = true;
}

bool MPPTask::switchStatus(TaskStatus from, TaskStatus to)
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class MPPTask

MPPTaskManager * manager;
std::atomic<bool> is_registered{false};
std::atomic<bool> is_cancelled{false};

MPPTaskScheduleEntry schedule_entry;

Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ void PhysicalJoin::probeSideTransform(DAGPipeline & probe_pipeline, Context & co
execId(),
needScanHashMapAfterProbe(join_ptr->getKind()));
join_ptr->initProbe(probe_pipeline.firstStream()->getHeader(), probe_pipeline.streams.size());
join_ptr->setCancellationHook([&] { return context.isCancelled(); });
size_t probe_index = 0;
for (auto & stream : probe_pipeline.streams)
{
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Flash/Planner/Plans/PhysicalJoinBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <Flash/Coprocessor/DAGContext.h>
#include <Flash/Coprocessor/InterpreterUtils.h>
#include <Flash/Executor/PipelineExecutorContext.h>
#include <Flash/Pipeline/Exec/PipelineExecBuilder.h>
#include <Flash/Planner/Plans/PhysicalJoinBuild.h>
#include <Interpreters/Context.h>
Expand All @@ -39,6 +40,7 @@ void PhysicalJoinBuild::buildPipelineExecGroupImpl(
join_execute_info.join_build_profile_infos = group_builder.getCurProfileInfos();
join_ptr->initBuild(group_builder.getCurrentHeader(), group_builder.concurrency());
join_ptr->setInitActiveBuildThreads();
join_ptr->setCancellationHook([&]() { return exec_context.isCancelled(); });
join_ptr.reset();
}
} // namespace DB
3 changes: 1 addition & 2 deletions dbms/src/Interpreters/Aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <Interpreters/AggSpillContext.h>
#include <Interpreters/AggregateDescription.h>
#include <Interpreters/AggregationCommon.h>
#include <Interpreters/CancellationHook.h>
#include <TiDB/Collation/Collator.h>
#include <common/StringRef.h>
#include <common/logger_useful.h>
Expand Down Expand Up @@ -1199,8 +1200,6 @@ class Aggregator
*/
Blocks convertBlockToTwoLevel(const Block & block);

using CancellationHook = std::function<bool()>;

/** Set a function that checks whether the current task can be aborted.
*/
void setCancellationHook(CancellationHook cancellation_hook);
Expand Down
22 changes: 22 additions & 0 deletions dbms/src/Interpreters/CancellationHook.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>

namespace DB
{
using CancellationHook = std::function<bool()>;
} // namespace DB
16 changes: 12 additions & 4 deletions dbms/src/Interpreters/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <Core/Types.h>
#include <Debug/MockServerInfo.h>
#include <IO/FileProvider/FileProvider_fwd.h>
#include <Interpreters/CancellationHook.h>
#include <Interpreters/ClientInfo.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/Settings.h>
Expand All @@ -34,6 +35,7 @@
#include <mutex>
#include <thread>


namespace pingcap
{
struct ClusterConfig;
Expand Down Expand Up @@ -178,6 +180,9 @@ class Context
TimezoneInfo timezone_info;

DAGContext * dag_context = nullptr;
CancellationHook is_cancelled{[]() {
return false;
}};
using DatabasePtr = std::shared_ptr<IDatabase>;
using Databases = std::map<String, std::shared_ptr<IDatabase>>;
/// Use copy constructor or createGlobal() instead
Expand Down Expand Up @@ -235,8 +240,8 @@ class Context
/// Compute and set actual user settings, client_info.current_user should be set
void calculateUserSettings();

ClientInfo & getClientInfo() { return client_info; };
const ClientInfo & getClientInfo() const { return client_info; };
ClientInfo & getClientInfo() { return client_info; }
const ClientInfo & getClientInfo() const { return client_info; }

void setQuota(
const String & name,
Expand Down Expand Up @@ -373,6 +378,9 @@ class Context
void setDAGContext(DAGContext * dag_context);
DAGContext * getDAGContext() const;

bool isCancelled() const { return is_cancelled(); }
void setCancellationHook(CancellationHook cancellation_hook) { is_cancelled = cancellation_hook; }

/// List all queries.
ProcessList & getProcessList();
const ProcessList & getProcessList() const;
Expand Down Expand Up @@ -499,8 +507,8 @@ class Context

SharedQueriesPtr getSharedQueries();

const TimezoneInfo & getTimezoneInfo() const { return timezone_info; };
TimezoneInfo & getTimezoneInfo() { return timezone_info; };
const TimezoneInfo & getTimezoneInfo() const { return timezone_info; }
TimezoneInfo & getTimezoneInfo() { return timezone_info; }

/// User name and session identifier. Named sessions are local to users.
using SessionKey = std::pair<String, String>;
Expand Down
28 changes: 25 additions & 3 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,8 @@ Block Join::joinBlockHash(ProbeProcessInfo & probe_process_info) const
restore_config.restore_round);
while (true)
{
if (is_cancelled())
return {};
auto block = doJoinBlockHash(probe_process_info, join_build_info);
assert(block);
block = removeUselessColumn(block);
Expand Down Expand Up @@ -1519,6 +1521,8 @@ Block Join::joinBlockCross(ProbeProcessInfo & probe_process_info) const

while (true)
{
if (is_cancelled())
return {};
Block block = doJoinBlockCross(probe_process_info);
assert(block);
block = removeUselessColumn(block);
Expand Down Expand Up @@ -1600,6 +1604,9 @@ Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_in

RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows);

if (is_cancelled())
return {};

Block block{};
for (size_t i = 0; i < probe_process_info.block.columns(); ++i)
{
Expand Down Expand Up @@ -1636,13 +1643,17 @@ Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_in
blocks,
null_rows,
max_block_size,
non_equal_conditions);
non_equal_conditions,
is_cancelled);

helper.joinResult(res_list);

RUNTIME_CHECK_MSG(res_list.empty(), "NASemiJoinResult list must be empty after calculating join result");
}

if (is_cancelled())
return {};

/// Now all results are known.

std::unique_ptr<IColumn::Filter> filter;
Expand Down Expand Up @@ -1787,6 +1798,8 @@ Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const Probe
probe_process_info);

RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows);
if (is_cancelled())
return {};

const NameSet & probe_output_name_set = has_other_condition
? output_columns_names_set_for_other_condition_after_finalize
Expand Down Expand Up @@ -1821,15 +1834,23 @@ Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const Probe
{
if (!res_list.empty())
{
SemiJoinHelper<KIND, typename Maps::MappedType>
helper(block, left_columns, right_column_indices_to_add, max_block_size, non_equal_conditions);
SemiJoinHelper<KIND, typename Maps::MappedType> helper(
block,
left_columns,
right_column_indices_to_add,
max_block_size,
non_equal_conditions,
is_cancelled);

helper.joinResult(res_list);

RUNTIME_CHECK_MSG(res_list.empty(), "SemiJoinResult list must be empty after calculating join result");
}
}

if (is_cancelled())
return {};

/// Now all results are known.

std::unique_ptr<IColumn::Filter> filter;
Expand Down Expand Up @@ -2488,6 +2509,7 @@ std::optional<RestoreInfo> Join::getOneRestoreStream(size_t max_block_size_)
restore_join->initBuild(build_sample_block, restore_join_build_concurrency);
restore_join->setInitActiveBuildThreads();
restore_join->initProbe(probe_sample_block, restore_join_build_concurrency);
restore_join->setCancellationHook(is_cancelled);
BlockInputStreams restore_scan_hash_map_streams;
restore_scan_hash_map_streams.resize(restore_join_build_concurrency, nullptr);
if (needScanHashMapAfterProbe(kind))
Expand Down
6 changes: 6 additions & 0 deletions dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <Flash/Coprocessor/JoinInterpreterHelper.h>
#include <Flash/Coprocessor/RuntimeFilterMgr.h>
#include <Interpreters/AggregationCommon.h>
#include <Interpreters/CancellationHook.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/HashJoinSpillContext.h>
#include <Interpreters/JoinHashMap.h>
Expand Down Expand Up @@ -311,6 +312,8 @@ class Join
void flushProbeSideMarkedSpillData(size_t stream_index);
size_t getProbeCacheColumnThreshold() const { return probe_cache_column_threshold; }

void setCancellationHook(CancellationHook cancellation_hook) { is_cancelled = cancellation_hook; }

static const String match_helper_prefix;
static const DataTypePtr match_helper_type;
static const String flag_mapped_entry_helper_prefix;
Expand Down Expand Up @@ -446,6 +449,9 @@ class Join
// the index of vector is the stream_index.
std::vector<MarkedSpillData> build_side_marked_spilled_data;
std::vector<MarkedSpillData> probe_side_marked_spilled_data;
CancellationHook is_cancelled{[]() {
return false;
}};

private:
/** Set information about structure of right hand of JOIN (joined data).
Expand Down
14 changes: 10 additions & 4 deletions dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,15 @@ NASemiJoinHelper<KIND, STRICTNESS, Mapped>::NASemiJoinHelper(
const BlocksList & right_blocks_,
const std::vector<RowsNotInsertToMap *> & null_rows_,
size_t max_block_size_,
const JoinNonEqualConditions & non_equal_conditions_)
const JoinNonEqualConditions & non_equal_conditions_,
CancellationHook is_cancelled_)
: block(block_)
, left_columns(left_columns_)
, right_column_indices_to_add(right_column_indices_to_add_)
, right_blocks(right_blocks_)
, null_rows(null_rows_)
, max_block_size(max_block_size_)
, is_cancelled(is_cancelled_)
, non_equal_conditions(non_equal_conditions_)
{
static_assert(KIND == NullAware_Anti || KIND == NullAware_LeftOuterAnti || KIND == NullAware_LeftOuterSemi);
Expand Down Expand Up @@ -280,17 +282,17 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::joinResult(std::list<NASemiJoin
res_list.swap(next_step_res_list);
}

if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStep<NASemiJoinStep::NOT_NULL_KEY_CHECK_NULL_ROWS>(res_list, next_step_res_list);
res_list.swap(next_step_res_list);
if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStep<NASemiJoinStep::NULL_KEY_CHECK_NULL_ROWS>(res_list, next_step_res_list);
res_list.swap(next_step_res_list);
if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStepAllBlocks(res_list);
Expand Down Expand Up @@ -324,6 +326,8 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::runStep(

while (!res_list.empty())
{
if (is_cancelled())
return;
MutableColumns columns(block_columns);
for (size_t i = 0; i < block_columns; ++i)
{
Expand Down Expand Up @@ -384,6 +388,8 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::runStepAllBlocks(std::list<NASe
NASemiJoinHelper::Result * res = *res_list.begin();
for (const auto & right_block : right_blocks)
{
if (is_cancelled())
return;
if (res->getStep() == NASemiJoinStep::DONE)
break;

Expand Down
5 changes: 4 additions & 1 deletion dbms/src/Interpreters/NullAwareSemiJoinHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <Common/Logger.h>
#include <Core/Block.h>
#include <Flash/Coprocessor/JoinInterpreterHelper.h>
#include <Interpreters/CancellationHook.h>
#include <Interpreters/SemiJoinHelper.h>
#include <Parsers/ASTTablesInSelectQuery.h>

Expand Down Expand Up @@ -167,7 +168,8 @@ class NASemiJoinHelper
const BlocksList & right_blocks,
const std::vector<RowsNotInsertToMap *> & null_rows,
size_t max_block_size,
const JoinNonEqualConditions & non_equal_conditions);
const JoinNonEqualConditions & non_equal_conditions,
CancellationHook is_cancelled_);

void joinResult(std::list<Result *> & res_list);

Expand All @@ -192,6 +194,7 @@ class NASemiJoinHelper
const BlocksList & right_blocks;
const std::vector<RowsNotInsertToMap *> & null_rows;
size_t max_block_size;
CancellationHook is_cancelled;

const JoinNonEqualConditions & non_equal_conditions;
};
Expand Down
Loading