Skip to content

Commit

Permalink
implement review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kpedro88 committed Jul 23, 2020
1 parent c96a30a commit 49bf75b
Show file tree
Hide file tree
Showing 17 changed files with 107 additions and 107 deletions.
1 change: 1 addition & 0 deletions HeterogeneousCore/SonicCore/BuildFile.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
<use name="FWCore/Concurrency"/>
<use name="FWCore/MessageLogger"/>
<use name="FWCore/ParameterSet"/>
<use name="FWCore/Utilities"/>
<export>
<lib name="1"/>
</export>
7 changes: 4 additions & 3 deletions HeterogeneousCore/SonicCore/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,15 @@ A Python configuration parameter can be provided to enable retries with a specif
The client must also provide a static method `fillPSetDescription()` to populate its parameters in the `fillDescriptions()` for the producers that use the client:
```cpp
void MyClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) {
edm::ParameterSetDescription descClient(basePSetDescription());
edm::ParameterSetDescription descClient;
fillBasePSetDescription(descClient);
//add parameters
iDesc.add<edm::ParameterSetDescription>("Client",descClient);
}
```
As indicated, the `descClient` object should always be initialized using the `basePSetDescription()` function,
As indicated, the `fillBasePSetDescription()` function should always be applied to the `descClient` object,
to ensure that it includes the necessary parameters.
(Calling `basePSetDescription(false)` will omit the `allowedTries` parameter, disabling retries.)
(Calling `fillBasePSetDescription(descClient, false)` will omit the `allowedTries` parameter, disabling retries.)
Example client code can be found in the `interface` and `src` directories of the other Sonic packages in this repository.
3 changes: 0 additions & 3 deletions HeterogeneousCore/SonicCore/interface/SonicClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ class SonicClient : public SonicClientBase, public SonicClientTypes<InputT, Outp
public:
//constructor
SonicClient(const edm::ParameterSet& params) : SonicClientBase(params), SonicClientTypes<InputT, OutputT>() {}

//do nothing by default
void reset() override {}
};

