Skip to content

Commit

Permalink
Fix reviewers' comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Rombur committed Jan 13, 2023
1 parent cd0b631 commit 45acff3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 50 deletions.
42 changes: 42 additions & 0 deletions core/src/HIP/Kokkos_HIP_Instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,48 @@ Kokkos::HIP::size_type *HIPInternal::scratch_flags(const std::size_t size) {
return m_scratchFlags;
}

Kokkos::HIP::size_type *HIPInternal::stage_functor_for_execution(
void const *driver, std::size_t const size) const {
if (verify_is_initialized("scratch_functor") && m_scratchFunctorSize < size) {
m_scratchFunctorSize = size;

using Record = Kokkos::Impl::SharedAllocationRecord<Kokkos::HIPSpace, void>;
using RecordHost =
Kokkos::Impl::SharedAllocationRecord<Kokkos::HIPHostPinnedSpace, void>;

if (m_scratchFunctor) {
Record::decrement(Record::get_record(m_scratchFunctor));
RecordHost::decrement(RecordHost::get_record(m_scratchFunctorHost));
}

Record *const r =
Record::allocate(Kokkos::HIPSpace(), "Kokkos::InternalScratchFunctor",
m_scratchFunctorSize);
RecordHost *const r_host = RecordHost::allocate(
Kokkos::HIPHostPinnedSpace(), "Kokkos::InternalScratchFunctorHost",
m_scratchFunctorSize);

Record::increment(r);
RecordHost::increment(r_host);

m_scratchFunctor = reinterpret_cast<size_type *>(r->data());
m_scratchFunctorHost = reinterpret_cast<size_type *>(r_host->data());
}

// When using HSA_XNACK=1, it is necessary to copy the driver to the host to
// ensure that the driver is not destroyed before the computation is done.
// Without this fix, all the atomic tests fail. It is not obvious that this
// problem is limited to HSA_XNACK=1 even if all the tests pass when
// HSA_XNACK=0. That's why we always copy the driver.
KOKKOS_IMPL_HIP_SAFE_CALL(hipStreamSynchronize(m_stream));
std::memcpy(m_scratchFunctorHost, driver, size);
KOKKOS_IMPL_HIP_SAFE_CALL(hipMemcpyAsync(m_scratchFunctor,
m_scratchFunctorHost, size,
hipMemcpyDefault, m_stream));

return m_scratchFunctor;
}

int HIPInternal::acquire_team_scratch_space() {
int current_team_scratch = 0;
int zero = 0;
Expand Down
52 changes: 4 additions & 48 deletions core/src/HIP/Kokkos_HIP_Instance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ class HIPInternal {
HIPInternal() = default;

// Resizing of reduction related scratch spaces
size_type *scratch_space(const std::size_t size);
size_type *scratch_flags(const std::size_t size);
template <typename DriverType>
size_type *scratch_functor(DriverType const &driver) const;
size_type *scratch_space(std::size_t const size);
size_type *scratch_flags(std::size_t const size);
size_type *stage_functor_for_execution(void const *driver,
std::size_t const size) const;
uint32_t impl_get_instance_id() const noexcept;
int acquire_team_scratch_space();
// Resizing of team level 1 scratch
Expand All @@ -147,50 +147,6 @@ class HIPInternal {
void release_team_scratch_space(int scratch_pool_id);
};

template <typename DriverType>
Kokkos::HIP::size_type *HIPInternal::scratch_functor(
DriverType const &driver) const {
std::size_t size = sizeof(DriverType);
if (verify_is_initialized("scratch_functor") && m_scratchFunctorSize < size) {
m_scratchFunctorSize = size;

using Record = Kokkos::Impl::SharedAllocationRecord<Kokkos::HIPSpace, void>;
using RecordHost =
Kokkos::Impl::SharedAllocationRecord<Kokkos::HIPHostPinnedSpace, void>;

if (m_scratchFunctor) {
Record::decrement(Record::get_record(m_scratchFunctor));
RecordHost::decrement(RecordHost::get_record(m_scratchFunctorHost));
}

Record *const r =
Record::allocate(Kokkos::HIPSpace(), "Kokkos::InternalScratchFunctor",
m_scratchFunctorSize);
RecordHost *const r_host = RecordHost::allocate(
Kokkos::HIPHostPinnedSpace(), "Kokkos::InternalScratchFunctorHost",
m_scratchFunctorSize);

Record::increment(r);
RecordHost::increment(r_host);

m_scratchFunctor = reinterpret_cast<size_type *>(r->data());
m_scratchFunctorHost = reinterpret_cast<size_type *>(r_host->data());
}

// When using HSA_XNACK=1, it is necessary to copy the driver to the host to
// ensure that the driver is not destroyed before the computation is done.
// Without this fix, all the atomic tests fail. It is not obvious that this
// problem is limited to HSA_XNACK=1 even if all the tests pass when
// HSA_XNACK=0. That's why we always copy the driver.
KOKKOS_IMPL_HIP_SAFE_CALL(hipStreamSynchronize(m_stream));
std::memcpy(m_scratchFunctorHost, &driver, size);
KOKKOS_IMPL_HIP_SAFE_CALL(hipMemcpyAsync(m_scratchFunctor,
m_scratchFunctorHost, size,
hipMemcpyDefault, m_stream));

return m_scratchFunctor;
}

} // namespace Impl

namespace Experimental {
Expand Down
5 changes: 3 additions & 2 deletions core/src/HIP/Kokkos_HIP_KernelLaunch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,9 @@ struct HIPParallelLaunchKernelInvoker<DriverType, LaunchBounds,
static void invoke_kernel(DriverType const &driver, dim3 const &grid,
dim3 const &block, int shmem,
HIPInternal const *hip_instance) {
DriverType *driver_ptr =
reinterpret_cast<DriverType *>(hip_instance->scratch_functor(driver));
DriverType *driver_ptr = reinterpret_cast<DriverType *>(
hip_instance->stage_functor_for_execution(
reinterpret_cast<void const *>(&driver), sizeof(DriverType)));
(base_t::get_kernel_func())<<<grid, block, shmem, hip_instance->m_stream>>>(
driver_ptr);
}
Expand Down

0 comments on commit 45acff3

Please sign in to comment.