Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FastAPI v1/completions/ endpoint #12101

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 108 additions & 41 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,16 +420,23 @@ def ptq(
@run.cli.entrypoint(namespace="llm")
def deploy(
nemo_checkpoint: AnyPath = None,
backend: str = "in-framework",
model_type: str = "llama",
triton_model_name: str = "triton_model",
triton_model_version: Optional[int] = 1,
triton_http_port: int = 8000,
triton_grpc_port: int = 8001,
triton_http_address: str = "0.0.0.0",
triton_model_repository: AnyPath = None,
start_fastapi_server: bool = True,
fastapi_http_address: str = "0.0.0.0",
fastapi_port: int = 8080,
num_gpus: int = 1,
num_nodes: int = 1,
tensor_parallelism_size: int = 1,
pipeline_parallelism_size: int = 1,
context_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
dtype: str = "bfloat16",
max_input_len: int = 256,
max_output_len: int = 256,
Expand All @@ -438,10 +445,12 @@ def deploy(
output_generation_logits: bool = True,
):
"""
Deploys nemo model on a PyTriton server by converting the nemo ckpt to trtllm.
Deploys nemo model on a PyTriton server either "in-framework" or by converting to trtllm depending on the backend.
This deploy method is intended to be used for evaluation.

Args:
nemo_checkpoint (Path): Path for nemo checkpoint.
backend (str):
model_type (str): Type of the model. Choices: gpt, llama, falcon, starcoder. Default: llama.
triton_model_name (str): Name for the model that gets deployed on PyTriton. Please ensure that the same model
name is passed to the evalute method for the model to be accessible while sending evalution requests.
Expand All @@ -452,7 +461,10 @@ def deploy(
triton_http_address (str): HTTP address for the PyTriton server. Default: "0.0.0.0".
triton_model_repository (Path): Folder for the trt-llm conversion, trt-llm engine gets saved in this specified
path. If None, saves it in /tmp dir. Default: None.
num_gpus (int): Number of GPUs for export to trtllm and deploy. Default: 1.
start_fastapi_server: only supported for in-framework deployment and not with 'trtllm' backend.
fastapi_http_address:
fastapi_port:
num_gpus (int): Number of GPUs per node for export to trtllm and deploy. Default: 1.
tensor_parallelism_size (int): Tensor parallelism size. Default: 1.
pipeline_parallelism_size (int): Pipeline parallelism size. Default: 1.
dtype (str): dtype of the TensorRT-LLM model. Default: "bfloat16".
Expand All @@ -468,53 +480,105 @@ def deploy(
generation_logits are used to compute the logProb of the output token in case of single token prediction
benchmarks (like MMLU, lambada). Default: True.
"""
import os

import uvicorn

from nemo.collections.llm.deploy.base import get_trtllm_deployable, unset_environment_variables
from nemo.deploy import DeployPyTriton

unset_environment_variables()
# unset_environment_variables() ## TODO Commenting for in-fw
if backend == 'in-framework':
assert (
start_fastapi_server is True
), 'in-framework deployment exposes OAI API endpoints v1/completions and \
v1/chat/completions hence needs fastAPI interface to expose these endpoints to PYtriton. Please set \
start_fastapi_server to True'
if triton_http_port == fastapi_port:
logging.error("FastAPI port and Triton server port cannot use the same port.")
return
# Store triton ip, port relevant for FastAPI as env vars to be accessible by fastapi_interface_to_pytriton.py
os.environ["TRITON_HTTP_ADDRESS"] = triton_http_address
os.environ["TRITON_PORT"] = str(triton_http_port)

triton_deployable = get_trtllm_deployable(
nemo_checkpoint,
model_type,
triton_model_repository,
num_gpus,
tensor_parallelism_size,
pipeline_parallelism_size,
max_input_len,
max_output_len,
max_batch_size,
dtype,
output_context_logits,
output_generation_logits,
)
try:
from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployableNemo2
except Exception as e:
raise ValueError(
f"Unable to import MegatronLLMDeployable, due to: {type(e).__name__}: {e} cannot run "
f"evaluation with in-framework deployment"
)

try:
nm = DeployPyTriton(
model=triton_deployable,
triton_model_name=triton_model_name,
triton_model_version=triton_model_version,
max_batch_size=max_batch_size,
http_port=triton_http_port,
grpc_port=triton_grpc_port,
address=triton_http_address,
triton_deployable = MegatronLLMDeployableNemo2(
nemo_checkpoint_filepath=nemo_checkpoint,
num_devices=num_gpus, # TODO is this per node or not ? In case of TRTLLM its per node. TRTLLM uses TP and PP size to compute num_gpus. If TP, PP=1 it just uses
# 1 GPU, since DP is not supported.
num_nodes=num_nodes, # TODO this is also just for in-fw I believe, double check and add that info in docstrings
tensor_model_parallel_size=tensor_parallelism_size,
pipeline_model_parallel_size=pipeline_parallelism_size,
context_parallel_size=context_parallel_size,
expert_model_parallel_size=expert_model_parallel_size,
)

logging.info("Triton deploy function will be called.")
nm.deploy()
nm.run()
except Exception as error:
logging.error("Error message has occurred during deploy function. Error message: " + str(error))
return

try:
logging.info("Model serving on Triton will be started.")
nm.serve()
except Exception as error:
logging.error("Error message has occurred during deploy function. Error message: " + str(error))
return
elif backend == 'trtllm':
triton_deployable = get_trtllm_deployable(
nemo_checkpoint,
model_type,
triton_model_repository,
num_gpus,
tensor_parallelism_size,
pipeline_parallelism_size,
max_input_len,
max_output_len,
max_batch_size,
dtype,
output_context_logits,
output_generation_logits,
)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: ##has been added for in-fw
try:
nm = DeployPyTriton(
model=triton_deployable,
triton_model_name=triton_model_name,
triton_model_version=triton_model_version,
max_batch_size=max_batch_size,
http_port=triton_http_port,
grpc_port=triton_grpc_port,
address=triton_http_address,
)

logging.info("Model serving will be stopped.")
nm.stop()
logging.info("Triton deploy function will be called.")
nm.deploy()
nm.run()
except Exception as error:
logging.error("Error message has occurred during deploy function. Error message: " + str(error))
return

try:
if start_fastapi_server:
try:
logging.info("REST service will be started.")
uvicorn.run(
'nemo.collections.llm.deploy.fastapi_interface_to_pytriton:app',
host=fastapi_http_address,
port=fastapi_port,
reload=True,
)
except Exception as error:
logging.error(
"Error message has occurred during REST service start. Error message: " + str(error)
)
logging.info("Model serving on Triton will be started.")
nm.serve()
except Exception as error:
logging.error("Error message has occurred during deploy function. Error message: " + str(error))
return

logging.info("Model serving will be stopped.")
nm.stop()
elif torch.distributed.get_rank() > 0: ## TODO added for in-fw
triton_deployable.generate_other_ranks()


def evaluate(
Expand Down Expand Up @@ -743,6 +807,7 @@ def generate(
max_batch_size: int = 4,
random_seed: Optional[int] = None,
inference_batch_times_seqlen_threshold: int = 1000,
inference_max_seq_length: int = 4096,
inference_params: Optional["CommonInferenceParams"] = None,
text_only: bool = False,
output_path: Optional[AnyPath] = None,
Expand Down Expand Up @@ -807,6 +872,8 @@ def generate(
random_seed (Optional[int], optional): The random seed. Defaults to None.
inference_batch_times_seqlen_threshold (int, optional): If batch-size times sequence-length is smaller than
this threshold then we will not use pipelining, otherwise we will. Defaults to 1000.
inference_max_seq_length (int, optional): max_seq_length for inference. Required by MCoreEngine(>=0.12).
Deafults to 4096.
inference_params (Optional["CommonInferenceParams"], optional): The inference parameters defined in
Mcore's CommonInferenceParams. Defaults to None.
text_only (bool, optional): Whether to return only the generated text as a string. Defaults to False.
Expand Down
60 changes: 60 additions & 0 deletions nemo/collections/llm/deploy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,66 @@

from nemo.utils import logging

# Define the chat template
chat_template = """
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}
{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}
{%- for message in messages %}
{%- if message['role'] in ['system', 'developer'] %}
{%- if ns.is_first_sp %}
{% set ns.system_prompt = ns.system_prompt + message['content'] %}
{% set ns.is_first_sp = false %}
{%- else %}
{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}
{%- endif %}
{%- endif %}
{%- endfor %}
{{ bos_token }}{{ ns.system_prompt }}
{%- for message in messages %}
{%- if message['role'] == 'user' %}
{%- set ns.is_tool = false -%}
{{'<|User|>' + message['content']}}
{%- endif %}
{%- if message['role'] == 'assistant' and 'tool_calls' in message %}
{%- set ns.is_tool = false -%}
{%- for tool in message['tool_calls'] %}
{%- if not ns.is_first %}
{%- if message['content'] is none %}
{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '``````' + '<|tool▁call▁end|>'}}
{%- else %}
{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '``````' + '<|tool▁call▁end|>'}}
{%- endif %}
{%- set ns.is_first = true -%}
{%- else %}
{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '``````' + '<|tool▁call▁end|>'}}
{%- endif %}
{%- endfor %}
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
{%- endif %}
{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}
{%- if ns.is_tool %}
{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}
{% set ns.is_tool = false %}
{%- else %}
{% set content = message['content'] %}
{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}
{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}
{%- endif %}
{%- endif %}
{%- if message['role'] == 'tool' %}
{% set ns.is_tool = true -%}
{% if ns.is_output_first -%}
{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
{% set ns.is_output_first = false -%}
{% else -%}
{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
{% endif -%}
{% endif -%}
{%- endfor -%}
{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}
{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\\n'}}{% endif %}
"""


def unset_environment_variables() -> None:
"""
Expand Down
Loading
Loading