#endif
6 changes: 3 additions & 3 deletions HeterogeneousCore/SonicCore/interface/SonicClientBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ class SonicClientBase {
//main operation
virtual void dispatch(edm::WaitingTaskWithArenaHolder holder) { dispatcher_->dispatch(std::move(holder)); }

//helper
virtual void reset() = 0;
//helper: does nothing by default
virtual void reset() {}

//provide base params
static edm::ParameterSetDescription basePSetDescription(bool allowRetry = true);
static void fillBasePSetDescription(edm::ParameterSetDescription& desc, bool allowRetry = true);

protected:
virtual void evaluate() = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,7 @@ class SonicDispatcherPseudoAsync : public SonicDispatcher {
SonicDispatcherPseudoAsync(SonicClientBase* client);

//destructor
~SonicDispatcherPseudoAsync() override {
stop_ = true;
cond_.notify_one();
if (thread_) {
try {
thread_->join();
thread_.reset();
} catch (...) {
}
}
}
~SonicDispatcherPseudoAsync() override;

//main operation
void dispatch(edm::WaitingTaskWithArenaHolder holder) override;
Expand Down
4 changes: 2 additions & 2 deletions HeterogeneousCore/SonicCore/interface/SonicEDProducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ class SonicEDProducer : public edm::stream::EDProducer<edm::ExternalWork, Capabi
//(no need to interact with callback holder)
void acquire(edm::Event const& iEvent, edm::EventSetup const& iSetup, edm::WaitingTaskWithArenaHolder holder) final {
auto t0 = std::chrono::high_resolution_clock::now();
//reset client data
client_.reset();
acquire(iEvent, iSetup, client_.input());
auto t1 = std::chrono::high_resolution_clock::now();
if (!client_.debugName().empty())
Expand All @@ -53,6 +51,8 @@ class SonicEDProducer : public edm::stream::EDProducer<edm::ExternalWork, Capabi
if (!client_.debugName().empty())
edm::LogInfo(client_.debugName()) << "produce() time: "
<< std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count();
//reset client data
client_.reset();
}
virtual void produce(edm::Event& iEvent, edm::EventSetup const& iSetup, Output const& iOutput) = 0;

Expand Down
10 changes: 4 additions & 6 deletions HeterogeneousCore/SonicCore/src/SonicClientBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ SonicClientBase::SonicClientBase(const edm::ParameterSet& params)
else if (modeName == "PseudoAsync")
mode_ = SonicMode::PseudoAsync;
else
throw cms::Exception("UnknownMode") << "Unknown mode for SonicClient: " << modeName;
throw cms::Exception("Configuration") << "Unknown mode for SonicClient: " << modeName;

//get correct dispatcher for mode
if (mode_ == SonicMode::Sync or mode_ == SonicMode::Async)
Expand Down Expand Up @@ -59,10 +59,8 @@ void SonicClientBase::finish(bool success, std::exception_ptr eptr) {
holder_.doneWaiting(eptr);
}

edm::ParameterSetDescription SonicClientBase::basePSetDescription(bool allowRetry) {
edm::ParameterSetDescription descClient;
descClient.add<std::string>("mode");
void SonicClientBase::fillBasePSetDescription(edm::ParameterSetDescription& desc, bool allowRetry) {
desc.add<std::string>("mode");
if (allowRetry)
descClient.addUntracked<unsigned>("allowedTries", 0);
return descClient;
desc.addUntracked<unsigned>("allowedTries", 0);
}
14 changes: 14 additions & 0 deletions HeterogeneousCore/SonicCore/src/SonicDispatcherPseudoAsync.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
#include "HeterogeneousCore/SonicCore/interface/SonicDispatcherPseudoAsync.h"
#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h"
#include "FWCore/Utilities/interface/thread_safety_macros.h"

SonicDispatcherPseudoAsync::SonicDispatcherPseudoAsync(SonicClientBase* client)
: SonicDispatcher(client), hasCall_(false), stop_(false) {
thread_ = std::make_unique<std::thread>([this]() { waitForNext(); });
}

SonicDispatcherPseudoAsync::~SonicDispatcherPseudoAsync() {
stop_ = true;
cond_.notify_one();
if (thread_) {
// avoid throwing in destructor
CMS_SA_ALLOW try {
thread_->join();
thread_.reset();
} catch (...) {
}
}
}

void SonicDispatcherPseudoAsync::dispatch(edm::WaitingTaskWithArenaHolder holder) {
//do all read/writes inside lock to ensure cache synchronization
{
Expand Down
3 changes: 2 additions & 1 deletion HeterogeneousCore/SonicCore/test/DummyClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class DummyClient : public SonicClient<int> {

//for fillDescriptions
static void fillPSetDescription(edm::ParameterSetDescription& iDesc) {
edm::ParameterSetDescription descClient(basePSetDescription());
edm::ParameterSetDescription descClient;
fillBasePSetDescription(descClient);
descClient.add<int>("factor", -1);
descClient.add<int>("wait", 10);
descClient.add<unsigned>("fails", 0);
Expand Down
6 changes: 3 additions & 3 deletions HeterogeneousCore/SonicTriton/interface/TritonClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ class TritonClient : public SonicClient<TritonInputMap, TritonOutputMap> {
//constructor
TritonClient(const edm::ParameterSet& params);

//helper
bool getResults(std::map<std::string, std::unique_ptr<InferContext::Result>>& results);

//accessors
unsigned batchSize() const { return batchSize_; }
bool verbose() const { return verbose_; }
Expand All @@ -42,6 +39,9 @@ class TritonClient : public SonicClient<TritonInputMap, TritonOutputMap> {
static void fillPSetDescription(edm::ParameterSetDescription& iDesc);

protected:
//helper
bool getResults(std::map<std::string, std::unique_ptr<InferContext::Result>>& results);

void evaluate() override;

void reportServerSideStats(const ServerSideStats& stats) const;
Expand Down
5 changes: 2 additions & 3 deletions HeterogeneousCore/SonicTriton/interface/TritonData.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
#include <string>
#include <unordered_map>
#include <numeric>
#include <functional>
#include <algorithm>
#include <memory>
#include <any>

#include "request_grpc.h"

Expand All @@ -35,7 +35,6 @@ class TritonData {
void fromServer(std::vector<DT>& data_out) const;

//const accessors
const std::shared_ptr<IO>& data() const { return data_; }
const std::vector<int64_t>& dims() const { return dims_; }
const std::vector<int64_t>& shape() const { return shape_.empty() ? dims() : shape_; }
int64_t byteSize() const { return byteSize_; }
Expand Down Expand Up @@ -68,7 +67,7 @@ class TritonData {
int64_t byteSize_;
std::vector<int64_t> shape_;
unsigned batchSize_;
std::function<void(void)> callback_;
std::any holder_;
std::unique_ptr<Result> result_;
};

Expand Down
26 changes: 0 additions & 26 deletions HeterogeneousCore/SonicTriton/interface/TritonUtils.h

This file was deleted.

29 changes: 29 additions & 0 deletions HeterogeneousCore/SonicTriton/interface/triton_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef HeterogeneousCore_SonicTriton_triton_utils
#define HeterogeneousCore_SonicTriton_triton_utils

#include <string>
#include <string_view>
#include <vector>

#include "request_grpc.h"

namespace triton_utils {

using Error = nvidia::inferenceserver::client::Error;

template <typename T>
std::string printVec(const std::vector<T>& vec, const std::string& delim = ", ");

//helper to turn triton error into exception
void throwIfError(const Error& err, const std::string_view& msg);

//helper to turn triton error into warning
bool warnIfError(const Error& err, const std::string_view& msg);

} // namespace triton_utils

extern template std::string triton_utils::printVec(const std::vector<int64_t>& vec, const std::string& delim);
extern template std::string triton_utils::printVec(const std::vector<uint8_t>& vec, const std::string& delim);
extern template std::string triton_utils::printVec(const std::vector<float>& vec, const std::string& delim);

#endif
39 changes: 20 additions & 19 deletions HeterogeneousCore/SonicTriton/src/TritonClient.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "FWCore/MessageLogger/interface/MessageLogger.h"
#include "FWCore/Utilities/interface/Exception.h"
#include "HeterogeneousCore/SonicTriton/interface/TritonClient.h"
#include "HeterogeneousCore/SonicTriton/interface/TritonUtils.h"
#include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"

#include "request_grpc.h"

Expand Down Expand Up @@ -31,12 +31,12 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
fullDebugName_ = clientName_;

//connect to the server
TritonUtils::wrap(nic::InferGrpcContext::Create(&context_, url_, modelName_, modelVersion_, false),
"TritonClient(): unable to create inference context");
triton_utils::throwIfError(nic::InferGrpcContext::Create(&context_, url_, modelName_, modelVersion_, false),
"TritonClient(): unable to create inference context");

//get options
TritonUtils::wrap(nic::InferContext::Options::Create(&options_),
"TritonClient(): unable to create inference context options");
triton_utils::throwIfError(nic::InferContext::Options::Create(&options_),
"TritonClient(): unable to create inference context options");

//get input and output (which know their sizes)
const auto& nicInputs = context_->Inputs();
Expand Down Expand Up @@ -70,7 +70,7 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
if (verbose_) {
const auto& curr_input = curr_itr.first->second;
io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
<< " b) : " << TritonUtils::printVec(curr_input.dims()) << "\n";
<< " b) : " << triton_utils::printVec(curr_input.dims()) << "\n";
}
}

Expand All @@ -83,11 +83,11 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
const auto& curr_itr = output_.emplace(
std::piecewise_construct, std::forward_as_tuple(oname), std::forward_as_tuple(oname, nicOutput));
const auto& curr_output = curr_itr.first->second;
TritonUtils::wrap(options_->AddRawResult(curr_output.data()),
"TritonClient(): unable to add raw result " + curr_itr.first->first);
triton_utils::throwIfError(options_->AddRawResult(nicOutput),
"TritonClient(): unable to add raw result " + curr_itr.first->first);
if (verbose_) {
io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
<< " b) : " << TritonUtils::printVec(curr_output.dims()) << "\n";
<< " b) : " << triton_utils::printVec(curr_output.dims()) << "\n";
}
}

Expand All @@ -102,7 +102,7 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
setBatchSize(params.getUntrackedParameter<unsigned>("batchSize"));

//initial server settings
TritonUtils::wrap(context_->SetRunOptions(*options_), "TritonClient(): unable to set run options");
triton_utils::throwIfError(context_->SetRunOptions(*options_), "TritonClient(): unable to set run options");

//print model info
std::stringstream model_msg;
Expand All @@ -118,8 +118,8 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
//print model info
edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();

has_server = TritonUtils::warn(nic::ServerStatusGrpcContext::Create(&serverCtx_, url_, false),
"TritonClient(): unable to create server context");
has_server = triton_utils::warnIfError(nic::ServerStatusGrpcContext::Create(&serverCtx_, url_, false),
"TritonClient(): unable to create server context");
}
if (!has_server)
serverCtx_ = nullptr;
Expand All @@ -142,7 +142,7 @@ bool TritonClient::setBatchSize(unsigned bsize) {
//set for server (and Input objects)
if (!noBatch_) {
options_->SetBatchSize(batchSize_);
TritonUtils::wrap(context_->SetRunOptions(*options_), "setBatchSize(): unable to set run options");
triton_utils::throwIfError(context_->SetRunOptions(*options_), "setBatchSize(): unable to set run options");
}
return true;
}
Expand Down Expand Up @@ -173,7 +173,7 @@ bool TritonClient::getResults(std::map<std::string, std::unique_ptr<nic::InferCo
//set shape here before output becomes const
if (output.variableDims()) {
bool status =
TritonUtils::warn(result->GetRawShape(&(output.shape())), "getResults(): unable to get output shape");
triton_utils::warnIfError(result->GetRawShape(&(output.shape())), "getResults(): unable to get output shape");
if (!status)
return status;
}
Expand All @@ -198,13 +198,13 @@ void TritonClient::evaluate() {
if (mode_ == SonicMode::Async) {
//non-blocking call
auto t1 = std::chrono::high_resolution_clock::now();
bool status = TritonUtils::warn(
bool status = triton_utils::warnIfError(
context_->AsyncRun([t1, start_status, this](nic::InferContext* ctx,
const std::shared_ptr<nic::InferContext::Request>& request) {
//get results
bool status = true;
std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
status = TritonUtils::warn(ctx->GetAsyncRunResults(request, &results), "evaluate(): unable to get result");
bool status =
triton_utils::warnIfError(ctx->GetAsyncRunResults(request, &results), "evaluate(): unable to get result");
if (!status) {
finish(false);
return;
Expand Down Expand Up @@ -237,7 +237,7 @@ void TritonClient::evaluate() {
//blocking call
auto t1 = std::chrono::high_resolution_clock::now();
std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
bool status = TritonUtils::warn(context_->Run(&results), "evaluate(): unable to run and/or get result");
bool status = triton_utils::warnIfError(context_->Run(&results), "evaluate(): unable to run and/or get result");
if (!status) {
finish(false);
return;
Expand Down Expand Up @@ -347,7 +347,8 @@ ni::ModelStatus TritonClient::getServerSideStatus() const {

//for fillDescriptions
void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) {
edm::ParameterSetDescription descClient(basePSetDescription());
edm::ParameterSetDescription descClient;
fillBasePSetDescription(descClient);
descClient.add<std::string>("modelName");
descClient.add<int>("modelVersion", -1);
//server parameters should not affect the physics results
Expand Down
Loading

0 comments on commit 49bf75b

Please sign in to comment.