diff --git a/src/ops/layout.rs b/src/ops/layout.rs index e3f30a78..752cf324 100644 --- a/src/ops/layout.rs +++ b/src/ops/layout.rs @@ -194,9 +194,13 @@ impl Operator for Expand { } fn flattened_shape(shape: &[usize], axis: isize) -> Result<[usize; 2], OpError> { - let resolved_axis = resolve_axis(shape.len(), axis)?; - let outer_size = shape.iter().take(resolved_axis).product(); - let inner_size = shape.iter().skip(resolved_axis).product(); + let outer_dims = if axis == shape.len() as isize { + shape.len() + } else { + resolve_axis(shape.len(), axis)? + }; + let outer_size = shape.iter().take(outer_dims).product(); + let inner_size = shape.iter().skip(outer_dims).product(); Ok([outer_size, inner_size]) } @@ -850,17 +854,74 @@ mod tests { fn test_flatten() { let pool = new_pool(); - let input = Tensor::from_data(&[1, 5, 1, 1], vec![1, 2, 3, 4, 5]); - let result = flatten(&pool, input.view(), 1 /* axis */).unwrap(); - assert_eq!(result.shape(), &[1, 5]); + struct Case { + shape: Vec, + axis: isize, + expected: Result, OpError>, + } - let input = Tensor::from_data(&[2, 3, 1, 4], (1..=24).collect::>()); - let result = flatten(&pool, input.view(), 2 /* axis */).unwrap(); - assert_eq!(result.shape(), &[6, 4]); + let cases = [ + Case { + shape: [1, 5, 1, 1].into(), + axis: 1, + expected: Ok([1, 5].into()), + }, + Case { + shape: [2, 3, 1, 4].into(), + axis: 2, + expected: Ok([6, 4].into()), + }, + // Axis = 0 + Case { + shape: [2, 3, 1, 4].into(), + axis: 0, + expected: Ok([1, 24].into()), + }, + // Axis equal to rank of input + Case { + shape: [2, 2].into(), + axis: 2, + expected: Ok([4, 1].into()), + }, + // Negative values count from the back + Case { + shape: [2, 3, 4].into(), + axis: -1, + expected: Ok([6, 4].into()), + }, + Case { + shape: [2, 3, 4].into(), + axis: -2, + expected: Ok([2, 12].into()), + }, + Case { + shape: [2, 3, 4].into(), + axis: -3, + expected: Ok([1, 24].into()), + }, + // Values outside `[-r, r]` are invalid + Case { + shape: [2, 3, 4].into(), + axis: 4, + expected: Err(OpError::InvalidValue("Axis is invalid")), + }, + Case { + shape: [2, 3, 4].into(), + axis: -4, + expected: Err(OpError::InvalidValue("Axis is invalid")), + }, + ]; - // Case when `axis` is zero, first output dim should always be 1 - let result = flatten(&pool, input.view(), 0 /* axis */).unwrap(); - assert_eq!(result.shape(), &[1, 24]); + for Case { + shape, + axis, + expected, + } in cases + { + let input = Tensor::::zeros(shape.as_slice()); + let result = flatten(&pool, input.view(), axis).map(|tensor| tensor.shape().to_vec()); + assert_eq!(result, expected); + } } #[test]