Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not import transformer_engine on import #3056

merged 10 commits into from
Aug 28, 2024


Copy link

@oraluben oraluben commented Aug 28, 2024

What does this PR do?

Fix a circular import when multiple package involves, including deepspeed, apex, transformer_engine, onnx.

A simple repro can be built based on ngc:


RUN pip install onnxruntime-training deepspeed transformers[torch]
RUN python -c 'import transformers.integrations.deepspeed, deepspeed; deepspeed.runtime'
full `pip list`
Package                   Version
------------------------- -------------------------
absl-py                   2.1.0
accelerate                0.33.0
aiohttp                   3.9.5
aiosignal                 1.3.1
annotated-types           0.7.0
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
asciitree                 0.3.3
asttokens                 2.4.1
astunparse                1.6.3
async-timeout             4.0.3
attrs                     23.2.0
audioread                 3.0.1
beautifulsoup4            4.12.3
bleach                    6.1.0
blis                      0.7.11
cachetools                5.3.3
catalogue                 2.0.10
Cerberus                  1.3.5
certifi                   2024.7.4
cffi                      1.16.0
charset-normalizer        3.3.2
click                     8.1.7
cloudpathlib              0.18.1
cloudpickle               3.0.0
cmake                     3.30.0
comm                      0.2.2
confection                0.1.5
contourpy                 1.2.1
cuda-python               12.5.0
cudf                      24.4.0
cugraph                   24.4.0
cugraph-dgl               24.4.0
cugraph-equivariant       24.4.0
cugraph-pyg               24.4.0
cugraph-service-client    24.4.0
cugraph-service-server    24.4.0
cuml                      24.4.0
cupy-cuda12x              13.0.0
cycler                    0.12.1
cymem                     2.0.8
Cython                    3.0.10
dask                      2024.1.1
dask-cuda                 24.4.0
dask-cudf                 24.4.0
dask-expr                 0.4.0
debugpy                   1.8.2
decorator                 5.1.1
deepspeed                 0.15.0
defusedxml                0.7.1
distributed               2024.1.1
dm-tree                   0.1.8
einops                    0.8.0
entrypoints               0.4
exceptiongroup            1.2.1
execnet                   2.1.1
executing                 2.0.1
expecttest                0.1.3
fasteners                 0.19
fastjsonschema            2.20.0
fastrlock                 0.8.2
filelock                  3.15.4
flash-attn                2.4.2
flatbuffers               24.3.25
fonttools                 4.53.1
frozenlist                1.4.1
fsspec                    2024.5.0
gast                      0.6.0
google-auth               2.32.0
google-auth-oauthlib      0.4.6
grpcio                    1.62.1
h5py                      3.11.0
hjson                     3.1.0
huggingface-hub           0.24.6
hypothesis                5.35.1
idna                      3.7
igraph                    0.11.6
importlib_metadata        7.1.0
iniconfig                 2.0.0
intel-openmp              2021.4.0
ipykernel                 6.29.5
ipython                   8.21.0
ipython-genutils          0.2.0
jedi                      0.19.1
Jinja2                    3.1.4
joblib                    1.4.2
json5                     0.9.25
jsonschema                4.23.0
jsonschema-specifications 2023.12.1
jupyter_client            8.6.2
jupyter_core              5.7.2
jupyter-tensorboard       0.2.0
jupyterlab                2.3.2
jupyterlab_pygments       0.3.0
jupyterlab-server         1.2.0
jupytext                  1.16.2
kiwisolver                1.4.5
kvikio                    24.4.0
langcodes                 3.4.0
language_data             1.2.0
lazy_loader               0.4
librosa                   0.10.1
lightning-thunder         0.2.0.dev0
lightning-utilities       0.11.3.post0
lintrunner                0.12.5
llvmlite                  0.42.0
locket                    1.0.0
looseversion              1.3.0
marisa-trie               1.2.0
Markdown                  3.6
markdown-it-py            3.0.0
MarkupSafe                2.1.5
matplotlib                3.9.1
matplotlib-inline         0.1.7
mdit-py-plugins           0.4.1
mdurl                     0.1.2
mistune                   3.0.2
mkl                       2021.1.1
mkl-devel                 2021.1.1
mkl-include               2021.1.1
mock                      5.1.0
mpmath                    1.3.0
msgpack                   1.0.8
multidict                 6.0.5
murmurhash                1.0.10
nbclient                  0.10.0
nbconvert                 7.16.4
nbformat                  5.10.4
nest-asyncio              1.6.0
networkx                  3.3
notebook                  6.4.10
numba                     0.59.1
numcodecs                 0.11.0
numpy                     1.24.4
nvfuser                   0.2.6a0+f73ff1b
nvidia-cudnn-frontend     1.5.1
nvidia-dali-cuda120       1.39.0
nvidia-modelopt           0.13.0
nvidia-pyindex            1.0.9
nvtx                      0.2.5
nx-cugraph                24.4.0
oauthlib                  3.2.2
onnxruntime-training      1.19.1
opencv                    4.7.0
opt-einsum                3.3.0
optree                    0.12.1
packaging                 24.0
pandas                    2.2.1
pandocfilters             1.5.1
parso                     0.8.4
partd                     1.4.2
pexpect                   4.9.0
pillow                    10.4.0
pip                       24.2
platformdirs              4.2.2
pluggy                    1.5.0
ply                       3.11
polygraphy                0.49.12
pooch                     1.8.2
preshed                   3.0.9
prometheus_client         0.20.0
prompt_toolkit            3.0.47
protobuf                  4.24.4
psutil                    5.9.8
ptyprocess                0.7.0
pure-eval                 0.2.2
py-cpuinfo                9.0.0
pyarrow                   14.0.2
pyasn1                    0.6.0
pyasn1_modules            0.4.0
pybind11                  2.13.1
pybind11_global           2.13.1
pycocotools               2.0+nv0.8.0
pycparser                 2.22
pydantic                  2.8.2
pydantic_core             2.20.1
Pygments                  2.18.0
pylibcugraph              24.4.0
pylibcugraphops           24.4.0
pylibraft                 24.4.0
pylibwholegraph           24.4.0
pynvjitlink               0.2.3
pynvml                    11.4.1
pyparsing                 3.1.2
pytest                    8.1.1
pytest-flakefinder        1.1.0
pytest-rerunfailures      14.0
pytest-shard              0.1.2
pytest-xdist              3.6.1
python-dateutil           2.9.0.post0
python-hostlist           1.23.0
pytorch-triton            3.0.0+989adb9a2
pytz                      2024.1
PyYAML                    6.0.1
pyzmq                     26.0.3
raft-dask                 24.4.0
rapids-dask-dependency    24.4.0a0
referencing               0.35.1
regex                     2024.5.15
requests                  2.32.3
requests-oauthlib         2.0.0
rich                      13.7.1
rmm                       24.4.0
rpds-py                   0.19.0
rsa                       4.9
safetensors               0.4.4
scikit-learn              1.5.1
scipy                     1.13.1
Send2Trash                1.8.3
setuptools                68.2.2
shellingham               1.5.4
six                       1.16.0
smart-open                7.0.4
sortedcontainers          2.4.0
soundfile                 0.12.1
soupsieve                 2.5
soxr                      0.3.7
spacy                     3.7.5
spacy-legacy              3.0.12
spacy-loggers             1.0.5
srsly                     2.4.8
stack-data                0.6.3
sympy                     1.13.0
tabulate                  0.9.0
tbb                       2021.13.0
tblib                     3.0.0
tensorboard               2.9.0
tensorboard-data-server   0.6.1
tensorboard-plugin-wit    1.8.1
tensorrt                  10.2.0
terminado                 0.18.1
texttable                 1.7.0
thinc                     8.2.5
threadpoolctl             3.5.0
thriftpy2                 0.5.0
tinycss2                  1.3.0
tokenizers                0.19.1
tomli                     2.0.1
toolz                     0.12.1
torch                     2.4.0a0+3bcc3cddb5.nv24.7
torch-tensorrt            2.5.0a0
torchvision               0.19.0a0
tornado                   6.4
tqdm                      4.66.4
traitlets                 5.9.0
transformer-engine        1.8.0+37280ec
transformers              4.44.2
treelite                  4.1.2
typer                     0.12.3
types-dataclasses         0.6.6
typing_extensions         4.12.0
tzdata                    2024.1
ucx-py                    0.37.0
urllib3                   2.0.7
wasabi                    1.1.3
wcwidth                   0.2.13
weasel                    0.4.1
webencodings              0.5.1
Werkzeug                  3.0.3
wheel                     0.43.0
wrapt                     1.16.0
xdoctest                  1.0.2
xgboost                   2.0.3
yarl                      1.9.4
zarr                      2.18.2
zict                      3.0.0
zipp                      3.19.0

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Maybe @muellerzr ?

