Skip to content

Commit

Permalink
Merge pull request open-mpi#10412 from MamziB/mamzi/get-accum
Browse files Browse the repository at this point in the history
OSC/UCX: Fix data validation issue in get accumulate and intrinsic atomic ops
  • Loading branch information
janjust authored May 24, 2022
2 parents bab0bd7 + 0031008 commit c05c23b
Showing 1 changed file with 38 additions and 20 deletions.
58 changes: 38 additions & 20 deletions ompi/mca/osc/ucx/osc_ucx_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,35 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
return ret;
}

static inline
bool osc_is_atomic_dt_op_supported(
struct ompi_datatype_t *dt,
struct ompi_op_t *op,
size_t dt_bytes,
uint64_t remote_addr)
{
/* UCX atomics are only supported on 32 and 64 bit values */
if (!ompi_datatype_is_predefined(dt) ||
!ompi_osc_base_is_atomic_size_supported(remote_addr, dt_bytes)) {
return false;
}
/* Hardware-based atomic add for floating point is not supported */
else if ((
op == &ompi_mpi_op_no_op.op
|| op == &ompi_mpi_op_replace.op
|| op == &ompi_mpi_op_sum.op
)
&& !(
op == &ompi_mpi_op_sum.op
&& (dt == MPI_FLOAT || dt == MPI_DOUBLE
|| dt == MPI_LONG_DOUBLE || dt == MPI_FLOAT_INT)
)) {
return true;
}

return false;
}

static inline
bool use_atomic_op(
ompi_osc_ucx_module_t *module,
Expand All @@ -388,25 +417,16 @@ bool use_atomic_op(
int origin_count,
int target_count)
{
size_t origin_dt_bytes;

if (module->acc_single_intrinsic &&
ompi_datatype_is_predefined(origin_dt) &&
origin_count == 1 &&
(op == &ompi_mpi_op_replace.op ||
op == &ompi_mpi_op_sum.op ||
op == &ompi_mpi_op_no_op.op)) {
size_t origin_dt_bytes;
size_t target_dt_bytes;
if (!module->acc_single_intrinsic || origin_count != 1 || target_count != 1
|| origin_dt != target_dt) {
return false;
} else {
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
ompi_datatype_type_size(target_dt, &target_dt_bytes);
/* UCX only supports 32 and 64-bit operands atm */
if (ompi_osc_base_is_atomic_size_supported(remote_addr, origin_dt_bytes) &&
origin_dt_bytes == target_dt_bytes && origin_count == target_count) {
return true;
}
return osc_is_atomic_dt_op_supported(origin_dt, op, origin_dt_bytes,
remote_addr);
}

return false;
}

static int do_atomic_op_intrinsic(
Expand Down Expand Up @@ -859,10 +879,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
ompi_datatype_type_size(dt, &dt_bytes);

/* UCX atomics are only supported on 32 and 64 bit values */
if (ompi_osc_base_is_atomic_size_supported(remote_addr, dt_bytes) &&
(op == &ompi_mpi_op_no_op.op || op == &ompi_mpi_op_replace.op ||
op == &ompi_mpi_op_sum.op)) {
if (osc_is_atomic_dt_op_supported(dt, op, dt_bytes, remote_addr)) {
uint64_t value;
ucp_atomic_fetch_op_t opcode;
bool lock_acquired = false;
Expand Down Expand Up @@ -973,6 +990,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
if (ret != OMPI_SUCCESS) {
return ret;
}
temp_count *= target_count;
}
ompi_datatype_get_true_extent(temp_dt, &temp_lb, &temp_extent);
temp_addr = free_addr = malloc(temp_extent * temp_count);
Expand Down

0 comments on commit c05c23b

Please sign in to comment.