Skip to content

Commit

Permalink
Update linear_transform.py
Browse files Browse the repository at this point in the history
  • Loading branch information
LiamBindle authored May 12, 2022
1 parent e7affa4 commit 1e537c5
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sparselt/linear_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def __init__(self, weights, row_ind, col_ind,

self._input_core_dims = tuple(input_transform_dims[0])
self._input_core_shape = tuple(input_transform_dims[1])
input_size = np.product(self._input_core_shape)

self._output_core_dims = tuple(output_transform_dims[0])
self._output_core_shape = tuple(output_transform_dims[1])
output_size = np.product(self._output_core_shape)

self._matrix = scipy.sparse.csr_matrix((weights, (row_ind, col_ind)))
self._matrix = scipy.sparse.csr_matrix((weights, (row_ind, col_ind)), shape=(output_size, input_size))
self._order = order

self._vfunc = self._create_vfunc()
Expand Down

0 comments on commit 1e537c5

Please sign in to comment.