Skip to content

Commit

Permalink
[ONNX][DORT] Lazy-import onnxruntime (pytorch#134662)
Browse files Browse the repository at this point in the history
Currently, if installed, `onnxruntime` will be imported when importing `torch._inductor` (which will be imported by some other library, e.g. transformer-engine):

```
  /mnt/c.py(53)<module>()
-> from torch._inductor.utils import maybe_profile
  /usr/local/lib/python3.10/site-packages/torch/_inductor/utils.py(49)<module>()
-> import torch._export
  /usr/local/lib/python3.10/site-packages/torch/_export/__init__.py(25)<module>()
-> import torch._dynamo
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/__init__.py(2)<module>()
-> from . import convert_frame, eval_frame, resume_execution
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py(48)<module>()
-> from . import config, exc, trace_rules
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/trace_rules.py(52)<module>()
-> from .variables import (
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py(38)<module>()
-> from .higher_order_ops import (
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py(14)<module>()
-> import torch.onnx.operators
  /usr/local/lib/python3.10/site-packages/torch/onnx/__init__.py(62)<module>()
-> from ._internal.onnxruntime import (
  /usr/local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py(37)<module>()
-> import onnxruntime  # type: ignore[import]
```

This issue breaks generated triton kernel because it imported torch, and unexpected runtime libraries as well.

I've also added a test for this specific case under `test/onnx`, perhaps we should add more somewhere else?

Related issue: huggingface/accelerate#3056
Pull Request resolved: pytorch#134662
Approved by: https://github.com/justinchuby
  • Loading branch information
oraluben authored and tolleybot committed Sep 14, 2024
1 parent c63ced0 commit 1ae54ad
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 24 deletions.
8 changes: 5 additions & 3 deletions test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,19 @@ def tearDown(self):
OrtBackend.clear_cached_instances()

def test_get_ort_device_type(self):
from onnxruntime.capi import _pybind_state as ORTC

self.assertEqual(
torch.onnx._internal.onnxruntime._get_ort_device_type("cuda"),
torch.onnx._internal.onnxruntime.ORTC.OrtDevice.cuda(),
ORTC.OrtDevice.cuda(),
)
self.assertEqual(
torch.onnx._internal.onnxruntime._get_ort_device_type("cpu"),
torch.onnx._internal.onnxruntime.ORTC.OrtDevice.cpu(),
ORTC.OrtDevice.cpu(),
)
self.assertEqual(
torch.onnx._internal.onnxruntime._get_ort_device_type("maia"),
torch.onnx._internal.onnxruntime.ORTC.OrtDevice.npu(),
ORTC.OrtDevice.npu(),
)

def test_torch_compile_backend_registration(self):
Expand Down
37 changes: 37 additions & 0 deletions test/onnx/test_lazy_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Owner(s): ["module: onnx"]

import subprocess
import sys
import tempfile

import pytorch_test_common

from torch.testing._internal import common_utils


class TestLazyONNXPackages(pytorch_test_common.ExportTestCase):
def _test_package_is_lazily_imported(self, pkg, torch_pkg="torch.onnx"):
with tempfile.TemporaryDirectory() as wd:
r = subprocess.run(
[sys.executable, "-Ximporttime", "-c", "import torch.onnx"],
capture_output=True,
text=True,
cwd=wd,
check=True,
)

# The extra space makes sure we're checking the package, not any package containing its name.
self.assertTrue(
f" {pkg}" not in r.stderr,
f"`{pkg}` should not be imported, full importtime: {r.stderr}",
)

def test_onnxruntime_is_lazily_imported(self):
self._test_package_is_lazily_imported("onnxruntime")

def test_onnxscript_is_lazily_imported(self):
self._test_package_is_lazily_imported("onnxscript")


if __name__ == "__main__":
common_utils.run_tests()
85 changes: 64 additions & 21 deletions torch/onnx/_internal/onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,18 @@

if TYPE_CHECKING:
import onnx

try:
# Use try-except to initialize package-dependent global variables.
import onnxruntime # type: ignore[import]
from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import]

# This is not use directly in DORT but needed by underlying exporter,
# so we still need to check if it exists.
importlib.import_module("onnxscript")
import onnxruntime
from onnxruntime.capi import _pybind_state as ORTC

import torch.onnx
import torch.onnx._internal
import torch.onnx._internal._exporter_legacy
import torch.onnx._internal.diagnostics
import torch.onnx._internal.fx.decomposition_table
import torch.onnx._internal.fx.passes
from torch.onnx._internal.fx import fx_onnx_interpreter
from torch.onnx._internal.fx.type_utils import (
_TORCH_DTYPE_TO_NUMPY_DTYPE,
_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
from_python_type_to_onnx_tensor_element_type,
)
import torch.onnx._internal.fx.passes # noqa: TCH004

_SUPPORT_ONNXRT = True
except ImportError:
_SUPPORT_ONNXRT = False

_SUPPORT_ONNXRT: Optional[bool] = None

