Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for request id field in generate API #7392

Merged
merged 9 commits into from
Jul 10, 2024
12 changes: 9 additions & 3 deletions docs/protocol/extension_generate.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<!--
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -87,10 +87,12 @@ return an error.

$generate_request =
{
"id" : $string, #optional
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
"text_input" : $string,
"parameters" : $parameters #optional
}

* "id": An identifier for this request. Optional, but if specified this identifier must be returned in the response.
* "text_input" : The text input that the model should generate output from.
* "parameters" : An optional object containing zero or more parameters for this
generate request expressed as key/value pairs. See
Expand Down Expand Up @@ -121,14 +123,15 @@ specification to set the parameters.
Below is an example to send generate request with additional model parameters `stream` and `temperature`.

```
$ curl -X POST localhost:8000/v2/models/mymodel/generate -d '{"text_input": "client input", "parameters": {"stream": false, "temperature": 0}}'
$ curl -X POST localhost:8000/v2/models/mymodel/generate -d '{"id": "42", "text_input": "client input", "parameters": {"stream": false, "temperature": 0}}'

POST /v2/models/mymodel/generate HTTP/1.1
Host: localhost:8000
Content-Type: application/json
Content-Length: <xx>
{
"text_input": "client input",
"id" : "42",
"text_input" : "client input",
"parameters" :
{
"stream": false,
Expand All @@ -145,11 +148,13 @@ the HTTP body.

$generate_response =
{
"id" : $string
"model_name" : $string,
"model_version" : $string,
"text_output" : $string
}

* "id" : The "id" identifier given in the request, if any.
* "model_name" : The name of the model used for inference.
* "model_version" : The specific model version used for inference.
* "text_output" : The output of the inference.
Expand All @@ -159,6 +164,7 @@ the HTTP body.
```
200
{
"id" : "42"
"model_name" : "mymodel",
"model_version" : "1",
"text_output" : "model output"
Expand Down
43 changes: 43 additions & 0 deletions qa/L0_http/generate_endpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,49 @@ def test_generate(self):
self.assertIn("TEXT", data)
self.assertEqual(text, data["TEXT"])

def test_request_id(self):
Copy link
Contributor

@rmccorm4 rmccorm4 Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test! I will run a pipeline with these changes.

Pipeline ID: 16450437

Copy link
Contributor

@rmccorm4 rmccorm4 Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipeline looks mostly good, and the new unit test passed, however, please update this variable from 15 to 16 to account for the new test added:

EXPECTED_NUM_TESTS=15

EXPECTED_NUM_TESTS=16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Updated count to 16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Ran a new pipeline 16473896 and it passed 🚀 Once I verify the CLA, this can be merged 👍

# Setup text based input
text = "hello world"
request_id = "42"

# Test when request id in request body
inputs = {"PROMPT": text, "id": request_id, "STREAM": False}
r = self.generate(self._model_name, inputs)
r.raise_for_status()

self.assertIn("Content-Type", r.headers)
self.assertEqual(r.headers["Content-Type"], "application/json")

data = r.json()
self.assertIn("id", data)
self.assertEqual(request_id, data["id"])
self.assertIn("TEXT", data)
self.assertEqual(text, data["TEXT"])

# Test when request id not in request body
inputs = {"PROMPT": text, "STREAM": False}
r = self.generate(self._model_name, inputs)
r.raise_for_status()

self.assertIn("Content-Type", r.headers)
self.assertEqual(r.headers["Content-Type"], "application/json")

data = r.json()
self.assertNotIn("id", data)

# Test when request id is empty
inputs = {"PROMPT": text, "id": "", "STREAM": False}
r = self.generate(self._model_name, inputs)
r.raise_for_status()

self.assertIn("Content-Type", r.headers)
self.assertEqual(r.headers["Content-Type"], "application/json")

data = r.json()
self.assertNotIn("id", data)
self.assertIn("TEXT", data)
self.assertEqual(text, data["TEXT"])

def test_generate_stream(self):
# Setup text-based input
text = "hello world"
Expand Down
2 changes: 1 addition & 1 deletion qa/L0_http/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ fi
## Python Unit Tests
TEST_RESULT_FILE='test_results.txt'
PYTHON_TEST=generate_endpoint_test.py
EXPECTED_NUM_TESTS=15
EXPECTED_NUM_TESTS=16
set +e
python $PYTHON_TEST >$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
Expand Down
2 changes: 2 additions & 0 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3327,6 +3327,8 @@ HTTPAPIServer::HandleGenerate(
// thus the string must live as long as the JSON message).
triton::common::TritonJson::Value request;
RETURN_AND_CALLBACK_IF_ERR(EVRequestToJson(req, &request), error_callback);
RETURN_AND_CALLBACK_IF_ERR(
ParseJsonTritonRequestID(request, irequest), error_callback);

RETURN_AND_CALLBACK_IF_ERR(
generate_request->ConvertGenerateRequest(
Expand Down
Loading