Skip to content

Commit

Permalink
profiler changes and mixed-input reference
Browse files Browse the repository at this point in the history
  • Loading branch information
Manish Gupta committed Sep 19, 2023
1 parent ba3f415 commit f4a751e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion tools/library/scripts/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def procedural_name(self):
ex = self.extended_name(),
tb = threadblock,
l = self.layout_name(),
a = str(self.A.alignment))
a = str(max(self.A.alignment, self.B.alignment)))

#
def configuration_name(self):
Expand Down
5 changes: 3 additions & 2 deletions tools/library/scripts/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,13 +2203,14 @@ def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version):

for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
]

data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_b,
math_inst.element_accumulator,
]

Expand Down Expand Up @@ -2254,7 +2255,7 @@ def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version):
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_a,
math_inst.element_accumulator,
]

Expand Down
16 changes: 8 additions & 8 deletions tools/library/src/reference/gemm_fp_mixed_input.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,31 +66,31 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) {
make_gemm_real_canonical_layouts<
uint8_t,
half_t,
float,
half_t,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
int8_t,
half_t,
float,
half_t,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
half_t,
uint8_t,
float,
half_t,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
half_t,
int8_t,
float,
half_t,
float,
float
>(manifest);
Expand All @@ -99,31 +99,31 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) {
make_gemm_real_canonical_layouts<
uint8_t,
bfloat16_t,
float,
bfloat16_t,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
int8_t,
bfloat16_t,
float,
bfloat16_t,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
bfloat16_t,
uint8_t,
float,
bfloat16_t,
float,
float
>(manifest);

make_gemm_real_canonical_layouts<
bfloat16_t,
int8_t,
float,
bfloat16_t,
float,
float
>(manifest);
Expand Down

0 comments on commit f4a751e

Please sign in to comment.