__all__ = [
"is_onnxrt_backend_supported",
Expand Down Expand Up @@ -87,6 +73,35 @@ def is_onnxrt_backend_supported() -> bool:
... print("pip install onnx onnxscript onnxruntime")
...
"""
global _SUPPORT_ONNXRT

if _SUPPORT_ONNXRT is None:
# `onnxruntime` might import a lot of other runtime packages,
# e.g. apex, deepspeed, transformers.
# So lazy-importing onnxruntime to avoid possible circular import.
try:
importlib.import_module("onnxruntime")
importlib.import_module("onnxruntime.capi._pybind_state")

# This is not use directly in DORT but needed by underlying exporter,
# so we still need to check if it exists.
importlib.import_module("onnxscript")

import torch.onnx # noqa: F401
import torch.onnx._internal # noqa: F401
import torch.onnx._internal._exporter_legacy # noqa: F401
import torch.onnx._internal.diagnostics # noqa: F401
from torch.onnx._internal.fx import ( # noqa: F401
decomposition_table,
fx_onnx_interpreter,
passes,
type_utils,
)

_SUPPORT_ONNXRT = True
except ImportError:
_SUPPORT_ONNXRT = False

return _SUPPORT_ONNXRT


Expand Down Expand Up @@ -143,6 +158,8 @@ def _nvtx_range_pop():


def _get_ort_device_type(device_type: str):
from onnxruntime.capi import _pybind_state as ORTC

if device_type == "cuda":
return ORTC.OrtDevice.cuda()
if device_type == "cpu":
Expand Down Expand Up @@ -305,6 +322,8 @@ def _get_onnx_devices(
...,
],
) -> Tuple["ORTC.OrtDevice", ...]:
from onnxruntime.capi import _pybind_state as ORTC

def _device_id_or_zero(device_id: int) -> int:
return device_id or 0

Expand Down Expand Up @@ -338,6 +357,10 @@ def _map_tensor_or_sym_to_device(
def _get_ortvalues_from_torch_tensors(
tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...]
) -> Tuple[torch.Tensor, ...]:
from onnxruntime.capi import _pybind_state as ORTC

from torch.onnx._internal.fx.type_utils import _TORCH_DTYPE_TO_NUMPY_DTYPE

ortvalues = ORTC.OrtValueVector()
ortvalues.reserve(len(tensors))
dtypes = []
Expand Down Expand Up @@ -436,6 +459,9 @@ def _run_onnx_session_with_ortvaluevector(
...,
],
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
import onnxruntime
from onnxruntime.capi import _pybind_state as ORTC

_nvtx_range_push("contiguous")
inputs = tuple(
_adjust_scalar_from_fx_to_onnx(arg, value_info)
Expand Down Expand Up @@ -514,6 +540,8 @@ def _run_onnx_session_with_fetch(
...,
],
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
import onnxruntime

inputs = tuple(
_adjust_scalar_from_fx_to_onnx(arg, value_info)
for arg, value_info in zip(inputs, input_value_infos)
Expand Down Expand Up @@ -570,6 +598,11 @@ def __init__(
)

def is_supported(self, *args):
from torch.onnx._internal.fx.type_utils import (
_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
from_python_type_to_onnx_tensor_element_type,
)

# Compare the args and the input schema in ONNX model and
# return the first match.
if len(args) != len(self.input_value_infos):
Expand Down Expand Up @@ -728,6 +761,12 @@ class OrtBackend:
"""

def __init__(self, options: Optional[OrtBackendOptions] = None):
from onnxruntime.capi import _pybind_state as ORTC

import torch.onnx
import torch.onnx._internal._exporter_legacy
import torch.onnx._internal.fx.decomposition_table

self._options: Final = OrtBackendOptions() if options is None else options

# options.export_options contains information shared between exporter and DORT.
Expand Down Expand Up @@ -849,6 +888,10 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar
it means we delegate the computation to _ort_acclerated_call and therefore
onnxruntime.InferenceSession.
"""
import onnxruntime

from torch.onnx._internal.fx import fx_onnx_interpreter, passes

cached_execution_info_per_session = (
self._all_ort_execution_info.search_reusable_session_execution_info(
graph_module, *args
Expand All @@ -867,7 +910,7 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar
# It's first time seeing such as graph. Let's make a new session
# (type: onnxruntime.InferenceSession) for it.

graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront(
graph_module = passes.MovePlaceholderToFront(
self._resolved_onnx_exporter_options.diagnostic_context,
graph_module,
).run()
Expand Down Expand Up @@ -915,7 +958,7 @@ def maybe_map_to_meta_val(value):
# Cast FX variables if they will result schema-mismatch when searching
# for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch,
# but ONNX expects add(double_tensor, double_tensor).
graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion(
graph_module = passes.InsertTypePromotion(
self._resolved_onnx_exporter_options.diagnostic_context, graph_module
).run()
# Start the per-node exporting process. It's conceptually a for loop
Expand Down

0 comments on commit 1ae54ad

Please sign in to comment.