diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt new file mode 100644 index 000000000..09016d081 --- /dev/null +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt @@ -0,0 +1,67 @@ +/* + * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.core.layer.reshaping + +import org.jetbrains.kotlinx.dl.api.core.KGraph +import org.jetbrains.kotlinx.dl.api.core.layer.Layer +import org.tensorflow.Operand +import org.tensorflow.Shape +import org.tensorflow.op.Ops + +/** + * Layer that repeats the input [n] times. + * + * Input shape: `2D tensor of shape (num_samples, features)`. + * + * Output shape: `3D tensor of shape (num_samples, n, features)`. + * + * @property n Repetition factor. + * @property [name] Custom layer name. + * @constructor Creates [RepeatVector] object. + * + * @since 0.3 + */ +public class RepeatVector( + public val n: Int, + name: String = "" +) : Layer(name) { + + init { + require(n >= 1) { "Number of repetitions (n) in RepeatVector should be positive but got $n" } + } + + override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit + + override fun computeOutputShape(inputShape: Shape): Shape { + require(inputShape.numDimensions() == 2) { + "Input tensor must have 2 dimensions but got ${inputShape.numDimensions()}" + } + return Shape.make(inputShape.size(0), n.toLong(), inputShape.size(1)) + } + + override fun forward( + tf: Ops, + input: Operand, + isTraining: Operand, + numberOfLosses: Operand? + ): Operand { + val x = tf.expandDims(input, tf.constant(1)) + val pattern = tf.stack(listOf(tf.constant(1), tf.constant(n), tf.constant(1))) + return tf.tile(x, pattern) + } + + override var weights: Map> + get() = emptyMap() + set(value) = assignWeights(value) + + override val hasActivation: Boolean get() = false + + override val paramCount: Int get() = 0 + + override fun toString(): String { + return "RepeatVector" + } +} diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt index 87999086f..9cfb3290c 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt @@ -37,6 +37,7 @@ internal const val LAYER_DROPOUT: String = "Dropout" // Attention layers // Reshaping layers internal const val LAYER_FLATTEN: String = "Flatten" +internal const val LAYER_REPEAT_VECTOR: String = "RepeatVector" internal const val LAYER_RESHAPE: String = "Reshape" internal const val LAYER_ZERO_PADDING_2D = "ZeroPadding2D" internal const val LAYER_CROPPING_2D = "Cropping2D" diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt index 7eb509c16..983da7fcb 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt @@ -22,6 +22,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.pooling.* import org.jetbrains.kotlinx.dl.api.core.layer.regularization.Dropout import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Cropping2D import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten +import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Reshape import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D import org.jetbrains.kotlinx.dl.api.core.regularizer.L1 @@ -141,6 +142,7 @@ private fun convertToLayer( // Attention layers // Reshaping layers LAYER_FLATTEN -> createFlattenLayer(kerasLayer.config!!.name!!) + LAYER_REPEAT_VECTOR -> createRepeatVectorLayer(kerasLayer.config!!, kerasLayer.config.name!!) LAYER_RESHAPE -> createReshapeLayer(kerasLayer.config!!, kerasLayer.config.name!!) LAYER_CROPPING_2D -> createCropping2DLayer(kerasLayer.config!!, kerasLayer.config.name!!) LAYER_ZERO_PADDING_2D -> createZeroPadding2DLayer(kerasLayer.config!!, kerasLayer.config.name!!) @@ -715,6 +717,10 @@ private fun createFlattenLayer(name: String): Layer { return Flatten(name = name) } +private fun createRepeatVectorLayer(config: LayerConfig, name: String): Layer { + return RepeatVector(name = name, n = config.n!!) +} + private fun createReshapeLayer(config: LayerConfig, name: String): Layer { return Reshape(name = name, targetShape = config.target_shape!!) } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt index d7f4e937d..e453ef271 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt @@ -24,6 +24,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.merge.* import org.jetbrains.kotlinx.dl.api.core.layer.normalization.BatchNorm import org.jetbrains.kotlinx.dl.api.core.layer.pooling.* import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten +import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D import org.jetbrains.kotlinx.dl.api.core.regularizer.L2L1 import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer @@ -100,6 +101,7 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i // Attention layers // Reshaping layers is Flatten -> createKerasFlattenLayer(layer) + is RepeatVector -> createKerasRepeatVectorLayer(layer) is ZeroPadding2D -> createKerasZeroPadding2DLayer(layer) // Merging layers is Add -> createKerasAddLayer(layer) @@ -530,6 +532,17 @@ private fun createKerasFlattenLayer(layer: Flatten): KerasLayer { return KerasLayer(class_name = LAYER_FLATTEN, config = configX) } +private fun createKerasRepeatVectorLayer(layer: RepeatVector): KerasLayer { + val configX = LayerConfig( + data_format = CHANNELS_LAST, + dtype = DATATYPE_FLOAT32, + trainable = layer.isTrainable, + name = layer.name, + n = layer.n + ) + return KerasLayer(class_name = LAYER_REPEAT_VECTOR, config = configX) +} + private fun createKerasConcatenateLayer(layer: Concatenate): KerasLayer { val configX = LayerConfig( dtype = DATATYPE_FLOAT32, diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt index 62381cd3d..242f7384a 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/config/LayerConfig.kt @@ -85,6 +85,8 @@ internal data class LayerConfig( @Json(serializeNull = false) val moving_variance_initializer: KerasInitializer? = null, @Json(serializeNull = false) + val n: Int? = null, + @Json(serializeNull = false) val name: String? = null, @Json(serializeNull = false) val negative_slope: Double? = null, diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt new file mode 100644 index 000000000..cf9f40e84 --- /dev/null +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt @@ -0,0 +1,57 @@ +/* + * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.core.layer + +import org.jetbrains.kotlinx.dl.api.core.KGraph +import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector +import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.tensorflow.Graph +import org.tensorflow.Output +import org.tensorflow.Shape +import org.tensorflow.op.Ops + +internal class RepeatVectorLayerTest { + + @Test + fun testIllegalRepetitions() { + Assertions.assertThrows(IllegalArgumentException::class.java) { + RepeatVector(n = -10) + } + } + + @Test + fun testOutputShape() { + val layer = RepeatVector(n = 2) + val x = Array(10) { FloatArray(10) { 1F } } + val y = layer(x) + Assertions.assertArrayEquals(intArrayOf(10, layer.n, 10), y.shape().toIntArray()) + } + + @Test + fun testOutput() { + val layer = RepeatVector(n = 2) + val x = Array(3) { FloatArray(3) { it.toFloat() } } + val y = layer(x) + val actual = y.tensor().copyTo(Array(3) { Array(layer.n) { FloatArray(3) } }) + val expected = arrayOf( + arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)), + arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)), + arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)) + ) + Assertions.assertArrayEquals(expected, actual) + } + + // TODO: generalise this for Layer, see https://github.com/JetBrains/KotlinDL/issues/145 + private operator fun RepeatVector.invoke(input: Array): Output = Ops.create().let { tf -> + build(tf, KGraph(Graph().toGraphDef()), Shape.make(10, 10)) + val inputOp = tf.constant(input) + val isTraining = tf.constant(true) + val numberOfLosses = tf.constant(1.0f) + forward(tf, inputOp, isTraining, numberOfLosses).asOutput() + } +}