Skip to content

Commit

Permalink
Made SPGEMM driver not crash in arg parsing
Browse files Browse the repository at this point in the history
If "--flag" expects another argument after, check that there is actually
another arg before trying to read it.
  • Loading branch information
brian-kelley committed Feb 21, 2020
1 parent ae14746 commit 6a32a49
Showing 1 changed file with 43 additions and 33 deletions.
76 changes: 43 additions & 33 deletions perf_test/sparse/KokkosSparse_spgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,42 +66,52 @@ void print_options(){
std::cerr << "\tVerbose Output: '--verbose'" << std::endl;
}

static char* getNextArg(int& i, int argc, char** argv)
{
i++;
if(i >= argc)
{
std::cerr << "Error: expected additional command-line argument!\n";
exit(1);
}
return argv[i];
}

int parse_inputs (KokkosKernels::Experiment::Parameters &params, int argc, char **argv){
for ( int i = 1 ; i < argc ; ++i ) {
if ( 0 == strcasecmp( argv[i] , "--threads" ) ) {
params.use_threads = atoi( argv[++i] );
params.use_threads = atoi(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--openmp" ) ) {
params.use_openmp = atoi( argv[++i] );
params.use_openmp = atoi(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--cuda" ) ) {
params.use_cuda = atoi( argv[++i] ) + 1;
params.use_cuda = atoi(getNextArg(i, argc, argv)) + 1;
}
else if ( 0 == strcasecmp( argv[i] , "--repeat" ) ) {
params.repeat = atoi( argv[++i] );
params.repeat = atoi(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--hashscale" ) ) {
params.minhashscale = atoi( argv[++i] );
params.minhashscale = atoi(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--chunksize" ) ) {
params.chunk_size = atoi( argv[++i] ) ;
params.chunk_size = atoi(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--teamsize" ) ) {
params.team_size = atoi( argv[++i] ) ;
params.team_size = atoi(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--vectorsize" ) ) {
params.vector_size = atoi( argv[++i] ) ;
params.vector_size = atoi(getNextArg(i, argc, argv));
}

else if ( 0 == strcasecmp( argv[i] , "--compression2step" ) ) {
params.compression2step = true ;
}
else if ( 0 == strcasecmp( argv[i] , "--shmem" ) ) {
params.shmemsize = atoi( argv[++i] ) ;
params.shmemsize = atoi(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--memspaces" ) ) {
int memspaces = atoi( argv[++i] ) ;
int memspaces = atoi(getNextArg(i, argc, argv));
int memspaceinfo = memspaces;
std::cout << "memspaceinfo:" << memspaceinfo << std::endl;
if (memspaceinfo & 1){
Expand Down Expand Up @@ -145,19 +155,19 @@ int parse_inputs (KokkosKernels::Experiment::Parameters &params, int argc, char
params.calculate_read_write_cost = 1;
}
else if ( 0 == strcasecmp( argv[i] , "--CIF" ) ) {
params.coloring_input_file = argv[++i];
params.coloring_input_file = getNextArg(i, argc, argv);
}
else if ( 0 == strcasecmp( argv[i] , "--COF" ) ) {
params.coloring_output_file = argv[++i];
params.coloring_output_file = getNextArg(i, argc, argv);
}
else if ( 0 == strcasecmp( argv[i] , "--CCO" ) ) {
//if 0.85 set, if compression does not reduce flops by at least 15% symbolic will run on original matrix.
//otherwise, it will compress the graph and run symbolic on compressed one.
params.compression_cut_off = atof( argv[++i] ) ;
params.compression_cut_off = atof(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--FLHCO" ) ) {
//if linear probing is used as hash, what is the max occupancy percantage we allow in the hash.
params.first_level_hash_cut_off = atof( argv[++i] ) ;
params.first_level_hash_cut_off = atof(getNextArg(i, argc, argv));
}

else if ( 0 == strcasecmp( argv[i] , "--flop" ) ) {
Expand All @@ -169,30 +179,30 @@ int parse_inputs (KokkosKernels::Experiment::Parameters &params, int argc, char
//when mkl2 is run, the sort option to use.
//7:not to sort the output
//8:to sort the output
params.mkl_sort_option = atoi( argv[++i] ) ;
params.mkl_sort_option = atoi(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--mklkeepout" ) ) {
//mkl output is not kept.
params.mkl_keep_output = atoi( argv[++i] ) ;
params.mkl_keep_output = atoi(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--checkoutput" ) ) {
//check correctness
params.check_output = 1;
}
else if ( 0 == strcasecmp( argv[i] , "--amtx" ) ) {
//A at C=AxB
params.a_mtx_bin_file = argv[++i];
params.a_mtx_bin_file = getNextArg(i, argc, argv);
}

else if ( 0 == strcasecmp( argv[i] , "--bmtx" ) ) {
//B at C=AxB.
//if not provided, C = AxA will be performed.
params.b_mtx_bin_file = argv[++i];
params.b_mtx_bin_file = getNextArg(i, argc, argv);
}
else if ( 0 == strcasecmp( argv[i] , "--cmtx" ) ) {
//if provided, C will be written to given file.
//has to have ".bin", or ".crs" extension.
params.c_mtx_bin_file = argv[++i];
params.c_mtx_bin_file = getNextArg(i, argc, argv);
}
else if ( 0 == strcasecmp( argv[i] , "--dynamic" ) ) {
//dynamic scheduling will be used for loops.
Expand All @@ -207,7 +217,7 @@ int parse_inputs (KokkosKernels::Experiment::Parameters &params, int argc, char
//this parameter overwrites this.
//with cache mode, or CPUs with smaller thread count, where memory bandwidth is not an issue,
//this cut-off can be increased to be more than 250,000
params.MaxColDenseAcc= atoi( argv[++i] ) ;
params.MaxColDenseAcc = atoi(getNextArg(i, argc, argv));
}
else if ( 0 == strcasecmp( argv[i] , "--verbose" ) ) {
//print the timing and information about the inner steps.
Expand All @@ -216,43 +226,43 @@ int parse_inputs (KokkosKernels::Experiment::Parameters &params, int argc, char
params.verbose = 1;
}
else if ( 0 == strcasecmp( argv[i] , "--algorithm" ) ) {
++i;
char* algoStr = getNextArg(i, argc, argv);

if ( 0 == strcasecmp( argv[i] , "DEFAULT" ) ) {
if ( 0 == strcasecmp( algoStr, "DEFAULT" ) ) {
params.algorithm = KokkosSparse::SPGEMM_KK;
}
else if ( 0 == strcasecmp( argv[i] , "KKDEFAULT" ) ) {
else if ( 0 == strcasecmp( algoStr, "KKDEFAULT" ) ) {
params.algorithm = KokkosSparse::SPGEMM_KK;
}
else if ( 0 == strcasecmp( argv[i] , "KKSPGEMM" ) ) {
else if ( 0 == strcasecmp( algoStr, "KKSPGEMM" ) ) {
params.algorithm = KokkosSparse::SPGEMM_KK;
}

else if ( 0 == strcasecmp( argv[i] , "KKMEM" ) ) {
else if ( 0 == strcasecmp( algoStr, "KKMEM" ) ) {
params.algorithm = KokkosSparse::SPGEMM_KK_MEMORY;
}
else if ( 0 == strcasecmp( argv[i] , "KKDENSE" ) ) {
else if ( 0 == strcasecmp( algoStr, "KKDENSE" ) ) {
params.algorithm = KokkosSparse::SPGEMM_KK_DENSE;
}
else if ( 0 == strcasecmp( argv[i] , "KKLP" ) ) {
else if ( 0 == strcasecmp( algoStr, "KKLP" ) ) {
params.algorithm = KokkosSparse::SPGEMM_KK_LP;
}
else if ( 0 == strcasecmp( argv[i] , "MKL" ) ) {
else if ( 0 == strcasecmp( algoStr, "MKL" ) ) {
params.algorithm = KokkosSparse::SPGEMM_MKL;
}
else if ( 0 == strcasecmp( argv[i] , "CUSPARSE" ) ) {
else if ( 0 == strcasecmp( algoStr, "CUSPARSE" ) ) {
params.algorithm = KokkosSparse::SPGEMM_CUSPARSE;
}
else if ( 0 == strcasecmp( argv[i] , "CUSP" ) ) {
else if ( 0 == strcasecmp( algoStr, "CUSP" ) ) {
params.algorithm = KokkosSparse::SPGEMM_CUSP;
}
else if ( 0 == strcasecmp( argv[i] , "KKDEBUG" ) ) {
else if ( 0 == strcasecmp( algoStr, "KKDEBUG" ) ) {
params.algorithm = KokkosSparse::SPGEMM_KK_LP;
}
else if ( 0 == strcasecmp( argv[i] , "MKL2" ) ) {
else if ( 0 == strcasecmp( algoStr, "MKL2" ) ) {
params.algorithm = KokkosSparse::SPGEMM_MKL2PHASE;
}
else if ( 0 == strcasecmp( argv[i] , "VIENNA" ) ) {
else if ( 0 == strcasecmp( algoStr, "VIENNA" ) ) {
params.algorithm = KokkosSparse::SPGEMM_VIENNA;
}

Expand Down

0 comments on commit 6a32a49

Please sign in to comment.