diff --git a/src/ops/pad.rs b/src/ops/pad.rs index 5708f5a8..f21f833c 100644 --- a/src/ops/pad.rs +++ b/src/ops/pad.rs @@ -78,9 +78,12 @@ pub fn pad( let pad_dims = input.ndim() - batch_dims; let (pad_top, pad_left) = if pad_dims == 1 { - (0, padding[[0]] as usize) + (0, padding[[batch_dims]] as usize) } else { - (padding[[0]] as usize, padding[[1]] as usize) + ( + padding[[batch_dims]] as usize, + padding[[batch_dims + 1]] as usize, + ) }; let mut input = input.view(); @@ -338,12 +341,24 @@ mod tests { pads: NdTensor::from([]), expected: Ok(Tensor::from(2.)), }, + // Pad start columns of a 3D tensor. + Case { + input: [[[1., 2., 3.]]].into(), + pads: [0, 0, 2, 0, 0, 0].into(), + expected: Ok(Tensor::from([[[3., 2., 1., 2., 3.]]])), + }, // Pad end columns of a 3D tensor. Case { input: [[[1., 2., 3.]]].into(), pads: [0, 0, 0, 0, 0, 2].into(), expected: Ok(Tensor::from([[[1., 2., 3., 2., 1.]]])), }, + // Pad start rows of a 3D tensor. + Case { + input: [[[1.], [2.], [3.]]].into(), + pads: [0, 2, 0, 0, 0, 0].into(), + expected: Ok(Tensor::from([[[3.], [2.], [1.], [2.], [3.]]])), + }, // Pad channel dimension of a 3D tensor. Case { input: [[[1., 2., 3.]]].into(),