Copy link

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! While I get why we're doing this, let's keep the solution simple and not over abstract please.

src/accelerate/utils/ Outdated Show resolved Hide resolved
Comment on lines 87 to 100
class LazyImportTester(TempDirTestCase):
Test suite which checks if specific packages are lazy-loaded.

Eager-import will trigger circular import in some case,
e.g. in huggingface/accelerate#3056.

# @require_transformer_engine
def test_te_import(self):
output = run_import_time("import accelerate, accelerate.utils.transformer_engine")

self.assertFalse(' transformer_engine' in output, '`transformer_engine` should not be imported on import')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have no CI for transformerengine, so not sure this test is actually valuable imo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review! I added this test sort of as a self-contained doc. But it's indeed not much valuable if not in CI. I've just add a one line comment in the other file, and I can undo the test if you think that's better.


The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! One last nit and we're good to go. Thanks for finding this edge case

tests/ Outdated Show resolved Hide resolved
Co-authored-by: Zach Mueller <[email protected]>
Copy link

@oraluben for the failing quality test, do pip install -e .[quality]; make style; make quality;

@muellerzr muellerzr merged commit 3fcc946 into huggingface:main Aug 28, 2024
25 checks passed
@oraluben oraluben deleted the patch-1 branch August 28, 2024 13:12
Copy link
Contributor Author

