Skip to content

Commit

Permalink
Addressing review comments
Browse files Browse the repository at this point in the history
Change-Id: I284b1f2c121051e672f548d6c6ee2a3267854e31
  • Loading branch information
Giuseppe Rossini committed Sep 25, 2020
1 parent 49ee849 commit e1d2238
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 106 deletions.
32 changes: 16 additions & 16 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
)
elif kernel_layout == "HWIO":
is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
has_dot_prod = topi.arm_cpu.arm_utils.is_fast_int8_on_arm()
has_dot_prod = topi.arm_cpu.arm_utils.is_dotprod_available()
if has_dot_prod and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_hybrid),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_hybrid),
name="conv2d_NHWC_quantized_hybrid.arm_cpu",
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
name="conv2d_NHWC_quantized_native.arm_cpu",
)
if is_aarch64 and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="conv2d_NHWC_quantized.arm_cpu",
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
)
if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
# TODO
# TODO(@giuseros)
# This strategy errors out for quantized data types when tuning.
# Let's use this only for non-aarch64 or non-quantized cases
strategy.add_implementation(
Expand Down Expand Up @@ -329,18 +329,18 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
data = inputs[0]
strategy = _op.OpStrategy()

interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_without_transform
hybrid_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_hybrid_without_transform
interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform
native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform
if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="conv2d_NHWC_quantized_without_transform.arm_cpu",
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
strategy.add_implementation(
wrap_compute_conv2d_gemm(hybrid_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_hybrid),
name="conv2d_NHWC_quantized_hybrid_without_transform.arm_cpu",
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
else:
raise RuntimeError(
Expand Down
47 changes: 46 additions & 1 deletion python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tvm


def is_fast_int8_on_arm():
def is_dotprod_available():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
target = tvm.target.Target.current(allow_none=False)
return "+v8.2a" in target.mattr and "+dotprod" in target.mattr
Expand All @@ -30,3 +30,48 @@ def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
return "aarch64" in target.attrs.get("mtriple", "")


def get_tiling_B_interleaved_t(interleave_A):
""" Compute the tiling information for matrix B', where B'
is the transposed and interleaved version of matrix B in C=A*B.
The tiling information is chosen to maximize register usage during the
tile computation.
Please refer to:
- https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
- Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
In order to have more information
Parameters
----------
interleave_A: bool
determines if A is expected to be interleaved
Returns
----------
tile_rows_B: the output tile rows of B'
tile_cols_B: the output tile columns of B'
"""
if is_dotprod_available():
# The number of tile rows of B' vary depending on the
# strategy:
# * If we are interleaving A, then we select 12 columns from B'(i.e.,
# 12 rows from B).
# * If we are not interleaving A, then we select 16 columns from B'(i.e.,
# 16 rows from B).
tile_rows_B = 12 if interleave_A else 16

# Dot product instruction groups 2 (u)int16x8 vectors in
# groups of 4 and compute the dot product among those groups
# This means that the number of columns in a tile of B' (i.e., the
# rows of the original matrix B) need to be 4.
tile_cols_B = 4
else:
# If dot product is not available, A must be interleaved. In this case
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_rows_B = 4
tile_cols_B = 16

return tile_rows_B, tile_cols_B
52 changes: 31 additions & 21 deletions python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,31 @@
from ..nn import conv2d_alter_layout
from ..util import get_const_tuple
from ..x86.conv2d import _get_default_config as _get_x86_default_config
from .arm_utils import is_fast_int8_on_arm
from .arm_utils import get_tiling_B_interleaved_t

logger = logging.getLogger("topi")


def interleave_transpose_B(inputs, data, kernel, interleave_A):
"""Return the new placeholder and the expression that represent
the matrix B transposed and interleaved"""

def interleave_transpose_weights(inputs, data, kernel, interleave_A):
""" Transform the weight matrix by reshaping, interleaving and transposing it
Parameters
----------
inputs : tvm.relay.Expr
Grouped input symbols
data :
Input shape and dtype
kernel :
Input shape and dtype
interleave_A: indicates if we expect matrix A to be interleaved
Returns
----------
new_kernel : tvm.te.placeholder
A placeholder with the new shape
new_kernel_expr : tvm.relay.Expr
The relay expression of the weights
"""
assert (
data.dtype == "int8"
and kernel.dtype == "int8"
Expand All @@ -47,12 +63,8 @@ def interleave_transpose_B(inputs, data, kernel, interleave_A):
K = KH * KW * IC
N = OC

if is_fast_int8_on_arm():
tile_rows_B = 12 if interleave_A else 16
tile_cols_B = 4
else:
tile_rows_B = 4
tile_cols_B = 16
# Get tiling information for the interleaved transposed version of B
tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A)

pad_K = 0
pad_N = 0
Expand Down Expand Up @@ -321,12 +333,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu":
if topi_tmpl == "conv2d_NHWC_quantized_interleaved.arm_cpu":
assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, IC, OC = get_const_tuple(kernel.shape)
N = OC
new_workload_name = "conv2d_NHWC_quantized_without_transform.arm_cpu"
new_kernel, new_kernel_expr = interleave_transpose_B(
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu"
new_kernel, new_kernel_expr = interleave_transpose_weights(
inputs, data, kernel, interleave_A=True
)
new_workload = autotvm.task.args_to_workload(
Expand All @@ -338,12 +349,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return relay.nn.contrib_conv2d_gemm_without_weight_transform(
inputs[0], new_kernel_expr, **new_attrs
)
if topi_tmpl == "conv2d_NHWC_quantized_hybrid.arm_cpu":
if topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu":
assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, IC, OC = get_const_tuple(kernel.shape)
N = OC
new_workload_name = "conv2d_NHWC_quantized_hybrid_without_transform.arm_cpu"
new_kernel, new_kernel_expr = interleave_transpose_B(
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu"
new_kernel, new_kernel_expr = interleave_transpose_weights(
inputs, data, kernel, interleave_A=False
)
new_workload = autotvm.task.args_to_workload(
Expand Down
66 changes: 41 additions & 25 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
mmla_4x4_int8_int8_int32,
mmla_16x4_int8_int8_int32,
)
from .arm_utils import is_aarch64_arm, is_fast_int8_on_arm
from .arm_utils import is_aarch64_arm, is_dotprod_available


def configure_knobs(cfg, M, K):
Expand All @@ -48,7 +48,7 @@ def configure_knobs(cfg, M, K):
cfg["reorder_gemm"] = ReorderEntity([0, 1])
cfg["A_interleaved_unroll_vec"] = AnnotateEntity(["unroll", "vec"])

if not is_fast_int8_on_arm():
if not is_dotprod_available():
cfg.define_knob("gemm_quantized_unroll", [True, False])
cfg.define_knob("gemm_quantized_interleave", [True, False])

Expand Down Expand Up @@ -76,8 +76,7 @@ def compute_conv2d_gemm_without_weight_transform(

KH, KW = get_const_tuple(kernel_size)
OC = get_const_int(output_channels)

K_AREA = KH * KW
kernel_area = KH * KW

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
Expand All @@ -101,13 +100,13 @@ def compute_conv2d_gemm_without_weight_transform(
else:
data_pad = data

# --- Im2col
# Im2col
M = OH * OW
K = IC * K_AREA
K = IC * kernel_area
N = OC

A_shape = (batches, M, K)
if K_AREA == 1:
if kernel_area == 1:
A = tvm.topi.reshape(data_pad, A_shape)
else:
A = te.compute(
Expand All @@ -121,35 +120,48 @@ def compute_conv2d_gemm_without_weight_transform(
name="data_im2col",
)

# --- Pad if necessary
# Pad if necessary
N_transformed = B_interleaved_t.shape[0]
tile_rows_B = B_interleaved_t.shape[2]
tile_cols_B = B_interleaved_t.shape[3]

if is_fast_int8_on_arm() and interleave_A:
# Select the tiling strategy for A.
# The tiling information is chosen to maximize register usage during
# the tile computation.
#
# Please refer to:
# - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
# - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
# In order to have more information
#
if is_dotprod_available() and interleave_A:
# If dot product has been enabled, and we are interleaving A
# tile size should be 8x4
tile_rows_A = 8
tile_cols_A = 4
else:
# If either there is no dot product or if we are using a native strategy
# tile size should be 4x16
tile_rows_A = 4
tile_cols_A = 16

pad_m = 0
pad_k = 0
pad_M = 0
pad_K = 0

if M % tile_rows_A != 0:
pad_m = tile_rows_A - (M % tile_rows_A)
pad_M = tile_rows_A - (M % tile_rows_A)

if K % tile_cols_A != 0:
pad_k = tile_cols_A - (K % tile_cols_A)
pad_K = tile_cols_A - (K % tile_cols_A)

M_padded = M + pad_m
K_padded = K + pad_k
M_padded = M + pad_M
K_padded = K + pad_K
N_padded = N_transformed * tile_rows_B

pad_before = (0, 0, 0)
pad_after = (0, pad_m, pad_k)
pad_after = (0, pad_M, pad_K)

if pad_m != 0 or pad_k != 0:
if pad_M != 0 or pad_K != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")

idxm = tvm.tir.indexmod
Expand All @@ -159,13 +171,13 @@ def compute_conv2d_gemm_without_weight_transform(
# Configuration space
configure_knobs(cfg, M_padded, K_padded)

# Pack A
# Pack the input data
A_interleaved = te.compute(
(batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
name="A_interleaved",
)
# Compute C
# Execute GEMM
C_interleaved = te.compute(
(batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
lambda b, x, y, w, z: te.sum(
Expand All @@ -175,7 +187,7 @@ def compute_conv2d_gemm_without_weight_transform(
),
name="C_interleaved",
)
# Unpack C
# Unpack the result
C = te.compute(
(batches, M, N),
lambda b, x, y: C_interleaved[
Expand All @@ -185,7 +197,7 @@ def compute_conv2d_gemm_without_weight_transform(
)
zero = tvm.tir.const(0)
else:
# No need to pack/unpack
# No need to pack/unpack, execute GEMM directly
C = te.compute(
(batches, M_padded, N_padded),
lambda b, x, y: te.sum(
Expand All @@ -197,12 +209,16 @@ def compute_conv2d_gemm_without_weight_transform(
),
name="C",
)

# We need to ensure that infer bound pass does not remove the padding
# which is necessary for the tensorizations to work. So we need to
# add a dummy reference to the padding area of the result
zero = (
tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
- tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
)

# --- Produce the conv output
# Reshape the result into a convolution output
out_shape = (batches, OH, OW, OC)
out = te.compute(
out_shape,
Expand All @@ -212,7 +228,7 @@ def compute_conv2d_gemm_without_weight_transform(
return out


def schedule_conv2d_gemm(cfg, s, out, final_out):
def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
""" Schedule the conv2d_gemm interleaved strategy """
C = out.op.input_tensors[0]
C_interleaved = C.op.input_tensors[0]
Expand Down Expand Up @@ -255,7 +271,7 @@ def schedule_conv2d_gemm(cfg, s, out, final_out):

k = C_interleaved.op.reduce_axis[0]
_, M, N = C.shape
if is_fast_int8_on_arm():
if is_dotprod_available():
mmla = mmla_4x4_int8_int8_int32(in_type)
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
xi, yi, x_factor=8, y_factor=4
Expand Down Expand Up @@ -299,7 +315,7 @@ def schedule_conv2d_gemm(cfg, s, out, final_out):
return s


def schedule_conv2d_gemm_hybrid(cfg, s, out, final_out):
def schedule_conv2d_gemm_native(cfg, s, out, final_out):
""" Schedule the conv2d_gemm hybrid strategy """
C = out.op.input_tensors[0]
A = C.op.input_tensors[0]
Expand Down
Loading

0 comments on commit e1d2238

Please sign in to comment.