From 7cf7adff444b88a9a661219b7cfc6e9fb61dd98f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 13 Aug 2021 04:54:43 +0300 Subject: [PATCH] [Torch] chunk and unsafe chunk (#8718) * alternative chunk op was implemented in pytorch frontend. aten::unsafe_chunk was added to op map in pytorch frontend * chunk was replaced by new one in pytorch frontend. it is faster in 2.5 times Co-authored-by: Valery Chernov --- python/tvm/relay/frontend/pytorch.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 83ee1d3377f4..9406c3b2ea9b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1594,28 +1594,11 @@ def chunk(self, inputs, input_types): else: unif_size = int(dim / num_chunks) - chunks = [] - for i in range(0, dim, unif_size): - begin = [0] * len(shape) - end = shape[:] - begin[axis] = i - end[axis] = i + unif_size - stride = [1] * len(shape) + indeces = [] + for i in range(unif_size, dim, unif_size): + indeces.append(i) - chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride) - chunks.append(chunk_out) - - if dim % num_chunks: - begin = [0] * len(shape) - end = shape[:] - begin[axis] = unif_size * (num_chunks - 1) - end[axis] = dim - stride = [1] * len(shape) - - chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride) - chunks.append(chunk_out) - - return chunks + return _op.split(data, indeces, axis) def matmul(self, inputs, input_types): @@ -2681,6 +2664,7 @@ def create_convert_map(self): "aten::alpha_dropout": self.dropout, "aten::mean": self.mean, "aten::chunk": self.chunk, + "aten::unsafe_chunk": self.chunk, "aten::matmul": self.matmul, "aten::bmm": self.matmul, "aten::expand": self.expand,