From 34ae33f2e3537ae31802eef6e7f5f5260fe69e62 Mon Sep 17 00:00:00 2001 From: Stan van der Bend <10871975+dosier@users.noreply.github.com> Date: Thu, 24 Jun 2021 18:35:53 +0200 Subject: [PATCH] Add RepeatVector layer (#139) * Added missing saving functions for ReLU and ELU activation layers (JetBrains#78) * Reverted changes to the imports * Added RepeatVector layer #123 * Added serialisation support for RepeatVector layer #123 * Wrote test for RepeatVector #123 * Made changed requested by avan (see desc.) - added missing require check in init block of RepeatVector - updated docs - reformatted code - housekeeping * Removed redundant Obs.repeat ext fun * Made changed requested by avan (see desc.) - change require message in computeOutputShape - used inputShape.size(...) for creating shape - removed author tag * Used `=` instead of `return` block, added TODO * Implemented changes requested by zaleslaw - save trainability status - renamed tests * Added test for negative `n` #123 * Added missing newline --- .../api/core/layer/reshaping/RepeatVector.kt | 67 +++++++++++++++++++ .../dl/api/inference/keras/KerasConstants.kt | 1 + .../dl/api/inference/keras/ModelLoader.kt | 6 ++ .../dl/api/inference/keras/ModelSaver.kt | 13 ++++ .../api/inference/keras/config/LayerConfig.kt | 2 + .../api/core/layer/RepeatVectorLayerTest.kt | 57 ++++++++++++++++ 6 files changed, 146 insertions(+) create mode 100644 api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt create mode 100644 api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt new file mode 100644 index 000000000..09016d081 --- /dev/null +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt @@ -0,0 +1,67 @@ +/* + * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.core.layer.reshaping + +import org.jetbrains.kotlinx.dl.api.core.KGraph +import org.jetbrains.kotlinx.dl.api.core.layer.Layer +import org.tensorflow.Operand +import org.tensorflow.Shape +import org.tensorflow.op.Ops + +/** + * Layer that repeats the input [n] times. + * + * Input shape: `2D tensor of shape (num_samples, features)`. + * + * Output shape: `3D tensor of shape (num_samples, n, features)`. + * + * @property n Repetition factor. + * @property [name] Custom layer name. + * @constructor Creates [RepeatVector] object. + * + * @since 0.3 + */ +public class RepeatVector( + public val n: Int, + name: String = "" +) : Layer(name) { + + init { + require(n >= 1) { "Number of repetitions (n) in RepeatVector should be positive but got $n" } + } + + override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit + + override fun computeOutputShape(inputShape: Shape): Shape { + require(inputShape.numDimensions() == 2) { + "Input tensor must have 2 dimensions but got ${inputShape.numDimensions()}" + } + return Shape.make(inputShape.size(0), n.toLong(), inputShape.size(1)) + } + + override fun forward( + tf: Ops, + input: Operand, + isTraining: Operand, + numberOfLosses: Operand? + ): Operand { + val x = tf.expandDims(input, tf.constant(1)) + val pattern = tf.stack(listOf(tf.constant(1), tf.constant(n), tf.constant(1))) + return tf.tile(x, pattern) + } + + override var weights: Map> + get() = emptyMap() + set(value) = assignWeights(value) + + override val hasActivation: Boolean get() = false + + override val paramCount: Int get() = 0 + + override fun toString(): String { + return "RepeatVector" + } +} diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt index 624240717..7e1e64bd4 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt @@ -38,6 +38,7 @@ internal const val LAYER_DROPOUT: String = "Dropout" // Attention layers // Reshaping layers internal const val LAYER_FLATTEN: String = "Flatten" +internal const val LAYER_REPEAT_VECTOR: String = "RepeatVector" internal const val LAYER_RESHAPE: String = "Reshape" internal const val LAYER_ZERO_PADDING_2D = "ZeroPadding2D" internal const val LAYER_CROPPING_2D = "Cropping2D" diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt index ce8ceb90e..96b200f9c 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt @@ -22,6 +22,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.pooling.* import org.jetbrains.kotlinx.dl.api.core.layer.regularization.Dropout import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Cropping2D import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten +import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Reshape import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D import org.jetbrains.kotlinx.dl.api.core.regularizer.L1 @@ -142,6 +143,7 @@ private fun convertToLayer( // Attention layers // Reshaping layers LAYER_FLATTEN -> createFlattenLayer(kerasLayer.config!!.name!!) + LAYER_REPEAT_VECTOR -> createRepeatVectorLayer(kerasLayer.config!!, kerasLayer.config.name!!) LAYER_RESHAPE -> createReshapeLayer(kerasLayer.config!!, kerasLayer.config.name!!) LAYER_CROPPING_2D -> createCropping2DLayer(kerasLayer.config!!, kerasLayer.config.name!!) LAYER_ZERO_PADDING_2D -> createZeroPadding2DLayer(kerasLayer.config!!, kerasLayer.config.name!!) @@ -722,6 +724,10 @@ private fun createFlattenLayer(name: String): Layer { return Flatten(name = name) } +private fun createRepeatVectorLayer(config: LayerConfig, name: String): Layer { + return RepeatVector(name = name, n = config.n!!) +} + private fun createReshapeLayer(config: LayerConfig, name: String): Layer { return Reshape(name = name, targetShape = config.target_shape!!) } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt index 79e833303..2c231670b 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt @@ -21,6 +21,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.merge.* import org.jetbrains.kotlinx.dl.api.core.layer.normalization.BatchNorm import org.jetbrains.kotlinx.dl.api.core.layer.pooling.* import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten +import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D import org.jetbrains.kotlinx.dl.api.core.regularizer.L2L1 import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer @@ -98,6 +99,7 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i // Attention layers // Reshaping layers is Flatten -> createKerasFlattenLayer(layer) + is RepeatVector -> createKerasRepeatVectorLayer(layer) is ZeroPadding2D -> createKerasZeroPadding2DLayer(layer) // Merging layers is Add -> createKerasAddLayer(layer) @@ -584,6 +586,17 @@ private fun createKerasFlattenLayer(layer: Flatten): KerasLayer { return KerasLayer(class_name = LAYER_FLATTEN, config = configX) } +private fun createKerasRepeatVectorLayer(layer: RepeatVector): KerasLayer { + val configX = LayerConfig( + data_format = CHANNELS_LAST, + dtype = DATATYPE_FLOAT32, + trainable = layer.isTrainable, + name = layer.name, + n = layer.n + ) + return KerasLayer(class_name = LAYER_REPEAT_VECTOR, config = configX) +} + private fun createKerasConcatenateLayer(layer: Concatenate): KerasLayer { val configX = LayerConfig( dtype = DATATYPE_FLOAT32, diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt index 62381cd3d..242f7384a 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt @@ -85,6 +85,8 @@ internal data class LayerConfig( @Json(serializeNull = false) val moving_variance_initializer: KerasInitializer? = null, @Json(serializeNull = false) + val n: Int? = null, + @Json(serializeNull = false) val name: String? = null, @Json(serializeNull = false) val negative_slope: Double? = null, diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt new file mode 100644 index 000000000..cf9f40e84 --- /dev/null +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt @@ -0,0 +1,57 @@ +/* + * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.core.layer + +import org.jetbrains.kotlinx.dl.api.core.KGraph +import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector +import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.tensorflow.Graph +import org.tensorflow.Output +import org.tensorflow.Shape +import org.tensorflow.op.Ops + +internal class RepeatVectorLayerTest { + + @Test + fun testIllegalRepetitions() { + Assertions.assertThrows(IllegalArgumentException::class.java) { + RepeatVector(n = -10) + } + } + + @Test + fun testOutputShape() { + val layer = RepeatVector(n = 2) + val x = Array(10) { FloatArray(10) { 1F } } + val y = layer(x) + Assertions.assertArrayEquals(intArrayOf(10, layer.n, 10), y.shape().toIntArray()) + } + + @Test + fun testOutput() { + val layer = RepeatVector(n = 2) + val x = Array(3) { FloatArray(3) { it.toFloat() } } + val y = layer(x) + val actual = y.tensor().copyTo(Array(3) { Array(layer.n) { FloatArray(3) } }) + val expected = arrayOf( + arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)), + arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)), + arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)) + ) + Assertions.assertArrayEquals(expected, actual) + } + + // TODO: generalise this for Layer, see https://github.com/JetBrains/KotlinDL/issues/145 + private operator fun RepeatVector.invoke(input: Array): Output = Ops.create().let { tf -> + build(tf, KGraph(Graph().toGraphDef()), Shape.make(10, 10)) + val inputOp = tf.constant(input) + val isTraining = tf.constant(true) + val numberOfLosses = tf.constant(1.0f) + forward(tf, inputOp, isTraining, numberOfLosses).asOutput() + } +}