Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More code cleanup #1406

Merged
merged 5 commits into from
Jul 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onnxruntime/core/common/logging/logging.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ LoggingManager::LoggingManager(std::unique_ptr<ISink> sink, Severity default_min
default_filter_user_data_{filter_user_data},
default_max_vlog_level_{default_max_vlog_level},
owns_default_logger_{false} {
if (!sink_) {
if (sink_ == nullptr) {
throw std::logic_error("ISink must be provided.");
}

Expand Down Expand Up @@ -126,7 +126,6 @@ LoggingManager::~LoggingManager() {

void LoggingManager::CreateDefaultLogger(const std::string& logger_id) {
// this method is only called from ctor in scope where DefaultLoggerMutex() is already locked

if (s_default_logger_ != nullptr) {
throw std::logic_error("Default logger already set. ");
}
Expand Down Expand Up @@ -186,7 +185,8 @@ std::exception LoggingManager::LogFatalAndCreateException(const char* category,
// create Capture in separate scope so it gets destructed (leading to log output) before we throw.
{
::onnxruntime::logging::Capture c{::onnxruntime::logging::LoggingManager::DefaultLogger(),
::onnxruntime::logging::Severity::kFATAL, category, ::onnxruntime::logging::DataType::SYSTEM, location};
::onnxruntime::logging::Severity::kFATAL, category,
::onnxruntime::logging::DataType::SYSTEM, location};
va_list args;
va_start(args, format_str);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,7 @@ Example 4:
pads_initializer->int64_data().end());

// fill with zeros if needed to reach appropriate size
if (pads_data.size() != static_cast<size_t>(2 * input_rank))
if (pads_data.size() != 2 * static_cast<size_t>(input_rank))
pads_data.resize(2 * input_rank, 0);

const auto& output_shape =
Expand Down
32 changes: 13 additions & 19 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ inline std::basic_string<T> GetCurrentTimeString() {
}
} // namespace

InferenceSession::InferenceSession(const SessionOptions& session_options, logging::LoggingManager* logging_manager)
InferenceSession::InferenceSession(const SessionOptions& session_options,
logging::LoggingManager* logging_manager)
: session_options_{session_options},
graph_transformation_mgr_{session_options_.max_num_graph_transformation_steps},
logging_manager_{logging_manager},
Expand Down Expand Up @@ -603,7 +604,7 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&

common::Status InferenceSession::ValidateOutputs(const std::vector<std::string>& output_names,
const std::vector<OrtValue>* p_fetches) {
if (!p_fetches) {
if (p_fetches == nullptr) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Output vector pointer is NULL");
}
Expand Down Expand Up @@ -646,8 +647,6 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
}

ORT_RETURN_IF_ERROR(ValidateInputs(feed_names, feeds));

// if the output vector is non-empty, ensure that its the same size as the output_names
ORT_RETURN_IF_ERROR(ValidateOutputs(output_names, p_fetches));

FeedsFetchesInfo info(feed_names, output_names);
Expand Down Expand Up @@ -886,13 +885,12 @@ const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& ru
run_log_id += run_options.run_tag;

logging::Severity severity = logging::Severity::kWARNING;

if (run_options.run_log_severity_level < 0) {
if (run_options.run_log_severity_level == -1) {
severity = session_logger_->GetSeverity();
} else {
ORT_ENFORCE(run_options.run_log_severity_level >= 0 &&
run_options.run_log_severity_level <= static_cast<int>(logging::Severity::kFATAL),
"Invalid run log severity level. Must be a valid onnxruntime::logging::Severity value. Got ",
run_options.run_log_severity_level <= static_cast<int>(logging::Severity::kFATAL),
"Invalid run log severity level. Not a valid onnxruntime::logging::Severity value: ",
run_options.run_log_severity_level);
severity = static_cast<logging::Severity>(run_options.run_log_severity_level);
}
Expand All @@ -915,26 +913,22 @@ const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& ru

void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) {
// create logger for session, using provided logging manager if possible
if (logging_manager != nullptr) {
std::string session_logid = !session_options_.session_logid.empty()
? session_options_.session_logid
: "InferenceSession"; // there's probably a better default...

if (logging_manager != nullptr && !session_options_.session_logid.empty()) {
logging::Severity severity = logging::Severity::kWARNING;

if (session_options_.session_log_severity_level < 0) {
if (session_options_.session_log_severity_level == -1) {
severity = logging::LoggingManager::DefaultLogger().GetSeverity();
} else {
ORT_ENFORCE(session_options_.session_log_severity_level >= 0 &&
session_options_.session_log_severity_level <= static_cast<int>(logging::Severity::kFATAL),
"Invalid session log severity level. Must be a valid onnxruntime::logging::Severity value. Got ",
session_options_.session_log_severity_level <= static_cast<int>(logging::Severity::kFATAL),
"Invalid session log severity level. Not a valid onnxruntime::logging::Severity value: ",
session_options_.session_log_severity_level);
severity = static_cast<logging::Severity>(session_options_.session_log_severity_level);
}

owned_session_logger_ = logging_manager_->CreateLogger(session_logid, severity, false,
owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid,
severity,
false,
session_options_.session_log_verbosity_level);

session_logger_ = owned_session_logger_.get();
} else {
session_logger_ = &logging::LoggingManager::DefaultLogger();
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,14 +391,15 @@ class InferenceSession {
std::vector<std::string> transformers_to_enable_;

/// Logging manager if provided.
logging::LoggingManager* logging_manager_;
logging::LoggingManager* logging_manager_ = nullptr;

/// Logger for this session. WARNING: Will contain nullptr if logging_manager_ is nullptr.
std::unique_ptr<logging::Logger> owned_session_logger_;
std::unique_ptr<logging::Logger> owned_session_logger_ = nullptr;

// Profiler for this session.
profiling::Profiler session_profiler_;

// The list of execution providers.
ExecutionProviders execution_providers_;

protected:
Expand Down