From c0799b0e0eb27578f9e0e5a8697e8d003ebd9590 Mon Sep 17 00:00:00 2001 From: Jiashu Yao Date: Thu, 6 Feb 2025 03:49:39 +0000 Subject: [PATCH] [Fix] compile errors on new image --- .github/workflows/docs-build.yaml | 4 +- .github/workflows/docs-sched-rebuild.yaml | 2 +- .../core_kernels/group_lock_kernels.cuh | 65 +++++++++++-------- include/merlin/types.cuh | 1 + 4 files changed, 43 insertions(+), 29 deletions(-) diff --git a/.github/workflows/docs-build.yaml b/.github/workflows/docs-build.yaml index 40d933a2b..d86733786 100644 --- a/.github/workflows/docs-build.yaml +++ b/.github/workflows/docs-build.yaml @@ -25,7 +25,7 @@ jobs: run: | make -C docs html - name: Upload HTML - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: html-build-artifact path: docs/build/html @@ -38,7 +38,7 @@ jobs: echo ${{ github.event.pull_request.merged }} > ./pr/merged.txt echo ${{ github.event.action }} > ./pr/action.txt - name: Upload PR information - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: pr path: pr/ diff --git a/.github/workflows/docs-sched-rebuild.yaml b/.github/workflows/docs-sched-rebuild.yaml index e307f18a2..eb8d255a4 100644 --- a/.github/workflows/docs-sched-rebuild.yaml +++ b/.github/workflows/docs-sched-rebuild.yaml @@ -38,7 +38,7 @@ jobs: find docs/build -name .doctrees -prune -exec rm -rf {} \; find docs/build -name .buildinfo -exec rm {} \; - name: Upload HTML - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: html-build-artifact path: docs/build/html diff --git a/include/merlin/core_kernels/group_lock_kernels.cuh b/include/merlin/core_kernels/group_lock_kernels.cuh index ae2f38466..a9fe9c696 100644 --- a/include/merlin/core_kernels/group_lock_kernels.cuh +++ b/include/merlin/core_kernels/group_lock_kernels.cuh @@ -15,23 +15,29 @@ */ #pragma once +#include #include namespace nv { namespace merlin { namespace group_lock { -static __global__ void init_kernel( - cuda::atomic* update_count, - cuda::atomic* read_count, +template +__global__ void init_kernel( + cuda::atomic* update_count, + cuda::atomic* read_count, cuda::atomic* unique_flag) { - new (update_count) cuda::atomic{0}; - new (read_count) cuda::atomic{0}; - new (unique_flag) cuda::atomic{false}; + if (blockIdx.x == 0 && threadIdx.x == 0) { + new (update_count) cuda::atomic{0}; + new (read_count) cuda::atomic{0}; + new (unique_flag) cuda::atomic{false}; + } } -static __global__ void lock_read_kernel( - cuda::atomic* update_count, - cuda::atomic* read_count) { + +template +__global__ void lock_read_kernel( + cuda::atomic* update_count, + cuda::atomic* read_count) { for (;;) { while (update_count->load(cuda::std::memory_order_relaxed)) { } @@ -43,14 +49,16 @@ static __global__ void lock_read_kernel( } } -static __global__ void unlock_read_kernel( - cuda::atomic* read_count) { +template +__global__ void unlock_read_kernel( + cuda::atomic* read_count) { read_count->fetch_sub(1, cuda::std::memory_order_relaxed); } -static __global__ void lock_update_kernel( - cuda::atomic* update_count, - cuda::atomic* read_count) { +template +__global__ void lock_update_kernel( + cuda::atomic* update_count, + cuda::atomic* read_count) { for (;;) { while (read_count->load(cuda::std::memory_order_relaxed)) { } @@ -62,14 +70,16 @@ static __global__ void lock_update_kernel( } } -static __global__ void unlock_update_kernel( - cuda::atomic* update_count) { +template +__global__ void unlock_update_kernel( + cuda::atomic* update_count) { update_count->fetch_sub(1, cuda::std::memory_order_relaxed); } -static __global__ void lock_update_read_kernel( - cuda::atomic* update_count, - cuda::atomic* read_count, +template +__global__ void lock_update_read_kernel( + cuda::atomic* update_count, + cuda::atomic* read_count, cuda::atomic* unique_flag) { /* Lock unique flag */ bool expected = false; @@ -101,22 +111,25 @@ static __global__ void lock_update_read_kernel( } } -static __global__ void unlock_update_read_kernel( - cuda::atomic* update_count, - cuda::atomic* read_count, +template +__global__ void unlock_update_read_kernel( + cuda::atomic* update_count, + cuda::atomic* read_count, cuda::atomic* unique_flag) { read_count->fetch_sub(1, cuda::std::memory_order_relaxed); update_count->fetch_sub(1, cuda::std::memory_order_relaxed); unique_flag->store(false, cuda::std::memory_order_relaxed); } -static __global__ void update_count_kernel( - int* counter, cuda::atomic* update_count) { +template +__global__ void update_count_kernel( + T* counter, cuda::atomic* update_count) { *counter = update_count->load(cuda::std::memory_order_relaxed); } -static __global__ void read_count_kernel( - int* counter, cuda::atomic* read_count) { +template +__global__ void read_count_kernel( + T* counter, cuda::atomic* read_count) { *counter = read_count->load(cuda::std::memory_order_relaxed); } diff --git a/include/merlin/types.cuh b/include/merlin/types.cuh index ebbe1bffd..a9446c3bd 100644 --- a/include/merlin/types.cuh +++ b/include/merlin/types.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include "debug.hpp"