From 78196a88018bad65df0adbf925ffc52862bd6abf Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Mon, 14 Jun 2021 21:16:31 +0200 Subject: [PATCH 01/12] Added missing saving functions for ReLU and ELU activation layers (JetBrains#78) --- .../dl/api/inference/keras/ModelSaver.kt | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt index f4ee80356..78609616b 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt @@ -12,10 +12,7 @@ import org.jetbrains.kotlinx.dl.api.core.Sequential import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.* import org.jetbrains.kotlinx.dl.api.core.layer.Layer -import org.jetbrains.kotlinx.dl.api.core.layer.activation.PReLU -import org.jetbrains.kotlinx.dl.api.core.layer.activation.LeakyReLU -import org.jetbrains.kotlinx.dl.api.core.layer.activation.Softmax -import org.jetbrains.kotlinx.dl.api.core.layer.activation.ThresholdedReLU +import org.jetbrains.kotlinx.dl.api.core.layer.activation.* import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.* import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense @@ -28,6 +25,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D import org.jetbrains.kotlinx.dl.api.core.regularizer.L2L1 import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer import org.jetbrains.kotlinx.dl.api.inference.keras.config.* +import org.tensorflow.op.nn.Elu import java.io.File /** @@ -86,6 +84,8 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i is Input -> createKerasInput(layer) is BatchNorm -> createKerasBatchNorm(layer, isKerasFullyCompatible) is ActivationLayer -> createKerasActivationLayer(layer) + is ELU -> createKerasELU(layer) + is ReLU -> createKerasReLU(layer) is PReLU -> createKerasPReLULayer(layer, isKerasFullyCompatible) is LeakyReLU -> createKerasLeakyReLU(layer) is ThresholdedReLU -> createKerasThresholdedReLULayer(layer) @@ -221,6 +221,26 @@ private fun createKerasActivationLayer(layer: ActivationLayer): KerasLayer { return KerasLayer(class_name = LAYER_ACTIVATION, config = configX) } +private fun createKerasReLU(layer: ReLU): KerasLayer { + val configX = LayerConfig( + dtype = DATATYPE_FLOAT32, + max_value = layer.maxValue?.toDouble(), + negative_slope = layer.negativeSlope.toDouble(), + threshold = layer.threshold.toDouble(), + name = layer.name + ) + return KerasLayer(class_name = LAYER_RELU, config = configX) +} + +private fun createKerasELU(layer: ELU): KerasLayer { + val configX = LayerConfig( + dtype = DATATYPE_FLOAT32, + alpha = layer.alpha.toDouble(), + name = layer.name + ) + return KerasLayer(class_name = LAYER_ELU, config = configX) +} + private fun createKerasPReLULayer(layer: PReLU, isKerasFullyCompatible: Boolean): KerasLayer { val configX = LayerConfig( dtype = DATATYPE_FLOAT32, @@ -604,4 +624,4 @@ private fun createKerasZeroPadding2D(layer: ZeroPadding2D): KerasLayer { padding = KerasPadding.ZeroPadding2D(layer.padding) ) return KerasLayer(class_name = LAYER_ZERO_PADDING_2D, config = configX) -} +} \ No newline at end of file From 6347d656d898fde3d2084e9b9489f6cddf1ed150 Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Mon, 14 Jun 2021 21:23:56 +0200 Subject: [PATCH 02/12] Reverted changes to the imports --- .../kotlinx/dl/api/inference/keras/ModelSaver.kt | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt index 78609616b..9e9adcd7f 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt @@ -12,7 +12,12 @@ import org.jetbrains.kotlinx.dl.api.core.Sequential import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.* import org.jetbrains.kotlinx.dl.api.core.layer.Layer -import org.jetbrains.kotlinx.dl.api.core.layer.activation.* +import org.jetbrains.kotlinx.dl.api.core.layer.activation.ELU +import org.jetbrains.kotlinx.dl.api.core.layer.activation.ReLU +import org.jetbrains.kotlinx.dl.api.core.layer.activation.PReLU +import org.jetbrains.kotlinx.dl.api.core.layer.activation.LeakyReLU +import org.jetbrains.kotlinx.dl.api.core.layer.activation.Softmax +import org.jetbrains.kotlinx.dl.api.core.layer.activation.ThresholdedReLU import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.* import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense @@ -25,7 +30,6 @@ import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D import org.jetbrains.kotlinx.dl.api.core.regularizer.L2L1 import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer import org.jetbrains.kotlinx.dl.api.inference.keras.config.* -import org.tensorflow.op.nn.Elu import java.io.File /** From e1f98d9c3aafd79d0bef5a0531a82122ddb9b165 Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Thu, 17 Jun 2021 14:46:58 +0200 Subject: [PATCH 03/12] Added RepeatVector layer #123 --- .../api/core/layer/reshaping/RepeatVector.kt | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt new file mode 100644 index 000000000..92df8079d --- /dev/null +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt @@ -0,0 +1,72 @@ +/* + * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.core.layer.reshaping + +import org.jetbrains.kotlinx.dl.api.core.KGraph +import org.jetbrains.kotlinx.dl.api.core.layer.Layer +import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape +import org.tensorflow.Operand +import org.tensorflow.Shape +import org.tensorflow.op.Ops +import org.tensorflow.op.core.Tile + +/** + * Layer that repeats the input [n] times. + * + * Input shape: `2D tensor of shape (num_samples, features)`. + * + * Output shape: `3D tensor of shape (num_samples, n, features)`. + * + * @property n Repetition factor. + * @property [name] Custom layer name. + * @constructor Creates [RepeatVector] object. + * + * @author Stan van der Bend + * @since 0.2 + */ +public class RepeatVector( + public val n: Int, + name: String = "" +) : Layer(name) { + + override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + //left empty + } + + override fun computeOutputShape(inputShape: Shape): Shape { + require(inputShape.numDimensions() == 2) { "input tensor must have 2 dimensions" } + val tensorShape = TensorShape(inputShape) + // TODO: maybe make `n` of type Long? + return Shape.make(tensorShape[0], n.toLong(), tensorShape[1]) + } + + override fun forward( + tf: Ops, + input: Operand, + isTraining: Operand, + numberOfLosses: Operand? + ): Operand { + return tf.repeat(input, n) + } + + private fun Ops.repeat(input: Operand, n : Int) : Tile { + val x = expandDims(input, constant(1)) + val pattern = stack(listOf(constant(1), constant(n), constant(1))) + return tile(x, pattern) + } + + override var weights: Map> + get() = emptyMap() + set(value) = assignWeights(value) + + override val hasActivation: Boolean get() = false + + override val paramCount: Int get() = 0 + + override fun toString(): String { + return "RepeatVector" + } +} \ No newline at end of file From 56445be532fed06c52867f1834519951fb8a00fa Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Thu, 17 Jun 2021 14:59:38 +0200 Subject: [PATCH 04/12] Added serialisation support for RepeatVector layer #123 --- .../kotlinx/dl/api/inference/keras/KerasConstants.kt | 1 + .../kotlinx/dl/api/inference/keras/ModelLoader.kt | 6 ++++++ .../kotlinx/dl/api/inference/keras/ModelSaver.kt | 12 ++++++++++++ .../dl/api/inference/keras/config/LayerConfig.kt | 2 ++ 4 files changed, 21 insertions(+) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt index 87999086f..9cfb3290c 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt @@ -37,6 +37,7 @@ internal const val LAYER_DROPOUT: String = "Dropout" // Attention layers // Reshaping layers internal const val LAYER_FLATTEN: String = "Flatten" +internal const val LAYER_REPEAT_VECTOR: String = "RepeatVector" internal const val LAYER_RESHAPE: String = "Reshape" internal const val LAYER_ZERO_PADDING_2D = "ZeroPadding2D" internal const val LAYER_CROPPING_2D = "Cropping2D" diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt index 7eb509c16..983da7fcb 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt @@ -22,6 +22,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.pooling.* import org.jetbrains.kotlinx.dl.api.core.layer.regularization.Dropout import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Cropping2D import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten +import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Reshape import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D import org.jetbrains.kotlinx.dl.api.core.regularizer.L1 @@ -141,6 +142,7 @@ private fun convertToLayer( // Attention layers // Reshaping layers LAYER_FLATTEN -> createFlattenLayer(kerasLayer.config!!.name!!) + LAYER_REPEAT_VECTOR -> createRepeatVectorLayer(kerasLayer.config!!, kerasLayer.config.name!!) LAYER_RESHAPE -> createReshapeLayer(kerasLayer.config!!, kerasLayer.config.name!!) LAYER_CROPPING_2D -> createCropping2DLayer(kerasLayer.config!!, kerasLayer.config.name!!) LAYER_ZERO_PADDING_2D -> createZeroPadding2DLayer(kerasLayer.config!!, kerasLayer.config.name!!) @@ -715,6 +717,10 @@ private fun createFlattenLayer(name: String): Layer { return Flatten(name = name) } +private fun createRepeatVectorLayer(config: LayerConfig, name: String): Layer { + return RepeatVector(name = name, n = config.n!!) +} + private fun createReshapeLayer(config: LayerConfig, name: String): Layer { return Reshape(name = name, targetShape = config.target_shape!!) } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt index d7f4e937d..22553bfc2 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt @@ -24,6 +24,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.merge.* import org.jetbrains.kotlinx.dl.api.core.layer.normalization.BatchNorm import org.jetbrains.kotlinx.dl.api.core.layer.pooling.* import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten +import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D import org.jetbrains.kotlinx.dl.api.core.regularizer.L2L1 import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer @@ -100,6 +101,7 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i // Attention layers // Reshaping layers is Flatten -> createKerasFlattenLayer(layer) + is RepeatVector -> createKerasRepeatVectorLayer(layer) is ZeroPadding2D -> createKerasZeroPadding2DLayer(layer) // Merging layers is Add -> createKerasAddLayer(layer) @@ -530,6 +532,16 @@ private fun createKerasFlattenLayer(layer: Flatten): KerasLayer { return KerasLayer(class_name = LAYER_FLATTEN, config = configX) } +private fun createKerasRepeatVectorLayer(layer: RepeatVector): KerasLayer { + val configX = LayerConfig( + data_format = CHANNELS_LAST, + dtype = DATATYPE_FLOAT32, + name = layer.name, + n = layer.n + ) + return KerasLayer(class_name = LAYER_REPEAT_VECTOR, config = configX) +} + private fun createKerasConcatenateLayer(layer: Concatenate): KerasLayer { val configX = LayerConfig( dtype = DATATYPE_FLOAT32, diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt index 62381cd3d..242f7384a 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt @@ -85,6 +85,8 @@ internal data class LayerConfig( @Json(serializeNull = false) val moving_variance_initializer: KerasInitializer? = null, @Json(serializeNull = false) + val n: Int? = null, + @Json(serializeNull = false) val name: String? = null, @Json(serializeNull = false) val negative_slope: Double? = null, From bf61c5db2a2308d6b7ea547a7b44ac7fb1ba34c4 Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Fri, 18 Jun 2021 11:47:43 +0200 Subject: [PATCH 05/12] Wrote test for RepeatVector #123 --- .../api/core/layer/RepeatVectorLayerTest.kt | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt new file mode 100644 index 000000000..3b2014015 --- /dev/null +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt @@ -0,0 +1,55 @@ +/* + * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.core.layer + +import org.jetbrains.kotlinx.dl.api.core.KGraph +import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector +import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.tensorflow.Graph +import org.tensorflow.Output +import org.tensorflow.op.Ops + +/** + * A test for the [RepeatVector] layer. + * + * @author Stan van der Bend + */ +internal class RepeatVectorLayerTest { + + @Test + fun `test output shape`(){ + val layer = RepeatVector(n = 2) + val x = Array(10) { FloatArray(10) { 1F } } + val y = layer(x) + Assertions.assertArrayEquals(intArrayOf(10, layer.n, 10), y.shape().toIntArray()) + } + + @Test + fun `test repetition output`(){ + val layer = RepeatVector(n = 2) + val x = Array(3) { FloatArray(3) { it.toFloat() } } + val y = layer(x) + val actual = y.tensor().copyTo(Array(3) { Array(layer.n) { FloatArray(3) } }) + val expected = arrayOf( + arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)), + arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)), + arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)) + ) + Assertions.assertArrayEquals(expected, actual) + } + + private operator fun RepeatVector.invoke(input : Array) : Output { + return Ops.create().let { tf -> + build(tf, KGraph(Graph().toGraphDef()), org.tensorflow.Shape.make(10, 10)) + val inputOp = tf.constant(input) + val isTraining = tf.constant(true) + val numberOfLosses = tf.constant(1.0f) + forward(tf, inputOp, isTraining, numberOfLosses).asOutput() + } + } +} From d83d44762692b7b5a40798a520bdaf1db5b977f2 Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Mon, 21 Jun 2021 00:52:47 +0200 Subject: [PATCH 06/12] Made changed requested by avan (see desc.) - added missing require check in init block of RepeatVector - updated docs - reformatted code - housekeeping --- .../api/core/layer/reshaping/RepeatVector.kt | 16 ++++++++++------ .../dl/api/core/layer/RepeatVectorLayerTest.kt | 18 +++++++----------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt index 92df8079d..76b23a476 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -8,6 +8,7 @@ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape +import org.jetbrains.kotlinx.dl.api.core.shape.toLongArray import org.tensorflow.Operand import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -25,22 +26,25 @@ import org.tensorflow.op.core.Tile * @constructor Creates [RepeatVector] object. * * @author Stan van der Bend - * @since 0.2 + * @since 0.3 */ public class RepeatVector( public val n: Int, name: String = "" ) : Layer(name) { - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { - //left empty + init { + require(n >= 1) { "Number of repetitions (n) in RepeatVector should be positive but got $n" } } + override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit + override fun computeOutputShape(inputShape: Shape): Shape { require(inputShape.numDimensions() == 2) { "input tensor must have 2 dimensions" } val tensorShape = TensorShape(inputShape) // TODO: maybe make `n` of type Long? - return Shape.make(tensorShape[0], n.toLong(), tensorShape[1]) + val input = inputShape.toLongArray() + return Shape.make(input[0], n.toLong(), input[1]) } override fun forward( @@ -52,7 +56,7 @@ public class RepeatVector( return tf.repeat(input, n) } - private fun Ops.repeat(input: Operand, n : Int) : Tile { + private fun Ops.repeat(input: Operand, n: Int): Tile { val x = expandDims(input, constant(1)) val pattern = stack(listOf(constant(1), constant(n), constant(1))) return tile(x, pattern) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt index 3b2014015..c2d4e4603 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -12,17 +12,13 @@ import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import org.tensorflow.Graph import org.tensorflow.Output +import org.tensorflow.Shape import org.tensorflow.op.Ops -/** - * A test for the [RepeatVector] layer. - * - * @author Stan van der Bend - */ internal class RepeatVectorLayerTest { @Test - fun `test output shape`(){ + fun `test output shape`() { val layer = RepeatVector(n = 2) val x = Array(10) { FloatArray(10) { 1F } } val y = layer(x) @@ -30,7 +26,7 @@ internal class RepeatVectorLayerTest { } @Test - fun `test repetition output`(){ + fun `test repetition output`() { val layer = RepeatVector(n = 2) val x = Array(3) { FloatArray(3) { it.toFloat() } } val y = layer(x) @@ -43,12 +39,12 @@ internal class RepeatVectorLayerTest { Assertions.assertArrayEquals(expected, actual) } - private operator fun RepeatVector.invoke(input : Array) : Output { + private operator fun RepeatVector.invoke(input: Array): Output { return Ops.create().let { tf -> - build(tf, KGraph(Graph().toGraphDef()), org.tensorflow.Shape.make(10, 10)) + build(tf, KGraph(Graph().toGraphDef()), Shape.make(10, 10)) val inputOp = tf.constant(input) val isTraining = tf.constant(true) - val numberOfLosses = tf.constant(1.0f) + val numberOfLosses = tf.constant(1.0f) forward(tf, inputOp, isTraining, numberOfLosses).asOutput() } } From 130e60af755b33d73de911d0c47e232f77e3eaa1 Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Mon, 21 Jun 2021 01:01:39 +0200 Subject: [PATCH 07/12] Removed redundant Obs.repeat ext fun --- .../dl/api/core/layer/reshaping/RepeatVector.kt | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt index 76b23a476..57ae773c7 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt @@ -12,7 +12,6 @@ import org.jetbrains.kotlinx.dl.api.core.shape.toLongArray import org.tensorflow.Operand import org.tensorflow.Shape import org.tensorflow.op.Ops -import org.tensorflow.op.core.Tile /** * Layer that repeats the input [n] times. @@ -53,13 +52,9 @@ public class RepeatVector( isTraining: Operand, numberOfLosses: Operand? ): Operand { - return tf.repeat(input, n) - } - - private fun Ops.repeat(input: Operand, n: Int): Tile { - val x = expandDims(input, constant(1)) - val pattern = stack(listOf(constant(1), constant(n), constant(1))) - return tile(x, pattern) + val x = tf.expandDims(input, tf.constant(1)) + val pattern = tf.stack(listOf(tf.constant(1), tf.constant(n), tf.constant(1))) + return tf.tile(x, pattern) } override var weights: Map> From af442562873396fbdba1dd1c27d11894688f4ecb Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Mon, 21 Jun 2021 16:28:26 +0200 Subject: [PATCH 08/12] Made changed requested by avan (see desc.) - change require message in computeOutputShape - used inputShape.size(...) for creating shape - removed author tag --- .../dl/api/core/layer/reshaping/RepeatVector.kt | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt index 57ae773c7..eed5c4553 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt @@ -7,8 +7,6 @@ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer -import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape -import org.jetbrains.kotlinx.dl.api.core.shape.toLongArray import org.tensorflow.Operand import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -24,7 +22,6 @@ import org.tensorflow.op.Ops * @property [name] Custom layer name. * @constructor Creates [RepeatVector] object. * - * @author Stan van der Bend * @since 0.3 */ public class RepeatVector( @@ -39,11 +36,10 @@ public class RepeatVector( override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit override fun computeOutputShape(inputShape: Shape): Shape { - require(inputShape.numDimensions() == 2) { "input tensor must have 2 dimensions" } - val tensorShape = TensorShape(inputShape) - // TODO: maybe make `n` of type Long? - val input = inputShape.toLongArray() - return Shape.make(input[0], n.toLong(), input[1]) + require(inputShape.numDimensions() == 2) { + "Input tensor must have 2 dimensions but got ${inputShape.numDimensions()}" + } + return Shape.make(inputShape.size(0), n.toLong(), inputShape.size(1)) } override fun forward( From 0cb1ac89f2735e4d0c21d440d5159180b97c3fa9 Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Mon, 21 Jun 2021 16:48:03 +0200 Subject: [PATCH 09/12] Used `=` instead of `return` block, added TODO --- .../dl/api/core/layer/RepeatVectorLayerTest.kt | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt index c2d4e4603..119ed1957 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt @@ -39,13 +39,12 @@ internal class RepeatVectorLayerTest { Assertions.assertArrayEquals(expected, actual) } - private operator fun RepeatVector.invoke(input: Array): Output { - return Ops.create().let { tf -> - build(tf, KGraph(Graph().toGraphDef()), Shape.make(10, 10)) - val inputOp = tf.constant(input) - val isTraining = tf.constant(true) - val numberOfLosses = tf.constant(1.0f) - forward(tf, inputOp, isTraining, numberOfLosses).asOutput() - } + // TODO: generalise this for Layer, see https://github.com/JetBrains/KotlinDL/issues/145 + private operator fun RepeatVector.invoke(input: Array): Output = Ops.create().let { tf -> + build(tf, KGraph(Graph().toGraphDef()), Shape.make(10, 10)) + val inputOp = tf.constant(input) + val isTraining = tf.constant(true) + val numberOfLosses = tf.constant(1.0f) + forward(tf, inputOp, isTraining, numberOfLosses).asOutput() } } From cf398b6990c9c4194bb115b1226e19d74636c26d Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Thu, 24 Jun 2021 09:31:23 +0200 Subject: [PATCH 10/12] Implemented changes requested by zaleslaw - save trainability status - renamed tests --- .../jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt | 1 + .../kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt index 22553bfc2..e453ef271 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt @@ -536,6 +536,7 @@ private fun createKerasRepeatVectorLayer(layer: RepeatVector): KerasLayer { val configX = LayerConfig( data_format = CHANNELS_LAST, dtype = DATATYPE_FLOAT32, + trainable = layer.isTrainable, name = layer.name, n = layer.n ) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt index 119ed1957..e6793aad0 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt @@ -18,7 +18,7 @@ import org.tensorflow.op.Ops internal class RepeatVectorLayerTest { @Test - fun `test output shape`() { + fun testOutputShape() { val layer = RepeatVector(n = 2) val x = Array(10) { FloatArray(10) { 1F } } val y = layer(x) @@ -26,7 +26,7 @@ internal class RepeatVectorLayerTest { } @Test - fun `test repetition output`() { + fun testOutput() { val layer = RepeatVector(n = 2) val x = Array(3) { FloatArray(3) { it.toFloat() } } val y = layer(x) From f5b08c3d5492f61aa60a34187c4112355eafeda8 Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Thu, 24 Jun 2021 09:40:38 +0200 Subject: [PATCH 11/12] Added test for negative `n` #123 --- .../kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt index e6793aad0..cf9f40e84 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt @@ -17,6 +17,13 @@ import org.tensorflow.op.Ops internal class RepeatVectorLayerTest { + @Test + fun testIllegalRepetitions() { + Assertions.assertThrows(IllegalArgumentException::class.java) { + RepeatVector(n = -10) + } + } + @Test fun testOutputShape() { val layer = RepeatVector(n = 2) From cc9bf98e611e4f8c158caa146048859b4ff81444 Mon Sep 17 00:00:00 2001 From: Stan van der Bend Date: Thu, 24 Jun 2021 09:43:38 +0200 Subject: [PATCH 12/12] Added missing newline --- .../kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt index eed5c4553..09016d081 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt @@ -64,4 +64,4 @@ public class RepeatVector( override fun toString(): String { return "RepeatVector" } -} \ No newline at end of file +}