diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 54613e5b05..362329b5de 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -196,9 +196,7 @@ def _get_min_alignment(size: int, alignment_value: int) -> int: 16 ``` """ - if size % alignment_value == 0: - return size - return (1 + (size // alignment_value)) * alignment_value + return (1 + ((size - 1) // alignment_value)) * alignment_value def pad_tensor_for_matmul( @@ -234,10 +232,6 @@ def pad_tensor_for_matmul( dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1 dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2 - # Check if padding is needed for either dimension - if dim1 == dim1_aligned and dim2 == dim2_aligned: - return tensor - # Calculate padding values for both dimensions pad_dim1 = dim1_aligned - dim1 pad_dim2 = dim2_aligned - dim2