diff --git a/ompi/communicator/comm_cid.c b/ompi/communicator/comm_cid.c index deaaefc8ccd..f93a89d4f29 100644 --- a/ompi/communicator/comm_cid.c +++ b/ompi/communicator/comm_cid.c @@ -103,10 +103,6 @@ struct ompi_comm_allreduce_context_t { ompi_comm_cid_context_t *cid_context; int *tmpbuf; - /* for intercomm allreduce */ - int *rcounts; - int *rdisps; - /* for group allreduce */ int peers_comm[3]; }; @@ -121,8 +117,6 @@ static void ompi_comm_allreduce_context_construct (ompi_comm_allreduce_context_t static void ompi_comm_allreduce_context_destruct (ompi_comm_allreduce_context_t *context) { free (context->tmpbuf); - free (context->rcounts); - free (context->rdisps); } OBJ_CLASS_INSTANCE (ompi_comm_allreduce_context_t, opal_object_t, @@ -602,7 +596,7 @@ static int ompi_comm_allreduce_intra_nb (int *inbuf, int *outbuf, int count, str /* Non-blocking version of ompi_comm_allreduce_inter */ static int ompi_comm_allreduce_inter_leader_exchange (ompi_comm_request_t *request); static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request); -static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t *request); +static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t *request); static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf, int count, struct ompi_op_t *op, @@ -636,18 +630,19 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf, rsize = ompi_comm_remote_size (intercomm); local_rank = ompi_comm_rank (intercomm); - context->tmpbuf = (int *) calloc (count, sizeof(int)); - context->rdisps = (int *) calloc (rsize, sizeof(int)); - context->rcounts = (int *) calloc (rsize, sizeof(int)); - if (OPAL_UNLIKELY (NULL == context->tmpbuf || NULL == context->rdisps || NULL == context->rcounts)) { - ompi_comm_request_return (request); - return OMPI_ERR_OUT_OF_RESOURCE; + if (0 == local_rank) { + context->tmpbuf = (int *) calloc (count, sizeof(int)); + if (OPAL_UNLIKELY (NULL == context->tmpbuf)) { + ompi_comm_request_return (request); + return OMPI_ERR_OUT_OF_RESOURCE; + } } /* Execute the inter-allreduce: the result from the local will be in the buffer of the remote group * and vise-versa. */ - rc = intercomm->c_coll.coll_iallreduce (inbuf, context->tmpbuf, count, MPI_INT, op, intercomm, - &subreq, intercomm->c_coll.coll_iallreduce_module); + rc = intercomm->c_local_comm->c_coll.coll_ireduce (inbuf, context->tmpbuf, count, MPI_INT, op, 0, + intercomm->c_local_comm, &subreq, + intercomm->c_local_comm->c_coll.coll_ireduce_module); if (OPAL_UNLIKELY(OMPI_SUCCESS != rc)) { ompi_comm_request_return (request); return rc; @@ -656,7 +651,7 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf, if (0 == local_rank) { ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_leader_exchange, &subreq, 1); } else { - ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_allgather, &subreq, 1); + ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_bcast, &subreq, 1); } ompi_comm_request_start (request); @@ -696,33 +691,20 @@ static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request ompi_op_reduce (context->op, context->tmpbuf, context->outbuf, context->count, MPI_INT); - return ompi_comm_allreduce_inter_allgather (request); + return ompi_comm_allreduce_inter_bcast (request); } -static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t *request) +static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t *request) { ompi_comm_allreduce_context_t *context = (ompi_comm_allreduce_context_t *) request->context; - ompi_communicator_t *intercomm = context->cid_context->comm; + ompi_communicator_t *comm = context->cid_context->comm->c_local_comm; ompi_request_t *subreq; int scount = 0, rc; - /* distribute the overall result to all processes in the other group. - Instead of using bcast, we are using here allgatherv, to avoid the - possible deadlock. Else, we need an algorithm to determine, - which group sends first in the inter-bcast and which receives - the result first. - */ - - if (0 != ompi_comm_rank (intercomm)) { - context->rcounts[0] = context->count; - } else { - scount = context->count; - } - - rc = intercomm->c_coll.coll_iallgatherv (context->outbuf, scount, MPI_INT, context->outbuf, - context->rcounts, context->rdisps, MPI_INT, intercomm, - &subreq, intercomm->c_coll.coll_iallgatherv_module); + /* both roots have the same result. broadcast to the local group */ + rc = comm->c_coll.coll_ibcast (context->outbuf, context->count, MPI_INT, 0, comm, + &subreq, comm->c_coll.coll_ibcast_module); if (OMPI_SUCCESS != rc) { return rc; }