diff --git a/src/square.cc b/src/square.cc index c0c8ab2..0559424 100644 --- a/src/square.cc +++ b/src/square.cc @@ -98,12 +98,13 @@ namespace triton { namespace backend { namespace square { // class ModelParameters { public: - enum DelayType { INFER, OUTPUT }; + enum DelayType { INFER, OUTPUT, CANCEL }; enum InferResultType { SUCCESS, FAIL, EMPTY }; ModelParameters() : custom_infer_delay_ns_(0), custom_output_delay_ns_(0), - custom_fail_count_(0), custom_empty_count_(0) + custom_cancel_delay_ns_(0), custom_fail_count_(0), + custom_empty_count_(0) { } ModelParameters(common::TritonJson::Value& model_config_); @@ -120,6 +121,7 @@ class ModelParameters { size_t custom_infer_delay_ns_; size_t custom_output_delay_ns_; + size_t custom_cancel_delay_ns_; size_t custom_fail_count_; size_t custom_empty_count_; }; @@ -135,6 +137,8 @@ ModelParameters::ModelParameters(common::TritonJson::Value& model_config_) parameters_json, "CUSTOM_INFER_DELAY_NS", &custom_infer_delay_ns_); ReadParameter( parameters_json, "CUSTOM_OUTPUT_DELAY_NS", &custom_output_delay_ns_); + ReadParameter( + parameters_json, "CUSTOM_CANCEL_DELAY_NS", &custom_cancel_delay_ns_); ReadParameter(parameters_json, "CUSTOM_FAIL_COUNT", &custom_fail_count_); ReadParameter(parameters_json, "CUSTOM_EMPTY_COUNT", &custom_empty_count_); } @@ -148,6 +152,8 @@ ModelParameters::Sleep(DelayType delay_type) const Sleep(custom_infer_delay_ns_); } else if (delay_type == DelayType::OUTPUT) { Sleep(custom_output_delay_ns_); + } else if (delay_type == DelayType::CANCEL) { + Sleep(custom_cancel_delay_ns_); } } @@ -526,18 +532,33 @@ ModelInstanceState::RequestThread( uint64_t response_start_ns; SET_TIMESTAMP(response_start_ns); - // Simulate compute delay, if provided. - model_state_->get_model_parameters().Sleep( - ModelParameters::DelayType::INFER); + // Check if the request is cancelled. + bool is_cancelled; + RESPOND_FACTORY_AND_RETURN_IF_ERROR( + factory.get(), + TRITONBACKEND_ResponseFactoryIsCancelled(factory.get(), &is_cancelled)); + + // Simulate compute or cancellation. + if (is_cancelled) { + // Simulate cancel clean-up delay, if provided. + model_state_->get_model_parameters().Sleep( + ModelParameters::DelayType::CANCEL); + } else { + // Simulate compute delay, if provided. + model_state_->get_model_parameters().Sleep( + ModelParameters::DelayType::INFER); + } - // Result type of the simulated inference. + // Result type of the simulated inference, ignore if cancelled. ModelParameters::InferResultType result_type = model_state_->get_model_parameters().InferResult(e, element_count); - // Populate 'compute_output_start_ns' and 'response' if not empty result. + // Simulate compute output based on result type and cancellation, and + // populate 'compute_output_start_ns' and 'response'. uint64_t compute_output_start_ns = 0; TRITONBACKEND_Response* response = nullptr; - if (result_type != ModelParameters::InferResultType::EMPTY) { + if (!is_cancelled && + result_type != ModelParameters::InferResultType::EMPTY) { // Timestamp at start of outputting compute tensors. SET_TIMESTAMP(compute_output_start_ns); @@ -584,13 +605,15 @@ ModelInstanceState::RequestThread( // Set error for simulated failure. TRITONSERVER_Error* error = nullptr; - if (result_type == ModelParameters::InferResultType::FAIL) { + if (is_cancelled) { + error = TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_CANCELLED, "cancelled"); + } else if (result_type == ModelParameters::InferResultType::FAIL) { error = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNKNOWN, "simulated failure"); } - // Send response if not empty. - if (result_type != ModelParameters::InferResultType::EMPTY) { + // Send response, if any. + if (response != nullptr) { LOG_IF_ERROR( TRITONBACKEND_ResponseSend(response, 0 /* flags */, error), "failed sending response"); @@ -612,6 +635,16 @@ ModelInstanceState::RequestThread( (std::string("sent response ") + std::to_string(e + 1) + " of " + std::to_string(element_count)) .c_str()); + + // If cancelled, stop sending remaining responses, if any. + if (is_cancelled) { + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("cancelled at response ") + std::to_string(e + 1) + + " of " + std::to_string(element_count)) + .c_str()); + break; + } } // Add some logging for the case where IN was size 0 and so no