Skip to content

Commit

Permalink
ShyLU - Basker : tune memory pre-allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
iyamazaki committed Nov 3, 2024
1 parent 7ab9d36 commit dcfc61c
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 191 deletions.
19 changes: 0 additions & 19 deletions packages/shylu/shylu_node/basker/src/shylubasker_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,25 +528,6 @@ namespace BaskerNS
Int off_diag
);

BASKER_INLINE
void L_blk_sfactor
(
BASKER_MATRIX &MV,
BASKER_SYMBOLIC_TREE &ST,
INT_1DARRAY gcol,
INT_1DARRAY grow
);

//old
BASKER_INLINE
void L_blk_sfactor
(
BASKER_MATRIX_VIEW &MV,
BASKER_SYMBOLIC_TREE &ST,
INT_1DARRAY gcol,
INT_1DARRAY grow
);

BASKER_INLINE
void S_sfactor_reduce
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace BaskerNS
<< " DOMBLK MALLOC : blk=" << thread_array(ti).error_blk
<< " subblk=" << thread_array(ti).error_subblk
<< " newsize=" << thread_array(ti).error_info
<< std::endl;
<< std::endl << std::flush;
}

//If on diagonal, want to compare L and U
Expand Down Expand Up @@ -113,7 +113,7 @@ namespace BaskerNS
{
if(Options.verbose == BASKER_TRUE)
{
std::cout << " ++ resize L( tid = " << ti << " ): new size = " << resize_L << std::endl;
std::cout << " ++ resize L( tid = " << ti << " ): new size = " << resize_L << std::endl << std::flush;
}
BASKER_MATRIX &L =
LL(thread_array(ti).error_blk)(thread_array(ti).error_subblk);
Expand All @@ -139,7 +139,7 @@ namespace BaskerNS
{
if(Options.verbose == BASKER_TRUE)
{
std::cout << " ++ resize U( tid = " << ti << " ): new size = " << resize_U << std::endl;
std::cout << " ++ resize U( tid = " << ti << " ): new size = " << resize_U << std::endl << std::flush;
}
BASKER_MATRIX &U =
LU(thread_array(ti).error_blk)(0);
Expand Down
44 changes: 25 additions & 19 deletions packages/shylu/shylu_node/basker/src/shylubasker_nfactor_blk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,37 +546,37 @@ namespace BaskerNS
if (Options.replace_tiny_pivot && normA_blk > abs(zero)) {
// just insert tiny pivot on diagonal
maxindex = k;
while (gperm(maxindex+brow_g) != BASKER_MAX_IDX && maxindex < M.ncol) {
while (gperm(maxindex+brow_g) != BASKER_MAX_IDX && maxindex < M.ncol) {
maxindex ++;
}
if (maxindex < M.ncol) {
}
if (maxindex < M.ncol) {
if (Options.verbose == BASKER_TRUE)
{
cout << " thread-" << kid << " Explicit tiny pivot for maxind = " << maxindex << endl;
}
}
pivot = normA_blk * sqrt(eps);
lastU = pivot;
npivots ++;
explicit_pivot = true;
}
explicit_pivot = true;
}
} else if (Options.replace_zero_pivot && normA_blk > abs(zero)) {
// just insert tiny pivot on diagonal
maxindex = k;
while (gperm(maxindex+brow_g) != BASKER_MAX_IDX && maxindex < M.ncol-1) {
while (gperm(maxindex+brow_g) != BASKER_MAX_IDX && maxindex < M.ncol-1) {
maxindex ++;
}
if (maxindex < M.ncol) {
}
if (maxindex < M.ncol) {
if (Options.verbose == BASKER_TRUE)
{
cout << " thread-" << kid << " Explicit nonzero pivot for maxind = " << maxindex << "(" << gperm(maxindex+brow_g) << ")" << endl;
}
}
pivot = normA_blk * eps;
lastU = pivot;
npivots ++;
explicit_pivot = true;
}
explicit_pivot = true;
}
}
if (!explicit_pivot) {
if (!explicit_pivot) {
thread_array(kid).error_type =
BASKER_ERROR_SINGULAR;
thread_array(kid).error_blk = b;
Expand Down Expand Up @@ -1543,8 +1543,8 @@ namespace BaskerNS


#ifdef BASKER_DEBUG_NFACTOR_BLK
printf("t_dense_move_offdiag_L, kid=%d, k=%d: L (%d %d) X (%d %d)\n",
kid, k, blkcol,blkrow, X_col, X_row);
printf("t_dense_move_offdiag_L, kid=%d, k=%d: L (%d %d) X (%d %d), nnz=%d\n",
kid, k, blkcol,blkrow, X_col, X_row, L.nnz);
#endif


Expand All @@ -1565,15 +1565,22 @@ namespace BaskerNS
}
*/

///for(Int i = 0; i < p_size; i++)
for(Int j = 0; j < L.nrow; ++j)
{
//Int j = pattern[i];
//Int t = gperm(j);
if(X(j) != (Entry)(0) )
{
//Int t = gperm(j+brow);

if (lnnz >= L.nnz) { // this should not happen since allocated as dense separator blocks
if (Options.verbose == BASKER_TRUE)
{
printf("Move Off-diag L failed with insufficient storage L(%d,%d).nnz = %d\n",
(int)blkcol, (int)blkrow, (int)L.nnz );
}
BASKER_ASSERT(true, "\n Not enough memory allocated for off-diagonal L\n");
return BASKER_ERROR;
}
#ifdef BASKER_DEBUG_NFACTOR_BLK
printf("L-Moving, kid: %d j: %d val: %f lnnz: %d \n",
kid, j, X[j]/pivot, lnnz);
Expand All @@ -1594,7 +1601,6 @@ namespace BaskerNS
#ifdef BASKER_INC_LVL
L.inc_lvl[lnnz] = INC_LVL_TEMP[j];
#endif

lnnz++;
}
}
Expand Down Expand Up @@ -1756,7 +1762,7 @@ namespace BaskerNS
printf("t_back_solve_diag, kid: %d, ws: %d starting psize: %d \n",
kid, ws_size, nnz);
printf("t_back_solve_diag, kid: %d, ALM(%d)(%d): %dx%d\n",kid,blkcol,blkrow,B.nrow,B.ncol );
printf("t_back_solve_diag, kid: %d, LL(%d)(%d): %dx%d\n",kid,blkcol,blkrow,L.nrow,L.ncol );
printf("t_back_solve_diag, kid: %d, LL(%d)(%d): %dx%d, nnz=%d, X.nnz=%d\n",kid,blkcol,blkrow,L.nrow,L.ncol,LL(blkcol)(blkrow).nnz,X.extent(0) );
printf("\n\n"); fflush(stdout);
#endif
//B.info();
Expand Down
26 changes: 15 additions & 11 deletions packages/shylu/shylu_node/basker/src/shylubasker_nfactor_col2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ namespace BaskerNS
}//for - over all sublevel 1...lvl-2
#ifdef BASKER_TIMER
printf("Time Upper-Col(%d): %lf \n", (int)kid, timer.seconds());
timer.reset();
fflush(stdout); timer.reset();
#endif

