Skip to content

Commit

Permalink
Split UCC initialization and context creation
Browse files Browse the repository at this point in the history
Signed-off-by: ferrol aderholdt <[email protected]>
  • Loading branch information
ferrol aderholdt authored and wckzhang committed May 10, 2022
1 parent 30784a0 commit 90a0550
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 14 deletions.
4 changes: 3 additions & 1 deletion oshmem/mca/scoll/ucc/scoll_ucc.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ struct mca_scoll_ucc_component_t {
char * cts;
int nr_modules;
bool libucc_initialized;
ucc_context_h ucc_context;
ucc_lib_h ucc_lib;
ucc_lib_attr_t ucc_lib_attr;
ucc_coll_type_t cts_requested;
ucc_context_h ucc_context;
};
typedef struct mca_scoll_ucc_component_t mca_scoll_ucc_component_t;

Expand Down Expand Up @@ -85,6 +85,8 @@ int mca_scoll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_threa
int mca_scoll_ucc_team_create(mca_scoll_ucc_module_t *ucc_module,
oshmem_group_t *osh_group);

int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group);

mca_scoll_base_module_t* mca_scoll_ucc_comm_query(oshmem_group_t *osh_group, int *priority);

int mca_scoll_ucc_barrier(struct oshmem_group_t *group, long *pSync, int alg);
Expand Down
6 changes: 6 additions & 0 deletions oshmem/mca/scoll/ucc/scoll_ucc_alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ static inline ucc_status_t mca_scoll_ucc_alltoall_init(const void *sbuf, void *r
.global_work_buffer = ucc_module->pSync,
};

if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}

if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
Expand Down
9 changes: 8 additions & 1 deletion oshmem/mca/scoll/ucc/scoll_ucc_barrier.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@ static inline ucc_status_t mca_scoll_ucc_barrier_init(mca_scoll_ucc_module_t * u
.mask = 0,
.coll_type = UCC_COLL_TYPE_BARRIER
};

if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}

if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
}
}

SCOLL_UCC_REQ_INIT(req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -49,4 +57,3 @@ int mca_scoll_ucc_barrier(struct oshmem_group_t *group, long *pSync, int alg)
pSync, alg);
return rc;
}

7 changes: 7 additions & 0 deletions oshmem/mca/scoll/ucc/scoll_ucc_broadcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ static inline ucc_status_t mca_scoll_ucc_broadcast_init(void * buf, int count,
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
}
};

if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}

if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
Expand Down
6 changes: 6 additions & 0 deletions oshmem/mca/scoll/ucc/scoll_ucc_collect.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ static inline ucc_status_t mca_scoll_ucc_collect_init(const void * sbuf, void *
},
};

if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}

if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
Expand Down
3 changes: 2 additions & 1 deletion oshmem/mca/scoll/ucc/scoll_ucc_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ mca_scoll_ucc_component_t mca_scoll_ucc_component = {
"basic", /* cls */
SCOLL_UCC_CTS_STR, /* cts */
0, /* nr_modules */
false /* libucc_initialized */
false, /* libucc_initialized */
NULL /* ucc_context */
};

static int ucc_register(void)
Expand Down
37 changes: 26 additions & 11 deletions oshmem/mca/scoll/ucc/scoll_ucc_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ static void mca_scoll_ucc_module_destruct(mca_scoll_ucc_module_t *ucc_module)
}

if (0 == mca_scoll_ucc_component.nr_modules) {
if (mca_scoll_ucc_component.libucc_initialized) {
if (mca_scoll_ucc_component.libucc_initialized) {
if (mca_scoll_ucc_component.ucc_context) {
opal_progress_unregister(mca_scoll_ucc_progress);
ucc_context_destroy(mca_scoll_ucc_component.ucc_context);
}
UCC_VERBOSE(1, "finalizing ucc library");
opal_progress_unregister(mca_scoll_ucc_progress);
ucc_context_destroy(mca_scoll_ucc_component.ucc_context);
ucc_finalize(mca_scoll_ucc_component.ucc_lib);
mca_scoll_ucc_component.libucc_initialized = false;
}
Expand Down Expand Up @@ -199,17 +201,12 @@ static ucc_status_t oob_allgather_test(void *req)
return oob_probe_test(oob_req);
}

static int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group)
static int mca_scoll_ucc_init(oshmem_group_t *osh_group)
{
mca_scoll_ucc_component_t *cm = &mca_scoll_ucc_component;
ucc_mem_map_t *maps = NULL;
char str_buf[256];
ucc_lib_config_h lib_config;
ucc_context_config_h ctx_config;
ucc_thread_mode_t tm_requested;
ucc_lib_params_t lib_params;
ucc_context_params_t ctx_params;
int segment;

tm_requested = oshmem_mpi_thread_multiple ? UCC_THREAD_MULTIPLE :
UCC_THREAD_SINGLE;
Expand Down Expand Up @@ -247,6 +244,25 @@ static int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group)
goto cleanup_lib;
}

cm->libucc_initialized = true;
return OSHMEM_SUCCESS;

cleanup_lib:
ucc_finalize(cm->ucc_lib);
cm->ucc_enable = 0;
cm->libucc_initialized = false;
return OSHMEM_ERROR;
}

int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group)
{
mca_scoll_ucc_component_t *cm = &mca_scoll_ucc_component;
ucc_mem_map_t *maps = NULL;
char str_buf[256];
ucc_context_config_h ctx_config;
ucc_context_params_t ctx_params;
int segment;

maps = (ucc_mem_map_t *)malloc(sizeof(ucc_mem_map_t) *
memheap_map->n_segments);
if (NULL == maps) {
Expand Down Expand Up @@ -398,7 +414,6 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module,
opal_show_help("help-oshmem-scoll-ucc.txt",
"module_enable:fatal", true,
"UCC module enable failed - aborting to prevent inconsistent application state");

goto err;
}
UCC_VERBOSE(1, "ucc enabled");
Expand Down Expand Up @@ -446,7 +461,7 @@ mca_scoll_ucc_comm_query(oshmem_group_t *osh_group, int *priority)

if (!cm->libucc_initialized) {
if (memheap_map && memheap_map->n_segments > 0) {
if (OSHMEM_SUCCESS != mca_scoll_ucc_init_ctx(osh_group)) {
if (OSHMEM_SUCCESS != mca_scoll_ucc_init(osh_group)) {
cm->ucc_enable = 0;
return NULL;
}
Expand Down
7 changes: 7 additions & 0 deletions oshmem/mca/scoll/ucc/scoll_ucc_reduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ static inline ucc_status_t mca_scoll_ucc_reduce_init(const void *sbuf, void *rbu
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}

if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}

if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
Expand Down

0 comments on commit 90a0550

Please sign in to comment.