Skip to content

Commit

Permalink
SWDEV-351980 - Remove HipApi{Callback|Activity}{Enable|Disable}Check
Browse files Browse the repository at this point in the history
The code is easier to read if calling HIPActivityCallbackTracker
enable/disable_check directly. Both enable/disable_check return the
new mask, and the check whether a callback is already installed is
clearer.

Change-Id: Ic90d34489b5b4d9929dc08b4d9e93cc974b136b1
  • Loading branch information
lmoriche committed Sep 7, 2022
1 parent 88c6e0a commit f0e082f
Showing 1 changed file with 27 additions and 56 deletions.
83 changes: 27 additions & 56 deletions src/roctracer/roctracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,6 @@ class HIPActivityCallbackTracker {

static HIPActivityCallbackTracker hip_act_cb_tracker;

inline uint32_t HipApiCallbackEnableCheck(uint32_t op) {
const uint32_t mask = hip_act_cb_tracker.enable_check(op, API_CB_MASK);
const uint32_t ret = (mask & API_ACT_MASK);
return ret;
}

inline uint32_t HipApiCallbackDisableCheck(uint32_t op) {
const uint32_t mask = hip_act_cb_tracker.disable_check(op, API_CB_MASK);
const uint32_t ret = (mask & API_ACT_MASK);
return ret;
}

inline uint32_t HipApiActivityEnableCheck(uint32_t op) {
hip_act_cb_tracker.enable_check(op, API_ACT_MASK);
return 0;
}

inline uint32_t HipApiActivityDisableCheck(uint32_t op) {
const uint32_t mask = hip_act_cb_tracker.disable_check(op, API_ACT_MASK);
const uint32_t ret = (mask & API_CB_MASK);
return ret;
}

void HIP_ApiCallback(uint32_t op_id, roctracer_record_t* record, void* callback_data, void* arg) {
hip_api_data_t* data = static_cast<hip_api_data_t*>(callback_data);
MemoryPool* pool = static_cast<MemoryPool*>(arg);
Expand Down Expand Up @@ -522,17 +499,16 @@ static void roctracer_enable_callback_fun(roctracer_domain_t domain, uint32_t op
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);

hipError_t hip_err =
HipLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIP::RegisterApiCallback(" << op << ") error(" << hip_err << ")");

if (HipApiCallbackEnableCheck(op) == 0) {
hip_err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIPAPI: HIP::RegisterActivityCallback(" << op << ") error(" << hip_err
<< ")");
if (hipError_t err =
HipLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data);
err != hipSuccess)
FATAL_LOGGING("HIP::RegisterApiCallback(" << op << ") error(" << err << ")");

if ((hip_act_cb_tracker.enable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) {
if (hipError_t err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
err != hipSuccess)
FATAL_LOGGING("HIPAPI: HIP::RegisterActivityCallback(" << op << ") error(" << err << ")");
}
break;
}
Expand Down Expand Up @@ -594,14 +570,12 @@ static void roctracer_disable_callback_fun(roctracer_domain_t domain, uint32_t o
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);

const hipError_t hip_err = HipLoader::Instance().RemoveApiCallback(op);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIP::RemoveApiCallback(" << op << "), error(" << hip_err << ")");
if (hipError_t err = HipLoader::Instance().RemoveApiCallback(op); err != hipSuccess)
FATAL_LOGGING("HIP::RemoveApiCallback(" << op << "), error(" << err << ")");

if (HipApiCallbackDisableCheck(op) == 0) {
const hipError_t hip_err = HipLoader::Instance().RemoveActivityCallback(op);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIPAPI: HIP::RemoveActivityCallback op(" << op << "), error(" << hip_err
if ((hip_act_cb_tracker.disable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) {
if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess)
FATAL_LOGGING("HIPAPI: HIP::RemoveActivityCallback op(" << op << "), error(" << err
<< ")");
}
break;
Expand Down Expand Up @@ -739,12 +713,11 @@ static void roctracer_enable_activity_fun(roctracer_domain_t domain, uint32_t op
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);

if (HipApiActivityEnableCheck(op) == 0) {
const hipError_t hip_err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, pool);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIP::RegisterActivityCallback(" << op << " error(" << hip_err << ")");
}
hip_act_cb_tracker.enable_check(op, API_ACT_MASK);
if (hipError_t err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, pool);
err != hipSuccess)
FATAL_LOGGING("HIP::RegisterActivityCallback(" << op << " error(" << err << ")");
break;
}
case ACTIVITY_DOMAIN_ROCTX:
Expand Down Expand Up @@ -835,16 +808,14 @@ static void roctracer_disable_activity_fun(roctracer_domain_t domain, uint32_t o
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);

if (HipApiActivityDisableCheck(op) == 0) {
const hipError_t hip_err = HipLoader::Instance().RemoveActivityCallback(op);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIP::RemoveActivityCallback op(" << op << "), error(" << hip_err << ")");
if ((hip_act_cb_tracker.disable_check(op, API_ACT_MASK) & API_CB_MASK) == 0) {
if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess)
FATAL_LOGGING("HIP::RemoveActivityCallback op(" << op << "), error(" << err << ")");
} else {
const hipError_t hip_err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIPACT: HIP::RegisterActivityCallback(" << op << ") error(" << hip_err
<< ")");
if (hipError_t err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
err != hipSuccess)
FATAL_LOGGING("HIPACT: HIP::RegisterActivityCallback(" << op << ") error(" << err << ")");
}
break;
}
Expand Down

0 comments on commit f0e082f

Please sign in to comment.