Thanks for helping landing this!

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Aug 31, 2024
Currently, if installed, `onnxruntime` will be imported when importing `torch._inductor` (which will be imported by some other library, e.g. transformer-engine):

-> from torch._inductor.utils import maybe_profile
-> import torch._export
-> import torch._dynamo
-> from . import convert_frame, eval_frame, resume_execution
-> from . import config, exc, trace_rules
-> from .variables import (
-> from .higher_order_ops import (
-> import torch.onnx.operators
-> from ._internal.onnxruntime import (
-> import onnxruntime  # type: ignore[import]

This issue breaks generated triton kernel because it imported torch, and unexpected runtime libraries as well.

I've also added a test for this specific case under `test/onnx`, perhaps we should add more somewhere else?

Related issue: huggingface/accelerate#3056
Pull Request resolved: #134662
Approved by:
tolleybot pushed a commit to tolleybot/pytorch that referenced this pull request Sep 14, 2024
Currently, if installed, `onnxruntime` will be imported when importing `torch._inductor` (which will be imported by some other library, e.g. transformer-engine):

-> from torch._inductor.utils import maybe_profile
-> import torch._export
-> import torch._dynamo
-> from . import convert_frame, eval_frame, resume_execution
-> from . import config, exc, trace_rules
-> from .variables import (
-> from .higher_order_ops import (
-> import torch.onnx.operators
-> from ._internal.onnxruntime import (
-> import onnxruntime  # type: ignore[import]

This issue breaks generated triton kernel because it imported torch, and unexpected runtime libraries as well.

I've also added a test for this specific case under `test/onnx`, perhaps we should add more somewhere else?

Related issue: huggingface/accelerate#3056
Pull Request resolved: pytorch#134662
Approved by:
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Currently, if installed, `onnxruntime` will be imported when importing `torch._inductor` (which will be imported by some other library, e.g. transformer-engine):

-> from torch._inductor.utils import maybe_profile
-> import torch._export
-> import torch._dynamo
-> from . import convert_frame, eval_frame, resume_execution
-> from . import config, exc, trace_rules
-> from .variables import (
-> from .higher_order_ops import (
-> import torch.onnx.operators
-> from ._internal.onnxruntime import (
-> import onnxruntime  # type: ignore[import]

This issue breaks generated triton kernel because it imported torch, and unexpected runtime libraries as well.

I've also added a test for this specific case under `test/onnx`, perhaps we should add more somewhere else?

Related issue: huggingface/accelerate#3056
Pull Request resolved: pytorch#134662
Approved by:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet

Successfully merging this pull request may close these issues.

3 participants