diff --git a/sparselt/linear_transform.py b/sparselt/linear_transform.py index 98971b9..e9f229c 100644 --- a/sparselt/linear_transform.py +++ b/sparselt/linear_transform.py @@ -23,11 +23,13 @@ def __init__(self, weights, row_ind, col_ind, self._input_core_dims = tuple(input_transform_dims[0]) self._input_core_shape = tuple(input_transform_dims[1]) + input_size = np.product(self._input_core_shape) self._output_core_dims = tuple(output_transform_dims[0]) self._output_core_shape = tuple(output_transform_dims[1]) + output_size = np.product(self._output_core_shape) - self._matrix = scipy.sparse.csr_matrix((weights, (row_ind, col_ind))) + self._matrix = scipy.sparse.csr_matrix((weights, (row_ind, col_ind)), shape=(output_size, input_size)) self._order = order self._vfunc = self._create_vfunc()