Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] add an option to log every function call to for debugging hang/crash in distributed inference #4079

Merged
merged 21 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py

- label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py

- label: Entrypoints Test
commands:
Expand Down
2 changes: 2 additions & 0 deletions .github/ISSUE_TEMPLATE/400-bug report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ body:
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.

Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.

If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
placeholder: |
A clear and concise description of what the bug is.

Expand Down
27 changes: 27 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import sys
import tempfile

from vllm.logger import enable_trace_function_call


def f1(x):
return f2(x)


def f2(x):
return x


def test_trace_function_call():
fd, path = tempfile.mkstemp()
cur_dir = os.path.dirname(__file__)
enable_trace_function_call(path, cur_dir)
f1(1)
with open(path, 'r') as f:
content = f.read()

assert "f1" in content
assert "f2" in content
sys.settrace(None)
os.remove(path)
12 changes: 9 additions & 3 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
get_vllm_instance_id, make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -133,12 +133,18 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)

# Set CUDA_VISIBLE_DEVICES for the driver and workers.
VLLM_INSTANCE_ID = get_vllm_instance_id()

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = []
for (node_id, _) in worker_node_and_gpu_ids:
all_args_to_update_environment_variables.append([{
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id]))
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"),
}])
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
Expand Down
52 changes: 52 additions & 0 deletions vllm/logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import datetime
import logging
import os
import sys
from functools import partial
from typing import Optional

VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
Expand Down Expand Up @@ -65,3 +67,53 @@ def init_logger(name: str):
logger.addHandler(_default_handler)
logger.propagate = False
return logger


logger = init_logger(__name__)


def _trace_calls(log_path, root_dir, frame, event, arg=None):
if event in ['call', 'return']:
# Extract the filename, line number, function name, and the code object
filename = frame.f_code.co_filename
lineno = frame.f_lineno
func_name = frame.f_code.co_name
if not filename.startswith(root_dir):
# only log the functions in the vllm root_dir
return
# Log every function call or return
try:
with open(log_path, 'a') as f:
if event == 'call':
f.write(f"{datetime.datetime.now()} Call to"
f" {func_name} in {filename}:{lineno}\n")
else:
f.write(f"{datetime.datetime.now()} Return from"
f" {func_name} in {filename}:{lineno}\n")
except NameError:
# modules are deleted during shutdown
pass
return partial(_trace_calls, log_path, root_dir)


def enable_trace_function_call(log_file_path: str,
root_dir: Optional[str] = None):
"""
Enable tracing of every function call in code under `root_dir`.
This is useful for debugging hangs or crashes.
`log_file_path` is the path to the log file.
`root_dir` is the root directory of the code to trace. If None, it is the
vllm root directory.

Note that this call is thread-level, any threads calling this function
will have the trace enabled. Other threads will not be affected.
"""
logger.warning(
"VLLM_TRACE_FUNCTION is enabled. It will record every"
" function executed by Python. This will slow down the code. It "
"is suggested to be used for debugging hang or crashes only.")
logger.info(f"Trace frame log is saved to {log_file_path}")
if root_dir is None:
# by default, this is the vllm root directory
root_dir = os.path.dirname(os.path.dirname(__file__))
sys.settrace(partial(_trace_calls, log_file_path, root_dir))
13 changes: 12 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,17 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)


@lru_cache(maxsize=None)
def get_vllm_instance_id():
"""
If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID.
Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id.
"""
return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}")


@lru_cache(maxsize=None)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
Expand Down Expand Up @@ -274,7 +285,7 @@ def get_open_port() -> int:

def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items():
if k in os.environ:
if k in os.environ and os.environ[k] != v:
logger.warning(f"Overwriting environment variable {k} "
f"from '{os.environ[k]}' to '{v}'")
os.environ[k] = v
Expand Down
20 changes: 17 additions & 3 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import datetime
import importlib
import os
import tempfile
import threading
from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple

from vllm.logger import init_logger
from vllm.logger import enable_trace_function_call, init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import update_environment_variables
from vllm.utils import get_vllm_instance_id, update_environment_variables

logger = init_logger(__name__)

Expand Down Expand Up @@ -115,9 +118,20 @@ def update_environment_variables(self, envs: Dict[str, str]) -> None:

def init_worker(self, *args, **kwargs):
"""
Actual initialization of the worker class.
Actual initialization of the worker class, and set up
function tracing if required.
Arguments are passed to the worker class constructor.
"""
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be moved into enable_trace_function_call

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I separate and simplify the logic in enable_trace_function_call so that it can be tested in a standalone way. The caller should be responsible for the logic of creating the log file path.

enable_trace_function_call(log_path)

mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
Expand Down
Loading