Skip to content

Commit

Permalink
add flexible input options for smem_args
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanwerkhoven committed Jun 17, 2021
1 parent 4fcf095 commit 069d1c6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
2 changes: 1 addition & 1 deletion kernel_tuner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options,

#add shared memory arguments to compiled module
if kernel_options.smem_args is not None:
self.dev.copy_shared_memory_args(kernel_options.smem_args)
self.dev.copy_shared_memory_args(util.get_smem_args(kernel_options.smem_args, params))
#add constant memory arguments to compiled module
if kernel_options.cmem_args is not None:
self.dev.copy_constant_memory_args(kernel_options.cmem_args)
Expand Down
2 changes: 1 addition & 1 deletion kernel_tuner/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def run_kernel(kernel_name, kernel_string, problem_size, arguments, params, grid

#add shared memory arguments to compiled module
if smem_args is not None:
dev.copy_shared_memory_args(smem_args)
dev.copy_shared_memory_args(util.get_smem_args(smem_args, params))
#add constant memory arguments to compiled module
if cmem_args is not None:
dev.copy_constant_memory_args(cmem_args)
Expand Down
13 changes: 13 additions & 0 deletions kernel_tuner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,19 @@ def get_problem_size(problem_size, params):
return current_problem_size


def get_smem_args(smem_args, params):
""" return a dict with kernel instance specific size """
result = smem_args.copy()
if 'size' in result:
size = result['size']
if callable(size):
size = size(params)
elif isinstance(size, str):
size = util.replace_param_occurrences(size, params)
size = int(eval(size))
return result


def get_temp_filename(suffix=None):
""" return a string in the form of temp_X, where X is a large integer """
file = tempfile.mkstemp(
Expand Down

0 comments on commit 069d1c6

Please sign in to comment.