Skip to content

Commit

Permalink
Merge pull request open-mpi#7 from RainybIue/huawei
Browse files Browse the repository at this point in the history
increase the check of data size
  • Loading branch information
nsosnsos authored Dec 9, 2020
2 parents 4481e72 + 9d68b69 commit 386000f
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions ompi/mca/coll/ucx/coll_ucx_op.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ static inline int mca_coll_ucx_is_datatype_supported(struct ompi_datatype_t *dty
return ompi_datatype_is_contiguous_memory_layout(dtype, count);
}

static ucs_status_t mca_coll_ucx_check_total_data_size(size_t dtype_size, int count)
{
static const uint64_t max_size = 4294967296;
uint64_t total_size = dtype_size * count;
return (total_size <= max_size) ? UCS_OK : UCS_ERR_OUT_OF_RANGE;
}

int mca_coll_ucx_start(size_t count, ompi_request_t** requests)
{
mca_coll_ucx_persistent_op_t *preq = NULL;
Expand Down Expand Up @@ -100,6 +107,13 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count,
ptrdiff_t extent, dsize, gap = 0;
int err;

ompi_datatype_type_extent(dtype, &extent);
ucs_status_t ret = mca_coll_ucx_check_total_data_size((size_t)extent, count);
if (OPAL_UNLIKELY(ret != UCS_OK)) {
COLL_UCX_ERROR("ucx component only support data size <= 2^32 bytes. please use other component.");
return OMPI_ERROR;
}

dsize = opal_datatype_span(&dtype->super, count, &gap);
if (sbuf == MPI_IN_PLACE && dsize != 0) {
inplace_buff = (char *)malloc(dsize);
Expand All @@ -122,8 +136,7 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count,

ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module);

ompi_datatype_type_extent(dtype, &extent);
ucs_status_t ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0,
ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0,
op, 0, 0, &coll);
if (OPAL_UNLIKELY(ret != UCS_OK)) {
COLL_UCX_ERROR("ucx allreduce init failed: %s", ucs_status_string(ret));
Expand Down Expand Up @@ -446,7 +459,12 @@ int mca_coll_ucx_bcast(void *buff, int count, struct ompi_datatype_t *dtype, int
ptrdiff_t dtype_size;
ucg_coll_h coll = NULL;
ompi_datatype_type_extent(dtype, &dtype_size);
ucs_status_t ret = ucg_coll_bcast_init(buff, buff, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, 0,
ucs_status_t ret = mca_coll_ucx_check_total_data_size((size_t)dtype_size, count);
if (OPAL_UNLIKELY(ret != UCS_OK)) {
COLL_UCX_ERROR("ucx component only support data size <= 2^32 bytes. please use other component.");
return OMPI_ERROR;
}
ret = ucg_coll_bcast_init(buff, buff, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, 0,
0, root, 0, &coll);
if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) {
COLL_UCX_ERROR("ucx bcast init failed: %s", ucs_status_string(ret));
Expand Down

0 comments on commit 386000f

Please sign in to comment.