-
Notifications
You must be signed in to change notification settings - Fork 10.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update HIP_UMA #7399 #7414
update HIP_UMA #7399 #7414
Conversation
add use of hipMemAdviseSetCoarseGrain when LLAMA_HIP_UMA is enable. - get x2 on prompte eval and x1.5 on token gen with rocm6.0 on ryzen 7940HX iGPU (780M/gfx1103)
I would definitely benchmark this on Windows and Linux separately. I think |
Is "HIP_UMA" possible on windows? https://hipsolver.readthedocs.io/en/rocm-6.1.1/conceptual/gpu-memory.html#coherence |
First hipMallocManaged is use only with Note: I be happy if someone that can have large VRAM on Ryzen 7940HS can bench both ... 🤞 |
I have made some minor changes to the code to simplify it a bit and be more consistent, if everything looks good let's merge this. |
Just have time to read it but looks good for me. |
or like that... static inline cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
auto res = cudaMalloc(ptr, size);
#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
// if Not enough space on VRAM => try on UMA
if (res == hipErrorOutOfMemory) {
GGML_CUDA_LOG_INFO(" Device %d: can not alloc %d MB on VRAM try alloc on HMM\n", device, (uint32_t)(size / 1024 / 1024));
res = hipMallocManaged(ptr, size);
if (res == hipSuccess) {
// Config the memory for best speed (It's not supposed to fail)
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
}
}
#endif
return res;
} |
👍 |
add UMA config for higher speed like in (ggerganov/llama.cpp#7414) but made 2 changes: - remove UMA build option - use it in all case if hipalloc failed with 'not have enough memory' an other change is look for 'hipcc' on linux and not 'amdclang++'
Add UMA config for higher speed like in (ggerganov/llama.cpp#7414) but made 2 changes: - Remove UMA build option - Use it in all case if hipalloc failed with 'not have enough memory' Another change is look for 'hipcc' on linux and not 'amdclang++'
Add use of hipMemAdviseSetCoarseGrain when LLAMA_HIP_UMA is enable.
On my Ryzen 7940HS I get some speed up:
build with:
# gfx1103 not supported use gfx1101 in place: make -j16 LLAMA_HIPBLAS=1 LLAMA_HIP_UMA=1 AMDGPU_TARGETS=gfx1101
run bench with:
I get: