Skip to content

Commit

Permalink
[1] Review comments handled
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed Oct 18, 2020
1 parent e11bd5b commit 73271f1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
12 changes: 7 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,8 @@ def _impl(inputs, attr, params, mod):


def _sparse_tensor_dense_matmul():
# Sparse utility from Numpy
from scipy import sparse
# Sparse utility from scipy
from scipy.sparse import csr_matrix

def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"
Expand All @@ -906,11 +906,11 @@ def _impl(inputs, attr, params, mod):
rows = [x[0] for x in indices_tensor]
cols = [x[1] for x in indices_tensor]

# Create Numpy sparse Tensor(CSR)
weight_sp = sparse.csr_matrix(
# Create scipy sparse Tensor(CSR)
weight_sp = csr_matrix(
(values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
)
weight_sp = sparse.csr_matrix(weight_sp.transpose())
weight_sp = csr_matrix(weight_sp.transpose())

weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
Expand All @@ -922,6 +922,8 @@ def _impl(inputs, attr, params, mod):
# TODO: Support other adjoint option too
if attr.get("adjoint_a") and attr.get("adjoint_b"):
ret = _op.transpose(ret)
else:
raise tvm.error.OpAttributeUnImplemented("Adjoint option is not supported yet.")

return ret

Expand Down
2 changes: 1 addition & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,7 +1763,7 @@ def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=Fal
for adjoint_b in [False]:
with tf.Graph().as_default():
A_sp = tf.sparse.SparseTensor(
indices=[[0, 0], [1, 2]], values=[4.0, 8.0], dense_shape=A_shape
indices=indices, values=values, dense_shape=A_shape
)
B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")

Expand Down

0 comments on commit 73271f1

Please sign in to comment.