//---------Lower Factor (old sublevel lvl-1)-------
Expand All @@ -255,11 +255,11 @@ namespace BaskerNS
}
}
#endif
#ifdef BASKER_DEBUG_NFACTOR_COL2
#ifdef BASKER_TIMER
printf("\n done with UPPER, kid: %d \n\n", kid);
printf("\n\n======= LOWER, KID: %d ======= \n\n", kid);
fflush(stdout);
#endif

//printf("\n\n======= LOWER, KID: %d ======= \n\n", kid);
//return;
// > accumulate the last update
// > factor the diagonal block LU(U_col)(U_row)
Expand All @@ -284,7 +284,8 @@ namespace BaskerNS
if (info == BASKER_SUCCESS)
{
#ifdef BASKER_DEBUG_NFACTOR_COL2
printf( " kid=%d: calling t_add_extend(k=%d/%d)\n",kid,k,ncol ); fflush(stdout);
printf( " kid=%d: calling t_add_extend(k=%d/%d) with LU(%d,%d).nnz = %d\n",
kid,k,ncol,U_col,U_row,LU(U_col)(U_row).nnz ); fflush(stdout);
#endif
t_add_extend(thread, kid,lvl,lvl-1, k,
LU(U_col)(U_row).scol,
Expand Down Expand Up @@ -316,13 +317,13 @@ namespace BaskerNS
}
}
#ifdef BASKER_DEBUG_NFACTOR_COL2
printf(" > done calling lower factor, kid: %d k: %d info=%d\n", kid, k, info); fflush(stdout);
#endif
#ifdef BASKER_DEBUG_NFACTOR_COL2
else {
printf(" + skipping lower factor, kid: %d k: %d \n", kid, k); fflush(stdout);
}
#endif
#ifdef BASKER_DEBUG_NFACTOR_COL2
printf(" > done calling lower factor, kid: %d k: %d info=%d\n", kid, k, info); fflush(stdout);
#endif
//need barrier if multiple thread uppdate
#ifdef USE_TEAM_BARRIER_NFACTOR_COL2
thread.team_barrier();
Expand Down Expand Up @@ -356,12 +357,12 @@ namespace BaskerNS
timer_facoff.reset();
#endif
#ifdef BASKER_DEBUG_NFACTOR_COL2
printf(" calling lower diag factor, kid: %d k: %d \n",
printf(" calling lower offdiag factor, kid: %d k: %d \n",
kid, k); fflush(stdout);
#endif
t_lower_col_factor_offdiag2(kid, lvl, lvl-1, k, pivot);
#ifdef BASKER_DEBUG_NFACTOR_COL2
printf(" done lower diag factor, kid: %d k: %d \n",
printf(" done lower offdiag factor, kid: %d k: %d \n",
kid, k); fflush(stdout);
#endif
}
Expand Down Expand Up @@ -906,7 +907,10 @@ namespace BaskerNS
L_row < LL_size(L_col);
X_row+=(lteam_size), L_row+=(lteam_size))
{
//printf("OFF_DIAG_LOWER. kid: %d k: %d U: %d %d L: %d %d X: %d %d pivot: %f \n", kid, k, U_col, U_row, L_col, L_row, X_col, X_row, pivot);
#ifdef BASKER_TIMER
printf("OFF_DIAG_LOWER. kid: %d k: %d U(%d, %d).nnz = %d L(%d, %d) X(%d, %d) pivot: %f \n",
kid, k, U_col, U_row, LU(U_col)(U_row).nnz, L_col, L_row, X_col, X_row, pivot);
#endif
/*old
t_back_solve_offdiag(leader_id,
L_col, L_row,
Expand Down
8 changes: 6 additions & 2 deletions packages/shylu/shylu_node/basker/src/shylubasker_order.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1160,8 +1160,10 @@ static int basker_sort_matrix_col(const void *arg1, const void *arg2)
std::cout << " > scotch_partition returned with info = " << info_scotch << " and apply_nd = " << apply_nd << std::endl;
}
return info_scotch;
} else if(Options.verbose == BASKER_TRUE) {
printf( "\n part_scotch done (num_threads = %d,%d)\n",num_threads,part_tree.leaf_nnz.extent(0) );
//for (Int i = 0; i < num_threads; i++) printf( " nnz_leaf[%d] = %d\n",i,part_tree.leaf_nnz[i] ); printf( "\n" );
}

nd_flag = BASKER_TRUE;
//permute
permute_row(M, part_tree.permtab);
Expand Down Expand Up @@ -2200,7 +2202,9 @@ static int basker_sort_matrix_col(const void *arg1, const void *arg2)
INT_1DARRAY row
)
{
permute_row(M.nnz, &(M.row_idx(0)), &(row(0)));
if (M.nnz > 0) {
permute_row(M.nnz, &(M.row_idx(0)), &(row(0)));
}
return 0;
}//end permute_row(matrix,int)

Expand Down
40 changes: 25 additions & 15 deletions packages/shylu/shylu_node/basker/src/shylubasker_order_scotch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ namespace BaskerNS

// id of the first leaf node (BF order, post_order maps from BF to ND)
Int leaves_id = pow(2.0, (double)(num_levels)) - 1;
//printf( " num_levels = %d, num_doms = %d, leves_id = %d\n",num_levels,num_doms,leaves_id );
if (Options.verbose == BASKER_TRUE) {
printf( " num_domains = %d: num_levels = %d, num_doms = %d, leves_id = %d\n",num_domains,num_levels,num_doms,leaves_id );
}

// > insert root
Int num_queued = 0;
Expand Down Expand Up @@ -297,11 +299,14 @@ namespace BaskerNS
// level goes to num_leaves so that we can call ND on the final leaf nodes
last_level = num_levels;
}
if (Options.verbose == BASKER_TRUE) {
if (run_nd_on_leaves) {
if (run_nd_on_leaves) {
if (Options.verbose == BASKER_TRUE) {
std::cout << std::endl << " + Using ND on leaves + " << std::endl;
} else if (run_amd_on_leaves) {
std::cout << std::endl << " + Using AMD on leaves + " << std::endl;
}
} else if (run_amd_on_leaves) {
MALLOC_INT_1DARRAY(BT.leaf_nnz, num_doms);
if (Options.verbose == BASKER_TRUE) {
std::cout << std::endl << " + Using AMD on leaves (# doms = " << num_doms << ") + " << std::endl;
}
}
// -------------------------------------------------- //
Expand Down Expand Up @@ -551,11 +556,16 @@ namespace BaskerNS
for(Int i = 0; i < metis_size_k; i++) {
metis_iperm_k(metis_perm_k(i)) = i;
}
if (Options.verbose == BASKER_TRUE) {
std::cout << std::endl << " > Basker AMD on leaf : estimated nnz(L(" << leaf_id << ") = " << l_nnz
<< " <" << std::endl << std::endl;
}
info = METIS_OK;
} else {
std::cout << std::endl << " > Basker AMD failed < " << std::endl << std::endl;
return BASKER_ERROR; // TODO: what to do here?
}
BT.leaf_nnz(leaf_id) = l_nnz;
}

// update perm/
Expand Down Expand Up @@ -888,7 +898,7 @@ namespace BaskerNS
sg.nz = sg.Ap[sg.m];

//printf("num self_edge: %d sg.m: %d \n",
// self_edge, sg.m);
// self_edge, sg.m);
if(self_edge != (sg.m))
{
BASKER_ASSERT(self_edge == (sg.m-1),
Expand Down Expand Up @@ -990,11 +1000,11 @@ namespace BaskerNS
#ifdef BASKER_DEBUG_ORDER_SCOTCH
printf("FIX SCOTCH PRINT OUT\n");
printf("SCOTCH: NUM_LEVELS ASKED = %d, NUM DOMS GOT = %d, NUM TREES = %d \n",
num_levels, sg.cblk, num_trees);
num_levels, sg.cblk, num_trees);
printf("\n");
printf("%d %d should blks: %f \n",
2, ((Int)num_levels+1),
pow(2.0,((double)num_levels+1))-1);
2, ((Int)num_levels+1),
pow(2.0,((double)num_levels+1))-1);
#endif

if(((sg.cblk) != pow(2.0,((double)num_levels+1))-1) || (num_trees != 1))
Expand Down Expand Up @@ -1028,7 +1038,7 @@ namespace BaskerNS
#ifdef BASKER_DEBUG_ORDER_SCOTCH
printf("\n\n Starting DEBUG COMPLETE OUT \n\n");
printf("Tree: ");
` for(Int i = 0; i < iblks+1; i++)
for(Int i = 0; i < iblks+1; i++)
{
printf("%d, ", ttree(i));
}
Expand Down Expand Up @@ -1217,11 +1227,11 @@ namespace BaskerNS
Int mynum = iblks-1;
otree(iblks) = -1;
rec_build_tree(lvl,
lpos,rpos,
mynum,
otree);
lpos,rpos,
mynum,
otree);

INT_1DARRAY ws;
BASKER_ASSERT((iblks+1)>0, "scotch iblks 2");
MALLOC_INT_1DARRAY(ws, iblks+1);
Expand Down Expand Up @@ -1486,7 +1496,7 @@ namespace BaskerNS
)
{
//printf("assign, lpos: %d rpos: %d number: %d\n",
// lpos, rpos, mynum);
// lpos, rpos, mynum);

if(lvl > 0)
{
Expand Down
Loading

0 comments on commit dcfc61c

Please sign in to comment.