Skip to content

Commit

Permalink
Merge pull request #577 from robertknight/flatten-axis-equal-to-rank
Browse files Browse the repository at this point in the history
Support `axis` equal to tensor rank in `Flatten` operator
  • Loading branch information
robertknight authored Feb 4, 2025
2 parents 0760757 + 4236297 commit eaaf34b
Showing 1 changed file with 73 additions and 12 deletions.
85 changes: 73 additions & 12 deletions src/ops/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,13 @@ impl Operator for Expand {
}

fn flattened_shape(shape: &[usize], axis: isize) -> Result<[usize; 2], OpError> {
let resolved_axis = resolve_axis(shape.len(), axis)?;
let outer_size = shape.iter().take(resolved_axis).product();
let inner_size = shape.iter().skip(resolved_axis).product();
let outer_dims = if axis == shape.len() as isize {
shape.len()
} else {
resolve_axis(shape.len(), axis)?
};
let outer_size = shape.iter().take(outer_dims).product();
let inner_size = shape.iter().skip(outer_dims).product();
Ok([outer_size, inner_size])
}

Expand Down Expand Up @@ -850,17 +854,74 @@ mod tests {
fn test_flatten() {
let pool = new_pool();

let input = Tensor::from_data(&[1, 5, 1, 1], vec![1, 2, 3, 4, 5]);
let result = flatten(&pool, input.view(), 1 /* axis */).unwrap();
assert_eq!(result.shape(), &[1, 5]);
struct Case {
shape: Vec<usize>,
axis: isize,
expected: Result<Vec<usize>, OpError>,
}

let input = Tensor::from_data(&[2, 3, 1, 4], (1..=24).collect::<Vec<_>>());
let result = flatten(&pool, input.view(), 2 /* axis */).unwrap();
assert_eq!(result.shape(), &[6, 4]);
let cases = [
Case {
shape: [1, 5, 1, 1].into(),
axis: 1,
expected: Ok([1, 5].into()),
},
Case {
shape: [2, 3, 1, 4].into(),
axis: 2,
expected: Ok([6, 4].into()),
},
// Axis = 0
Case {
shape: [2, 3, 1, 4].into(),
axis: 0,
expected: Ok([1, 24].into()),
},
// Axis equal to rank of input
Case {
shape: [2, 2].into(),
axis: 2,
expected: Ok([4, 1].into()),
},
// Negative values count from the back
Case {
shape: [2, 3, 4].into(),
axis: -1,
expected: Ok([6, 4].into()),
},
Case {
shape: [2, 3, 4].into(),
axis: -2,
expected: Ok([2, 12].into()),
},
Case {
shape: [2, 3, 4].into(),
axis: -3,
expected: Ok([1, 24].into()),
},
// Values outside `[-r, r]` are invalid
Case {
shape: [2, 3, 4].into(),
axis: 4,
expected: Err(OpError::InvalidValue("Axis is invalid")),
},
Case {
shape: [2, 3, 4].into(),
axis: -4,
expected: Err(OpError::InvalidValue("Axis is invalid")),
},
];

// Case when `axis` is zero, first output dim should always be 1
let result = flatten(&pool, input.view(), 0 /* axis */).unwrap();
assert_eq!(result.shape(), &[1, 24]);
for Case {
shape,
axis,
expected,
} in cases
{
let input = Tensor::<f32>::zeros(shape.as_slice());
let result = flatten(&pool, input.view(), axis).map(|tensor| tensor.shape().to_vec());
assert_eq!(result, expected);
}
}

#[test]
Expand Down

0 comments on commit eaaf34b

Please sign in to comment.