diff --git a/packages/ml/src/Comm/ml_comm.c b/packages/ml/src/Comm/ml_comm.c index d9ad14a91c6c..7e52c54e8d57 100644 --- a/packages/ml/src/Comm/ml_comm.c +++ b/packages/ml/src/Comm/ml_comm.c @@ -21,6 +21,15 @@ ML_Comm *global_comm = NULL; /* should not be used to avoid side effect */ /* -------------------------------------------------------------------- */ int ML_Comm_Create( ML_Comm ** com ) +{ +#ifdef ML_MPI + return ML_Comm_Create2(com, MPI_COMM_WORLD); +#else + return ML_Comm_Create2(com, 0); +#endif +} + +int ML_Comm_Create2( ML_Comm ** com, USR_COMM in_comm ) { ML_Comm *com_ptr; @@ -36,13 +45,13 @@ int ML_Comm_Create( ML_Comm ** com ) com_ptr->USR_cheapwaitbytes = ML_Comm_CheapWait; #ifdef ML_MPI - MPI_Comm_size(MPI_COMM_WORLD, &(com_ptr->ML_nprocs)); - MPI_Comm_rank(MPI_COMM_WORLD, &(com_ptr->ML_mypid)); + MPI_Comm_size(in_comm, &(com_ptr->ML_nprocs)); + MPI_Comm_rank(in_comm, &(com_ptr->ML_mypid)); com_ptr->USR_sendbytes = ML_Comm_Send; com_ptr->USR_irecvbytes = ML_Comm_Irecv; com_ptr->USR_waitbytes = ML_Comm_Wait; com_ptr->USR_cheapwaitbytes = ML_Comm_CheapWait; - com_ptr->USR_comm = MPI_COMM_WORLD; + com_ptr->USR_comm = in_comm; #ifdef ML_CATCH_MPI_ERRORS_IN_DEBUGGER /* register the error handling function */ ML_Comm_ErrorHandlerCreate((USR_ERRHANDLER_FUNCTION *) ML_Comm_ErrorHandler, diff --git a/packages/ml/src/Comm/ml_comm.h b/packages/ml/src/Comm/ml_comm.h index 24fdb63db783..a8d40cd6893a 100644 --- a/packages/ml/src/Comm/ml_comm.h +++ b/packages/ml/src/Comm/ml_comm.h @@ -103,6 +103,7 @@ extern "C" #endif extern int ML_Comm_Create( ML_Comm ** comm ); +extern int ML_Comm_Create2( ML_Comm ** comm, USR_COMM com ); extern int ML_Comm_Destroy( ML_Comm ** comm ); extern int ML_Comm_Check( ML_Comm *comm ); diff --git a/packages/ml/src/Main/ml_struct.c b/packages/ml/src/Main/ml_struct.c index f81453cd0046..cf845e7b01ba 100755 --- a/packages/ml/src/Main/ml_struct.c +++ b/packages/ml/src/Main/ml_struct.c @@ -39,6 +39,15 @@ ML_PrintControl ML_PrintLevel = {0}; int ml_defines_have_printed = 0; int ML_Create(ML **ml_ptr, int Nlevels) +{ +#ifdef ML_MPI + return ML_Create2(ml_ptr, Nlevels, MPI_COMM_WORLD); +#else + return ML_Create2(ml_ptr, Nlevels, 0); +#endif +} + +int ML_Create2(ML **ml_ptr, int Nlevels, USR_COMM in_comm) { int i, length; double *max_eigen; @@ -77,7 +86,7 @@ int ML_Create(ML **ml_ptr, int Nlevels) (*ml_ptr)->repartitionStartLevel = -1; (*ml_ptr)->RAP_storage_type=ML_MSR_MATRIX; - ML_Comm_Create( &((*ml_ptr)->comm) ); + ML_Comm_Create2( &((*ml_ptr)->comm), in_comm ); if (global_comm == NULL) global_comm = (*ml_ptr)->comm; diff --git a/packages/ml/src/Main/ml_struct.h b/packages/ml/src/Main/ml_struct.h index 19e7f73edb16..b0b9ca350ea3 100755 --- a/packages/ml/src/Main/ml_struct.h +++ b/packages/ml/src/Main/ml_struct.h @@ -140,6 +140,13 @@ extern ML_PrintControl ML_PrintLevel; /* ******************************************************************** */ /* ******************************************************************** */ +#ifdef ML_MPI +#include "mpi.h" +#define USR_COMM MPI_Comm +#else +#define USR_COMM int +#endif + #ifndef ML_CPP #ifdef __cplusplus extern "C" { @@ -147,6 +154,7 @@ extern "C" { #endif extern int ML_Create(ML **ml, int Nlevels); +extern int ML_Create2(ML **ml, int Nlevels, USR_COMM comm); extern int ML_build_ggb( ML *ml, void *data); extern void ML_build_ggb_cheap(ML *ml, void *data); extern void ML_build_ggb_fat(ML *ml, void *data);