Skip to content

Commit

Permalink
Merge pull request microsoft#55 from NonStatic2014/bohu/shortcut
Browse files Browse the repository at this point in the history
Add shortcut endpoint for predict request and health endpoint for orchestration systems
  • Loading branch information
NonStatic2014 authored Apr 26, 2019
2 parents 6abc722 + e8da7c8 commit 17c6220
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/Hosting_Application_Usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ONNX Hosting provides a REST API for prediction. The goal of the project is to m
ONNX Hosting: host an ONNX model for inferencing with ONNXRuntime
Allowed options:
-h [ --help ] Shows a help message and exits
--logging_level arg (=verbose) Logging level. Allowed options (case
--log_level arg (=info) Logging level. Allowed options (case
sensitive): verbose, info, warning, error,
fatal
-m [ --model_path ] arg Path to ONNX model
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/hosting/http/core/http_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,3 @@ class App {
};
} // namespace hosting
} // namespace onnxruntime

18 changes: 14 additions & 4 deletions onnxruntime/hosting/http/core/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,16 @@ template <typename Body, typename Allocator>
void HttpSession::HandleRequest(http::request<Body, http::basic_fields<Allocator> >&& req) {
HttpContext context{};
context.request = std::move(req);
// TODO: set request id

auto status = ExecuteUserFunction(context);
// Special handle the liveness probe endpoint for orchestration systems like Kubernetes.
if (context.request.method() == http::verb::get && context.request.target().to_string() == "/") {
context.response.body() = "Healthy";
} else {
auto status = ExecuteUserFunction(context);

if (status != http::status::ok) {
routes_.on_error(context);
if (status != http::status::ok) {
routes_.on_error(context);
}
}

context.response.keep_alive(context.request.keep_alive());
Expand All @@ -116,6 +120,12 @@ http::status HttpSession::ExecuteUserFunction(HttpContext& context) {
context.client_request_id = context.request["x-ms-client-request-id"].to_string();
}

if (path == "/score") {
// This is a shortcut since we have only one model instance currently.
// This code path will be removed once we start supporting multiple models or multiple versions of one model.
path = "/v1/models/default/versions/1:predict";
}

auto status = routes_.ParseUrl(context.request.method(), path, model_name, model_version, action, func);

if (status != http::status::ok) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/hosting/http/predict_request_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ void Predict(const std::string& name,
/* in, out */ HttpContext& context,
std::shared_ptr<HostingEnvironment> env) {
auto logger = env->GetLogger(context.request_id);
LOGS(*logger, VERBOSE) << "Name: " << name << " Version: " << version << " Action: " << action;
LOGS(*logger, INFO) << "Model Name: " << name << ", Version: " << version << ", Action: " << action;

if (!context.client_request_id.empty()) {
LOGS(*logger, VERBOSE) << "x-ms-client-request-id: [" << context.client_request_id << "]";
LOGS(*logger, INFO) << "x-ms-client-request-id: [" << context.client_request_id << "]";
}

// Request and Response content type information
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/hosting/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ int main(int argc, char* argv[]) {
auto env = std::make_shared<hosting::HostingEnvironment>(config.logging_level);
auto logger = env->GetAppLogger();
LOGS(logger, VERBOSE) << "Logging manager initialized.";
LOGS(logger, VERBOSE) << "Model path: " << config.model_path;
LOGS(logger, INFO) << "Model path: " << config.model_path;

auto status = env->InitializeModel(config.model_path);
if (!status.IsOK()) {
Expand All @@ -47,8 +47,8 @@ int main(int argc, char* argv[]) {
app.RegisterStartup(
[env](const auto& details) -> void {
auto logger = env->GetAppLogger();
LOGS(logger, VERBOSE) << "Listening at: "
<< "http://" << details.address << ":" << details.port;
LOGS(logger, INFO) << "Listening at: "
<< "http://" << details.address << ":" << details.port;
});

app.RegisterError(
Expand Down
13 changes: 6 additions & 7 deletions onnxruntime/hosting/server_configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ServerConfiguration {

ServerConfiguration() {
desc.add_options()("help,h", "Shows a help message and exits");
desc.add_options()("logging_level", po::value(&logging_level_str)->default_value(logging_level_str), "Logging level. Allowed options (case sensitive): verbose, info, warning, error, fatal");
desc.add_options()("log_level", po::value(&log_level_str)->default_value(log_level_str), "Logging level. Allowed options (case sensitive): verbose, info, warning, error, fatal");
desc.add_options()("model_path", po::value(&model_path)->required(), "Path to ONNX model");
desc.add_options()("address", po::value(&address)->default_value(address), "The base HTTP address");
desc.add_options()("http_port", po::value(&http_port)->default_value(http_port), "HTTP port to listen to requests");
Expand Down Expand Up @@ -77,7 +77,7 @@ class ServerConfiguration {
Result result = ValidateOptions();

if (result == Result::ContinueSuccess) {
logging_level = supported_log_levels[logging_level_str];
logging_level = supported_log_levels[log_level_str];
}

return result;
Expand All @@ -86,13 +86,13 @@ class ServerConfiguration {
private:
po::options_description desc{"Allowed options"};
po::variables_map vm{};
std::string logging_level_str = "verbose";
std::string log_level_str = "info";

// Print help and return if there is a bad value
Result ValidateOptions() {
if (vm.count("logging_level") &&
supported_log_levels.find(logging_level_str) == supported_log_levels.end()) {
PrintHelp(std::cerr, "logging_level must be one of verbose, info, warning, error, or fatal");
if (vm.count("log_level") &&
supported_log_levels.find(log_level_str) == supported_log_levels.end()) {
PrintHelp(std::cerr, "log_level must be one of verbose, info, warning, error, or fatal");
return Result::ExitFailure;
} else if (num_http_threads <= 0) {
PrintHelp(std::cerr, "num_http_threads must be greater than 0");
Expand Down Expand Up @@ -129,4 +129,3 @@ class ServerConfiguration {

} // namespace hosting
} // namespace onnxruntime

85 changes: 83 additions & 2 deletions onnxruntime/test/hosting/integration_tests/function_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class HttpJsonPayloadTests(unittest.TestCase):

@classmethod
def setUpClass(cls):
cmd = [cls.hosting_app_path, '--http_port', str(cls.server_port), '--model_path', os.path.join(cls.model_path, 'mnist.onnx'), '--logging_level', cls.log_level]
cmd = [cls.hosting_app_path, '--http_port', str(cls.server_port), '--model_path', os.path.join(cls.model_path, 'mnist.onnx'), '--log_level', cls.log_level]
print('Launching hosting app: [{0}]'.format(' '.join(cmd)))
cls.hosting_app_proc = subprocess.Popen(cmd)
print('Hosting app PID: {0}'.format(cls.hosting_app_proc.pid))
Expand Down Expand Up @@ -143,6 +143,52 @@ def test_mnist_missing_content_type(self):
self.assertEqual(r.content.decode('utf-8'), '{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n')


def test_single_model_shortcut(self):
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.json')

with open(input_data_file, 'r') as f:
request_payload = f.read()

with open(output_data_file, 'r') as f:
expected_response_json = f.read()
expected_response = json.loads(expected_response_json)

request_headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'x-ms-client-request-id': 'This~is~my~id'
}

url = "http://{0}:{1}/score".format(self.server_ip, self.server_port)
print(url)
r = requests.post(url, headers=request_headers, data=request_payload)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
self.assertTrue(r.headers.get('x-ms-request-id'))
self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id')

actual_response = json.loads(r.content.decode('utf-8'))

# Note:
# The 'dims' field is defined as "repeated int64" in protobuf.
# When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string
# Reference: https://developers.google.com/protocol-buffers/docs/proto3#json

self.assertTrue(actual_response['outputs'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dims'])
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dims'], ['1', '10'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dataType'])
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dataType'], 1)
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['rawData'])
actual_data = test_util.decode_base64_string(actual_response['outputs']['Plus214_Output_0']['rawData'], '10f')
expected_data = test_util.decode_base64_string(expected_response['outputs']['Plus214_Output_0']['rawData'], '10f')

for i in range(0, 10):
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))


class HttpProtobufPayloadTests(unittest.TestCase):
server_ip = '127.0.0.1'
server_port = 54321
Expand All @@ -156,7 +202,7 @@ class HttpProtobufPayloadTests(unittest.TestCase):

@classmethod
def setUpClass(cls):
cmd = [cls.hosting_app_path, '--http_port', str(cls.server_port), '--model_path', os.path.join(cls.model_path, 'mnist.onnx'), '--logging_level', cls.log_level]
cmd = [cls.hosting_app_path, '--http_port', str(cls.server_port), '--model_path', os.path.join(cls.model_path, 'mnist.onnx'), '--log_level', cls.log_level]
print('Launching hosting app: [{0}]'.format(' '.join(cmd)))
cls.hosting_app_proc = subprocess.Popen(cmd)
print('Hosting app PID: {0}'.format(cls.hosting_app_proc.pid))
Expand Down Expand Up @@ -272,5 +318,40 @@ def test_any_accept_header(self):
self.assertEqual(r.headers.get('Content-Type'), 'application/octet-stream')


class HttpEndpointTests(unittest.TestCase):
server_ip = '127.0.0.1'
server_port = 54321
hosting_app_path = ''
test_data_path = ''
model_path = ''
log_level = 'verbose'
hosting_app_proc = None
wait_server_ready_in_seconds = 1

@classmethod
def setUpClass(cls):
cmd = [cls.hosting_app_path, '--http_port', str(cls.server_port), '--model_path', os.path.join(cls.model_path, 'mnist.onnx'), '--log_level', cls.log_level]
print('Launching hosting app: [{0}]'.format(' '.join(cmd)))
cls.hosting_app_proc = subprocess.Popen(cmd)
print('Hosting app PID: {0}'.format(cls.hosting_app_proc.pid))
print('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds))
time.sleep(cls.wait_server_ready_in_seconds)


@classmethod
def tearDownClass(cls):
print('Shutdown hosting app')
cls.hosting_app_proc.kill()
print('PID {0} has been killed: {1}'.format(cls.hosting_app_proc.pid, test_util.is_process_killed(cls.hosting_app_proc.pid)))


def test_health_endpoint(self):
url = url = "http://{0}:{1}/".format(self.server_ip, self.server_port)
print(url)
r = requests.get(url)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.content.decode('utf-8'), 'Healthy')


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_models_from_model_zoo(self):
for model_path, data_paths in model_data_map.items():
hosting_app_proc = None
try:
cmd = [self.hosting_app_path, '--http_port', str(self.server_port), '--model_path', os.path.join(model_path, 'model.onnx'), '--logging_level', self.log_level]
cmd = [self.hosting_app_path, '--http_port', str(self.server_port), '--model_path', os.path.join(model_path, 'model.onnx'), '--log_level', self.log_level]
test_util.test_log(cmd)
hosting_app_proc = test_util.launch_hosting_app(cmd, self.server_ip, self.server_port, self.server_ready_in_seconds)

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/hosting/integration_tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if __name__ == '__main__':
loader = unittest.TestLoader()

test_classes = [function_tests.HttpJsonPayloadTests, function_tests.HttpProtobufPayloadTests]
test_classes = [function_tests.HttpJsonPayloadTests, function_tests.HttpProtobufPayloadTests, function_tests.HttpEndpointTests]

test_suites = []
for tests in test_classes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ TEST(ConfigParsingTests, AllArgs) {
const_cast<char*>("--address"), const_cast<char*>("4.4.4.4"),
const_cast<char*>("--http_port"), const_cast<char*>("80"),
const_cast<char*>("--num_http_threads"), const_cast<char*>("1"),
const_cast<char*>("--logging_level"), const_cast<char*>("info")};
const_cast<char*>("--log_level"), const_cast<char*>("info")};

onnxruntime::hosting::ServerConfiguration config{};
Result res = config.ParseInput(11, test_argv);
Expand All @@ -42,7 +42,7 @@ TEST(ConfigParsingTests, Defaults) {
EXPECT_EQ(config.address, "0.0.0.0");
EXPECT_EQ(config.http_port, 8001);
EXPECT_EQ(config.num_http_threads, 3);
EXPECT_EQ(config.logging_level, onnxruntime::logging::Severity::kVERBOSE);
EXPECT_EQ(config.logging_level, onnxruntime::logging::Severity::kINFO);
}

TEST(ConfigParsingTests, Help) {
Expand Down Expand Up @@ -81,7 +81,7 @@ TEST(ConfigParsingTests, ModelNotFound) {
TEST(ConfigParsingTests, WrongLoggingLevel) {
char* test_argv[] = {
const_cast<char*>("/path/to/binary"),
const_cast<char*>("--logging_level"), const_cast<char*>("not a logging level"),
const_cast<char*>("--log_level"), const_cast<char*>("not a logging level"),
const_cast<char*>("--model_path"), const_cast<char*>("testdata/mul_1.pb"),
const_cast<char*>("--address"), const_cast<char*>("4.4.4.4"),
const_cast<char*>("--http_port"), const_cast<char*>("80"),
Expand Down

0 comments on commit 17c6220

Please sign in to comment.