forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ONNX][DORT] Lazy-import
onnxruntime
(pytorch#134662)
Currently, if installed, `onnxruntime` will be imported when importing `torch._inductor` (which will be imported by some other library, e.g. transformer-engine): ``` /mnt/c.py(53)<module>() -> from torch._inductor.utils import maybe_profile /usr/local/lib/python3.10/site-packages/torch/_inductor/utils.py(49)<module>() -> import torch._export /usr/local/lib/python3.10/site-packages/torch/_export/__init__.py(25)<module>() -> import torch._dynamo /usr/local/lib/python3.10/site-packages/torch/_dynamo/__init__.py(2)<module>() -> from . import convert_frame, eval_frame, resume_execution /usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py(48)<module>() -> from . import config, exc, trace_rules /usr/local/lib/python3.10/site-packages/torch/_dynamo/trace_rules.py(52)<module>() -> from .variables import ( /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py(38)<module>() -> from .higher_order_ops import ( /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py(14)<module>() -> import torch.onnx.operators /usr/local/lib/python3.10/site-packages/torch/onnx/__init__.py(62)<module>() -> from ._internal.onnxruntime import ( /usr/local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py(37)<module>() -> import onnxruntime # type: ignore[import] ``` This issue breaks generated triton kernel because it imported torch, and unexpected runtime libraries as well. I've also added a test for this specific case under `test/onnx`, perhaps we should add more somewhere else? Related issue: huggingface/accelerate#3056 Pull Request resolved: pytorch#134662 Approved by: https://github.com/justinchuby
- Loading branch information
Showing
3 changed files
with
106 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Owner(s): ["module: onnx"] | ||
|
||
import subprocess | ||
import sys | ||
import tempfile | ||
|
||
import pytorch_test_common | ||
|
||
from torch.testing._internal import common_utils | ||
|
||
|
||
class TestLazyONNXPackages(pytorch_test_common.ExportTestCase): | ||
def _test_package_is_lazily_imported(self, pkg, torch_pkg="torch.onnx"): | ||
with tempfile.TemporaryDirectory() as wd: | ||
r = subprocess.run( | ||
[sys.executable, "-Ximporttime", "-c", "import torch.onnx"], | ||
capture_output=True, | ||
text=True, | ||
cwd=wd, | ||
check=True, | ||
) | ||
|
||
# The extra space makes sure we're checking the package, not any package containing its name. | ||
self.assertTrue( | ||
f" {pkg}" not in r.stderr, | ||
f"`{pkg}` should not be imported, full importtime: {r.stderr}", | ||
) | ||
|
||
def test_onnxruntime_is_lazily_imported(self): | ||
self._test_package_is_lazily_imported("onnxruntime") | ||
|
||
def test_onnxscript_is_lazily_imported(self): | ||
self._test_package_is_lazily_imported("onnxscript") | ||
|
||
|
||
if __name__ == "__main__": | ||
common_utils.run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters