Skip to content

Commit

Permalink
Optimize copying nalgebra matrices into NumPy arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Sep 10, 2022
1 parent 7c07bff commit 6e47b97
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,14 @@ where

fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
unsafe {
let array = PyArray::new(py, (self.nrows(), self.ncols()), false);
for r in 0..self.nrows() {
for c in 0..self.ncols() {
*array.uget_mut((r, c)) = self.get_unchecked((r, c)).clone();
let array = PyArray::<N, _>::new(py, (self.nrows(), self.ncols()), true);
let mut data_ptr = array.data();
if self.data.is_contiguous() {
ptr::copy_nonoverlapping(self.data.ptr(), data_ptr, self.len());
} else {
for item in self.iter() {
data_ptr.write(item.clone());
data_ptr = data_ptr.add(1);
}
}
array
Expand Down
10 changes: 10 additions & 0 deletions tests/to_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ fn slice_container_type_confusion() {
#[test]
fn matrix_to_numpy() {
let matrix = nalgebra::Matrix3::<i32>::new(0, 1, 2, 3, 4, 5, 6, 7, 8);
assert!(nalgebra::storage::RawStorage::is_contiguous(&matrix.data));

Python::with_gil(|py| {
let array = matrix.to_pyarray(py);
Expand All @@ -312,4 +313,13 @@ fn matrix_to_numpy() {
array![[0, 1, 2], [3, 4, 5], [6, 7, 8]],
);
});

let matrix = matrix.row(0);
assert!(!nalgebra::storage::RawStorage::is_contiguous(&matrix.data));

Python::with_gil(|py| {
let array = matrix.to_pyarray(py);

assert_eq!(array.readonly().as_array(), array![[0, 1, 2]]);
});
}

0 comments on commit 6e47b97

Please sign in to comment.