Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add required libcuda.so #6864

Closed
wants to merge 1 commit into from

Conversation

sdake
Copy link

@sdake sdake commented Jul 27, 2024

It is necessary to add libcuda.so so that vllm will run against rebuilds of pytorch. I built my own version of pytorch (v2.3.1) to workaround an issue with GPU compute versions to run Neural Magic's dynamic quantization model. The reason for my recompile, although not super relevant, was the related to BLAS in PyTorch failing with and A40 (compute 8.6) and only expecting compute>8.9, which is unrelated to this issue.

I then tried something less exoticfacebookresearch/opt125m, and a problem was exposed showing that cuTensorMapEncodeTiled could not be loaded as part of import vllm._C. The problem resulted in failure of the llvm python process. The cuTensorMapEncodedTiled symbol is provided by cutlass and used by pytorch. the
_C.abi3.so file shows the symbol is undefined (nm -g /home/sdake/v-llm/lib/python3.11/site-packages/vllm/_C.abi3.so | grep cuTensorMapEncodeTiled shows U meaning the symbol isn't available via the system's dynamic loader and this is caused by the symbol not being linked into the dymamic library. I tried patchelf to add the dynamic library to _C.abi3.so which works.

I could use the patchelf workaround, however, a more pressing problem is that when cloning pytorch v2.3.1, the symbol is consumed. I am not clear why this doesn't show up in broader community testing.

Running this build script with this Dockerfile will reproduce the problem nicely. After building, a pytorch wheel is output to ${CWD}/target/torch-2.3.1-cp311-cp311-linux_x86_64.whl

It is necessary to add libcuda.so so that vllm will run against rebuilds
of pytorch. I built my own version of pytorch (v2.3.1) to workaround an
issue with GPU compute versions to run Neural Magic's [dynamic
quantization model](https://huggingface.co/neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8-dynamic#use-with-vllm).
The reason for my recompile, although not super relevant, was the related to
`BLAS` in PyTorch failing with and A40 (compute 8.6) and only expecting compute>8.9.

I then tried `facebookresearch/opt125m`, and a problem was exposed
showing that `cuTensorMapEncodeTiled` could not be loaded as part of
import vllm._C. The problem resulted in failure of the llvm python process
.The `cuTensorMapEncodedTiled` symbol is provided by `cutlass` and used by
`pytorch`. the
`_C.abi3.so` file shows the symbol is undefined (`nm -g
/home/sdake/v-llm/lib/python3.11/site-packages/vllm/_C.abi3.so | grep
cuTensorMapEncodeTiled` shows `U` meaning the symbol isn't available via
the system's dynamic loader and this is caused by the symbol not being
linked into the dymamic library. I tried `patchelf` to add the dynamic
library to `_C.abi3.so` which works.

I could use the patchelf workaround, however, a more pressing problem is
that when cloning `pytorch v2.3.1`, the symbol is consumed. I am not
clear why this doesn't show up in broader community testing.

Running this [build
script](artificialwisdomai/origin@091699e#diff-1b3ed36bdb219f011fec128d976b51764d50d07472e81d439a4456f73c89ecd6R6)
with this
[Dockerfile](artificialwisdomai/origin@091699e#diff-ad85eb666c2a47809160de9dcdeb14638c67cc18f38398092719d0f5a415a30bR128) will
reproduce the problem nicely. After building, a pytorch wheel is output
to `${CWD}/target/torch-2.3.1-cp311-cp311-linux_x86_64.whl`
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@sdake
Copy link
Author

sdake commented Jul 27, 2024

/ready

@cliffwoolley
Copy link

This looks correct to us on the NVIDIA side. We're testing it locally and @kushanam will report back. Thanks @sdake !

@cliffwoolley
Copy link

LGTM

@sdake
Copy link
Author

sdake commented Oct 9, 2024 via email

@kushanam
Copy link
Contributor

kushanam commented Oct 9, 2024

Some of the vLLM abi for custom ops require functionality from the CUDA driver rather than the runtime. e.g: cuTensorMapEncodeTiled. The library seems to get loaded when built with default packages (possibly by one of the older versions of the dependencies), but upgrading them would result in custom ops not getting built correctly.

@cliffwoolley
Copy link

There are two driver API functions (not runtime API) in use, and that definitely requires linking directly to libcuda or else dlopen/dlsym of the same so that the symbols can be resolved. This dependency is coming in from CUTLASS, it looked like to me.

@DarkLight1337
Copy link
Member

@youkaichao is this still relevant?

@mergify mergify bot added the ci/build label Oct 30, 2024
Copy link

mergify bot commented Oct 30, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @sdake please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 30, 2024
@cliffwoolley
Copy link

Yes this is still relevant. It's needed because of https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L399 .

@youkaichao
Copy link
Member

we usually use the libcuda.so brought by pytorch. if you build pytorch from scratch, chances are pytorch statically link against libcuda.so , and you have to make it available in the linker's path.

I don't think this is needed for the general public.

@cliffwoolley
Copy link

cliffwoolley commented Nov 2, 2024 via email

@cliffwoolley
Copy link

This is still needed -- can we go ahead and get it merged please?

@tlrmchlsmth
Copy link
Collaborator

It's needed because of https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L399 .

@cliffwoolley could you say a bit more about why it's needed for this?

@kushanam
Copy link
Contributor

It's needed because of https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L399 .

@cliffwoolley could you say a bit more about why it's needed for this?

Hi Tyler, the reason this doesn't throw an error by default is that PyTorch loads libcuda.so, and it happens to be loaded before vLLM makes any calls to CUDA. However, if that order is altered for any reason, lack of it leads to issues.

@cliffwoolley
Copy link

cliffwoolley commented Jan 24, 2025

Because

It's needed because of https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L399 .

@cliffwoolley could you say a bit more about why it's needed for this?

Where we are using the CUDA driver API -- which the linked line instructs CUTLASS to do -- then we have to have some form of linkage to the CUDA driver library, or else we can get an undefined symbol error like in the OP at runtime. This can be addressed in one of two ways: one is to directly link to the CUDA driver, which is what this PR accomplishes. The other way is via dlopen/dlsym -- or the corresponding convenience wrapper from the CUDA Runtime API that is cudaGetDriverEntryPoint. The latter might be preferable for executables that still have to load without CUDA (perhaps CPU-only, or on some other device type), though it would require some care to ensure that the symbol loaded is given the same name and signature as the CUDA driver would provide while being local/private to to vLLM (so as not to conflict with the actual CUDA driver library, which may also be loaded).

We get away without this only when something else in the address space -- e.g. libtorch -- has already loaded libcuda.so into the process and the dynamic loader can resolve our symbol even though there wasn't actual linkage to it. But it's possible in some cases to call into parts of vLLM that don't happen fo first call through PyTorch to initialize the GPU, and then the problem appears.

@tlrmchlsmth
Copy link
Collaborator

@kushanam and @cliffwoolley that is pretty convincing, thanks for the thorough explanation. Any reason not to do the samefor the other CUDA .so files as well (_moe_C.abi3.so and cumem_allocator.abi3.so)?

@cliffwoolley
Copy link

It's best to limit this to places where it's strictly needed. You can find that out by looking at nm -D on whichever library and looking for U symbols whose names start like cuSomething (rather than cudaSomething, which are from the runtime API, aka libcudart).

@cliffwoolley
Copy link

(rather than cudaSomething, which are from the runtime API, aka libcudart).

... The reason here being that the CUDA Runtime doesn't present this issue, as it does the dlopen/dlsym approach on the application's behalf, when you call through the runtime instead of direct to the driver API. So CUTLASS also wouldn't have this problem if we didn't tell it to bypass the Runtime API.

@tlrmchlsmth
Copy link
Collaborator

@cliffwoolley check out the following:

nm -Du cumem_allocator.abi3.so | grep cu
                 U cuCtxGetCurrent
                 U cuCtxSetCurrent
                 U cuDevicePrimaryCtxRetain
                 U cuGetErrorString
                 U cuMemAddressFree
                 U cuMemAddressReserve
                 U cuMemCreate
                 U cuMemGetAllocationGranularity
                 U cuMemMap
                 U cuMemRelease
                 U cuMemSetAccess
                 U cuMemUnmap

Needed for the cumem_allocator so or no?

@tlrmchlsmth
Copy link
Collaborator

Actually looks like the LIBRARIES ${CUMEM_LIBS} line should handle that one

@mergify mergify bot removed the needs-rebase label Jan 24, 2025
@cliffwoolley
Copy link

Actually looks like the LIBRARIES ${CUMEM_LIBS} line should handle that one

Ah yes, good catch. I agree -- https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L195C1-L195C31 seems to be another way to express the same exact intent as this PR, for that other built library. If we had cross-platform builds to contend with (e.g. Windows native), then the https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L195C1-L195C31 -ism might be actually better?

@tlrmchlsmth
Copy link
Collaborator

@cliffwoolley I'll do that in a different PR

Actually looks like the LIBRARIES ${CUMEM_LIBS} line should handle that one

Ah yes, good catch. I agree -- https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L195C1-L195C31 seems to be another way to express the same exact intent as this PR, for that other built library. If we had cross-platform builds to contend with (e.g. Windows native), then the https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L195C1-L195C31 -ism might be actually better?

Closing this PR in favor of that approach in #12424 (also because this one is super stale and unlikely to get through the CI)

@sdake
Copy link
Author

sdake commented Jan 25, 2025

@tlrmchlsmth this would naturally pass ci as ci rebases work.

Thanks for the review

Take care
-steve

@youkaichao
Copy link
Member

just catching up the thread. I'm fine with the change, but in general we depend on pytorch to load these libraries.

We get away without this only when something else in the address space -- e.g. libtorch -- has already loaded libcuda.so into the process and the dynamic loader can resolve our symbol even though there wasn't actual linkage to it. But it's possible in some cases to call into parts of vLLM that don't happen fo first call through PyTorch to initialize the GPU, and then the problem appears.

@cliffwoolley I'm curious, when would you encounter this situation, where you use vllm but torch is not loaded? I think that's quite rare, we use import torch a lot. It is intended to hand it over to pytorch, because we don't want to manage them from vllm side.

Unless you directly import vLLM's C extensions from another file?

@cliffwoolley
Copy link

@youkaichao Each DSO should have linkage to whatever is needed to resolve the symbols it references.

Where we saw the issue originally was in various vLLM tests we've constructed. We weren't just importing the vllm._C alone though no, but basically it was some fast path toward vLLM doing so. (I'm not sure I can quickly find that again.)

Now.. all that said, I discovered the following alternative means of accomplishing this (I think) that might actually have made this (embarrassingly to me) duplicative, which raced with this PR. I haven't had time to go back and verify that that one alone was sufficient.
https://github.com/vllm-project/vllm/blame/main/cmake/utils.cmake#L440 from #9588 which was merged Oct 27. (Our branch was on newer PyTorch but older vLLM at the time of my most recent previous comments, and I didn't think to go look in that other file for relevant behavior changes.)

@tlrmchlsmth
Copy link
Collaborator

Update on my end is that I landed #12424 to fix this issue but ended up reverting it as libcuda.so wasn't being found on some users' systems. I was going to look into copying what pytorch does using dlopen, since that is known to work, but if there's an easier way following https://github.com/vllm-project/vllm/blame/main/cmake/utils.cmake#L440 (or if that already works) that would be great news

@sdake
Copy link
Author

sdake commented Feb 4, 2025

Hello gang,

When I originally filed this PR, I was seeing some strange behavior such as:

  1. collect-env.py reporting no found cuda device.
  2. other cuda code reporting no found cuda device.

Digging back into that original error to respond on this thread.

Find packages owning libcuda.so:

fish> dpkg-query --search libcuda.so
libcuda1:amd64: /usr/lib/x86_64-linux-gnu/libcuda.so
cuda-driver-dev-12-4: /usr/local/cuda-12.4/targets/x86_64-linux/lib/stubs/libcuda.so

Contents of libcuda1:

fish> dpkg-query --listfiles libcuda1
/usr/lib/x86_64-linux-gnu/libcuda.so.570.86.15
/usr/lib/x86_64-linux-gnu/libcuda.so
/usr/lib/x86_64-linux-gnu/libcuda.so.1

Contents of cuda-driver-dev:

fish> dpkg-query --listfiles cuda-driver-dev-12-4
/usr/lib/pkgconfig/cuda-12.4.pc
/usr/local/cuda-12.4/targets/x86_64-linux/lib/stubs/libcuda.so

I presume nvidia created this stubs driver so that test and release can build software with a smaller T&R footprint.

When you link dynamically, the linker specifies the shared object name. The resolution of the name (in this case libcuda.so) occurs during runtime. As @cliffwoolley pointed out, you can achieve this by implementing your own custom dlopen() code, or you can rely on the system linker to automate the process.

The system linker (/usr/bin/ld) maintains a cache of all shared objects in the system and their relative locations, so that runtime linking can occur. The system linker needs a list of paths to scan for *.so files. Typically, the linker is configured via drop-in files placed in /etc/ld.so.conf.d that identify a path to scan when rebuilding the dynamic link cache. When you then run ldconfig, the dynamic linker cache is rebuilt then and only then.

VLLM's test and release team builds a vllm wheel. User sdake downloads the vllm wheel and installs it within his system. The vllm test and release team need not install libcuda1 (which brings in the kernel driver among other parts of the kitchen), instead they simply set the link flag -L /usr/local/cuda-12.4/targets/86_64-linux/lib/stubs and link with the "name" of libcuda.so. When sdake runs vllm, the ld dynamically links /usr/lib/x86_64-gnu-linux/libcuda.so at runtime.

In short, two libcuda.so is provided by NVIDIA for two different use cases. This can confuse some.

Thank you,
-steve

@cliffwoolley
Copy link

/usr/local/cuda-12.4/targets/86_64-linux/lib/stubs
I presume nvidia created this stubs driver so that test and release can build software with a smaller T&R footprint.

CUDA applications are meant to be buildable on systems without GPUs. So the stubs are there for link time, not for runtime.

It's not a real driver -- it just exposes the same entrypoints as the driver (to make the linker happy), but if you try to actually call those functions through the stubs library, every one of them will return cudaErrorStubLibrary without actually doing anything else.

(An implication maybe worth making explicit is that the stubs are useful for building [linking], but not helpful for testing, except perhaps to the extent you're wanting to run CPU-only tests and bypass GPU functionality -- though maybe I'd argue there are better ways to accomplish the CPU-only part than use of these stubs at runtime would be.)

cudaErrorStubLibrary = 34
This indicates that the CUDA driver that the application has loaded is a stub library. Applications that run with the stub rather than a real driver loaded will result in CUDA API returning this error.

@sdake
Copy link
Author

sdake commented Feb 5, 2025

@cliffwoolley understood thanks. I was simply educating @tlrmchlsmth.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants