diff --git a/src/ops/gather.rs b/src/ops/gather.rs index e4c449ab..08938c00 100644 --- a/src/ops/gather.rs +++ b/src/ops/gather.rs @@ -26,7 +26,8 @@ pub fn gather( let axis = resolve_axis(input.ndim(), axis)?; for index in indices.iter().copied() { - if index < 0 || index >= input.size(axis) as i32 { + let size = input.size(axis) as i32; + if index < -size || index >= size { return Err(OpError::InvalidValue("Entry in `indices` is out of range")); } } @@ -500,6 +501,13 @@ mod tests { let result = gather(input.view(), 1, indices.view()).unwrap(); expect_equal(&result, &expected)?; + // Negative index values. + let input = Tensor::from([1, 2, 3]); + let indices = Tensor::from([-1, -2, -3]); + let expected = Tensor::from([3, 2, 1]); + let result = gather(input.view(), 0, indices.view()).unwrap(); + assert_eq!(&result, &expected); + Ok(()) }