Skip to content

Commit

Permalink
[Pack-peel-4-level-tiling] Enable Matmul+Trunci : i8->i8 on Phoenix|…
Browse files Browse the repository at this point in the history
…Strix + Ukernel|Vectorization (#1084)

Add support for handling Matmul + Trunci : i8->i8 on Phoenix/Strix.

On Phoenix it works via both vectorization as well as ukernel path.
On Strix it currently works only via ukernel path.

Therefore, for now, three e2e CI tests have been added.

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma authored Feb 12, 2025
1 parent 08e7777 commit 8d01ec9
Show file tree
Hide file tree
Showing 11 changed files with 377 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// input ${M}x${K}x${TYPE1}
// input ${K}x${N}x${TYPE1}

func.func @matmul_trunci(%arg0: tensor<${M}x${K}x${TYPE1}>, %arg1: tensor<${K}x${N}x${TYPE1}>) -> tensor<${M}x${N}x${TYPE1}>
{
%cst = arith.constant ${ZERO} : ${TYPE2}
%0 = tensor.empty() : tensor<${M}x${N}x${TYPE2}>
%1 = linalg.fill ins(%cst : ${TYPE2}) outs(%0 : tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<${M}x${K}x${TYPE1}>, tensor<${K}x${N}x${TYPE1}>)
outs(%1: tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}>
%3 = arith.trunci %2 : tensor<${M}x${N}x${TYPE2}> to tensor<${M}x${N}x${TYPE1}>
return %3: tensor<${M}x${N}x${TYPE1}>
}
138 changes: 133 additions & 5 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def run(self, config):
# does not).
if self.use_chess and not config.vitis_dir:
return False
if self.use_ukernel and not config.vitis_dir:
return False

# If use_chess=0, and config has not provided a valid
# path to peano, then bail: a path to peano must be provided.
Expand Down Expand Up @@ -655,11 +657,70 @@ def _execute(self, config):
input_args = generate_inputs(
filename, self.get_dir(config), 1, {1: self.lhs, 2: self.rhs}
)
"""
Currently without function outlining, we run out of program memory.
"""
self.add_aie_compilation_flags(
["--iree-amdaie-enable-function-outlining=balanced"]
aie_vs_baseline(
config=config,
aie_compilation_flags=self.aie_compilation_flags,
test_file=self.get_filename(config),
input_args=input_args,
baseline_value=self.expected_out,
use_ukernel=self.use_ukernel,
tile_pipeline=self.tile_pipeline,
function_name=None,
seed=1,
rtol=0,
atol=0,
lower_to_aie_pipeline=self.lower_to_aie_pipeline,
n_repeats=self.n_repeats,
output_type=get_output_type(self.get_filename(config)),
)

return True


class MatmulTrunci(BaseMatmul):
"""
A test of the form matmul(A,B) + trunci(C) where A:MxK, B:KxN and C:MxN
"""

def __init__(
self,
M,
N,
K,
input_type,
acc_type,
lhs,
rhs,
expected_out,
test_params=None,
):
super().__init__(
name=f"matmul_trunci_{M}_{N}_{K}_{input_type}_{acc_type}",
test_params=test_params,
M=M,
N=N,
K=K,
input_type=input_type,
acc_type=acc_type,
)
self.labels.append("MatmulTrunci")

# Assertions on shapes: Check that lhs is MxK, rhs is KxN, and expected_out is MxN
assert lhs.shape == (M, K)
assert rhs.shape == (K, N)
assert expected_out.shape == (M, N)

self.lhs = lhs
self.rhs = rhs
self.expected_out = expected_out

def _execute(self, config):
matmul_template_dir = config.file_dir / "matmul_template"
template_name = matmul_template_dir / "matmul_trunci_MxK_KxN.mlir"
self.generate(config, template_name)
filename = self.get_filename(config)
input_args = generate_inputs(
filename, self.get_dir(config), 1, {1: self.lhs, 2: self.rhs}
)
aie_vs_baseline(
config=config,
Expand Down Expand Up @@ -1462,6 +1523,73 @@ def __init__(self):
self.existing_names = []
self.tests = []

# Tests Matmul + Trunci.
# Phoenix : Ukernel + Peano.
self.register(
MatmulTrunci(
256,
128,
32,
"i8",
"i32",
1 * np.ones([256, 32], dtype=np.int8),
1 * np.ones([32, 128], dtype=np.int8),
32 * np.ones([256, 128], dtype=np.int8),
test_params=TestParams(
tile_pipeline="pack-peel-4-level-tiling",
run_on_target=["npu1_4col"],
aie_compilation_flags=[
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=4",
],
use_ukernel=True,
),
)
)
# Phoenix : Vectorization + Peano.
self.register(
MatmulTrunci(
256,
128,
32,
"i8",
"i32",
1 * np.ones([256, 32], dtype=np.int8),
1 * np.ones([32, 128], dtype=np.int8),
32 * np.ones([256, 128], dtype=np.int8),
test_params=TestParams(
tile_pipeline="pack-peel-4-level-tiling",
run_on_target=["npu1_4col"],
aie_compilation_flags=[
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=4",
],
),
)
)
# Strix : Ukernel + Chess.
self.register(
MatmulTrunci(
256,
128,
32,
"i8",
"i32",
1 * np.ones([256, 32], dtype=np.int8),
1 * np.ones([32, 128], dtype=np.int8),
32 * np.ones([256, 128], dtype=np.int8),
test_params=TestParams(
tile_pipeline="pack-peel-4-level-tiling",
run_on_target=["npu4"],
aie_compilation_flags=[
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=8",
],
use_chess=True,
use_ukernel=True,
),
)
)
# Matmul with truncf test(s):
for tile_pipeline in ["pack-peel", "pack-peel-4-level-tiling"]:
self.register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -802,10 +802,11 @@ struct ToMinorIdentityTransferReadPattern
/// %1 = arith.truncf %0 : vector<6xf32> to vector<6xbf16>
/// %2 = vector.shape_cast %1 : vector<6xbf16> to vector<2x3xbf16>
// clang-format on
struct FlattenArithTruncFOpPattern : public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
template <typename TruncOpTy>
struct FlattenArithTruncOpPattern : public OpRewritePattern<TruncOpTy> {
using OpRewritePattern<TruncOpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::TruncFOp op,
LogicalResult matchAndRewrite(TruncOpTy op,
PatternRewriter &rewriter) const override {
// Get old shape type.
auto oldShapedType = dyn_cast<VectorType>(op.getType());
Expand All @@ -826,7 +827,7 @@ struct FlattenArithTruncFOpPattern : public OpRewritePattern<arith::TruncFOp> {
Value newInputVector = rewriter.create<vector::ShapeCastOp>(
op.getLoc(), newVectorTypeForInput, origInputOfTruncFOp);
// Create new base operation with the linearized input/output.
Value newTruncFOp = rewriter.create<arith::TruncFOp>(
Value newTruncFOp = rewriter.create<TruncOpTy>(
op.getLoc(), newVectorTypeForOutput, newInputVector);
// Delinearize the output back to the original type.
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(),
Expand Down Expand Up @@ -1054,11 +1055,12 @@ struct CanonicalizeVectorForAIEVecPass

{
RewritePatternSet patterns(context);
patterns
.add<ExtractTransposeFromContractionOp, FlattenArithTruncFOpPattern,
ToMinorIdentityTransferReadPattern,
ToMinorIdentityTransferWritePattern,
ConvertLeadingUnitDimInsertToReshapePattern>(context);
patterns.add<ExtractTransposeFromContractionOp,
FlattenArithTruncOpPattern<arith::TruncFOp>,
FlattenArithTruncOpPattern<arith::TruncIOp>,
ToMinorIdentityTransferReadPattern,
ToMinorIdentityTransferWritePattern,
ConvertLeadingUnitDimInsertToReshapePattern>(context);
patterns.add<ConvertSplatTransferReadToBroadcastPattern>(context);
patterns
.add<copied_from_mlir::FlattenContiguousRowMajorTransferReadPattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,19 @@ func.func @arith_truncf(%inp: vector<2x3xf32>) -> vector<2x3xbf16> {

// -----

// CHECK-LABEL: @arith_trunci(
// CHECK-SAME: %[[INP:.*]]: vector<2x3xi32>)
func.func @arith_trunci(%inp: vector<2x3xi32>) -> vector<2x3xi8> {
// CHECK: %[[LINEARIZE:.*]] = vector.shape_cast %[[INP]] : vector<2x3xi32> to vector<6xi32>
// CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LINEARIZE]] : vector<6xi32> to vector<6xi8>
// CHECK: %[[DELINEARIZE:.*]] = vector.shape_cast %[[TRUNCI]] : vector<6xi8> to vector<2x3xi8>
// CHECK: return %[[DELINEARIZE]]
%0 = arith.trunci %inp : vector<2x3xi32> to vector<2x3xi8>
return %0 : vector<2x3xi8>
}

// -----

// CHECK: #map = affine_map<()[s0] -> (s0 * 256 + 96)>
// CHECK-LABEL: @trivial_read_access
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8x4x8xbf16, strided<[256, 32, 8, 1]>>,
Expand Down
Loading

0 comments on commit 8d01ec9

Please sign in to comment.