Skip to content

Commit

Permalink
Change dimensions order in the image shape from WHC to HWC
Browse files Browse the repository at this point in the history
  • Loading branch information
juliabeliaeva committed Apr 24, 2023
1 parent a8b578f commit 1548e0e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ internal fun BufferedImage.copy(): BufferedImage {
}

internal fun BufferedImage.getShape(): TensorShape {
return TensorShape(width.toLong(), height.toLong(), colorModel.numComponents.toLong())
return TensorShape(height.toLong(), width.toLong(), colorModel.numComponents.toLong())
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class PreprocessingFinalShapeTest {
val image = BufferedImage(10, 20, BufferedImage.TYPE_3BYTE_BGR)
val (_, actualShape) = preprocess.apply(image)

assertEquals(actualShape, preprocess.getOutputShape(TensorShape(10, 20, 1)))
assertEquals(actualShape, preprocess.getOutputShape(TensorShape(20, 10, 1)))
}

@Test
Expand All @@ -139,6 +139,6 @@ class PreprocessingFinalShapeTest {
val image = BufferedImage(10, 20, BufferedImage.TYPE_3BYTE_BGR)
val (_, actualShape) = preprocess.apply(image)

assertEquals(actualShape, preprocess.getOutputShape(TensorShape(10, 20, 3)))
assertEquals(actualShape, preprocess.getOutputShape(TensorShape(20, 10, 3)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class PreprocessingImageTest {
inputImage.setRGB(1, 1, Color.RED.rgb)
val (imageFloats, tensorShape) = preprocess.apply(inputImage)

Assertions.assertEquals(TensorShape(9, 5, 3), tensorShape)
Assertions.assertEquals(TensorShape(5, 9, 3), tensorShape)

val expectedImage = FloatArray(tensorShape.numElements().toInt()) { Color.GRAY.red / 255f }
expectedImage.setRGB(3, 1, Color.BLUE, tensorShape, ColorMode.BGR)
Expand Down Expand Up @@ -198,7 +198,7 @@ class PreprocessingImageTest {
}
}
for (i in colorComponents.indices) {
set3D(y, x, i, tensorShape[0].toInt(), colorMode.channels, colorComponents[i])
set3D(y, x, i, tensorShape[1].toInt(), colorMode.channels, colorComponents[i])
}
}
}
Expand Down

0 comments on commit 1548e0e

Please sign in to comment.