Skip to content

Commit

Permalink
Add RepeatVector layer (#139)
Browse files Browse the repository at this point in the history
* Added missing saving functions for ReLU and ELU activation layers (JetBrains#78)

* Reverted changes to the imports

* Added RepeatVector layer #123

* Added serialisation support for RepeatVector layer #123

* Wrote test for RepeatVector #123

* Made changed requested by avan (see desc.)
- added missing require check in init block of RepeatVector
- updated docs
- reformatted code
- housekeeping

* Removed redundant Obs.repeat ext fun

* Made changed requested by avan (see desc.)

- change require message in computeOutputShape
- used inputShape.size(...) for creating shape
- removed author tag

* Used `=` instead of `return` block, added TODO

* Implemented changes requested by zaleslaw

- save trainability status
- renamed tests

* Added test for negative `n` #123

* Added missing newline
  • Loading branch information
dosier authored Jun 24, 2021
1 parent f7bebf9 commit 34ae33f
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer.reshaping

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops

/**
* Layer that repeats the input [n] times.
*
* Input shape: `2D tensor of shape (num_samples, features)`.
*
* Output shape: `3D tensor of shape (num_samples, n, features)`.
*
* @property n Repetition factor.
* @property [name] Custom layer name.
* @constructor Creates [RepeatVector] object.
*
* @since 0.3
*/
public class RepeatVector(
public val n: Int,
name: String = ""
) : Layer(name) {

init {
require(n >= 1) { "Number of repetitions (n) in RepeatVector should be positive but got $n" }
}

override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit

override fun computeOutputShape(inputShape: Shape): Shape {
require(inputShape.numDimensions() == 2) {
"Input tensor must have 2 dimensions but got ${inputShape.numDimensions()}"
}
return Shape.make(inputShape.size(0), n.toLong(), inputShape.size(1))
}

override fun forward(
tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> {
val x = tf.expandDims(input, tf.constant(1))
val pattern = tf.stack(listOf(tf.constant(1), tf.constant(n), tf.constant(1)))
return tf.tile(x, pattern)
}

override var weights: Map<String, Array<*>>
get() = emptyMap()
set(value) = assignWeights(value)

override val hasActivation: Boolean get() = false

override val paramCount: Int get() = 0

override fun toString(): String {
return "RepeatVector"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ internal const val LAYER_DROPOUT: String = "Dropout"
// Attention layers
// Reshaping layers
internal const val LAYER_FLATTEN: String = "Flatten"
internal const val LAYER_REPEAT_VECTOR: String = "RepeatVector"
internal const val LAYER_RESHAPE: String = "Reshape"
internal const val LAYER_ZERO_PADDING_2D = "ZeroPadding2D"
internal const val LAYER_CROPPING_2D = "Cropping2D"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.pooling.*
import org.jetbrains.kotlinx.dl.api.core.layer.regularization.Dropout
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Cropping2D
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Reshape
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D
import org.jetbrains.kotlinx.dl.api.core.regularizer.L1
Expand Down Expand Up @@ -142,6 +143,7 @@ private fun convertToLayer(
// Attention layers
// Reshaping layers
LAYER_FLATTEN -> createFlattenLayer(kerasLayer.config!!.name!!)
LAYER_REPEAT_VECTOR -> createRepeatVectorLayer(kerasLayer.config!!, kerasLayer.config.name!!)
LAYER_RESHAPE -> createReshapeLayer(kerasLayer.config!!, kerasLayer.config.name!!)
LAYER_CROPPING_2D -> createCropping2DLayer(kerasLayer.config!!, kerasLayer.config.name!!)
LAYER_ZERO_PADDING_2D -> createZeroPadding2DLayer(kerasLayer.config!!, kerasLayer.config.name!!)
Expand Down Expand Up @@ -722,6 +724,10 @@ private fun createFlattenLayer(name: String): Layer {
return Flatten(name = name)
}

private fun createRepeatVectorLayer(config: LayerConfig, name: String): Layer {
return RepeatVector(name = name, n = config.n!!)
}

private fun createReshapeLayer(config: LayerConfig, name: String): Layer {
return Reshape(name = name, targetShape = config.target_shape!!)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.merge.*
import org.jetbrains.kotlinx.dl.api.core.layer.normalization.BatchNorm
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.*
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D
import org.jetbrains.kotlinx.dl.api.core.regularizer.L2L1
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
Expand Down Expand Up @@ -98,6 +99,7 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i
// Attention layers
// Reshaping layers
is Flatten -> createKerasFlattenLayer(layer)
is RepeatVector -> createKerasRepeatVectorLayer(layer)
is ZeroPadding2D -> createKerasZeroPadding2DLayer(layer)
// Merging layers
is Add -> createKerasAddLayer(layer)
Expand Down Expand Up @@ -584,6 +586,17 @@ private fun createKerasFlattenLayer(layer: Flatten): KerasLayer {
return KerasLayer(class_name = LAYER_FLATTEN, config = configX)
}

private fun createKerasRepeatVectorLayer(layer: RepeatVector): KerasLayer {
val configX = LayerConfig(
data_format = CHANNELS_LAST,
dtype = DATATYPE_FLOAT32,
trainable = layer.isTrainable,
name = layer.name,
n = layer.n
)
return KerasLayer(class_name = LAYER_REPEAT_VECTOR, config = configX)
}

private fun createKerasConcatenateLayer(layer: Concatenate): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ internal data class LayerConfig(
@Json(serializeNull = false)
val moving_variance_initializer: KerasInitializer? = null,
@Json(serializeNull = false)
val n: Int? = null,
@Json(serializeNull = false)
val name: String? = null,
@Json(serializeNull = false)
val negative_slope: Double? = null,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector
import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import org.tensorflow.Graph
import org.tensorflow.Output
import org.tensorflow.Shape
import org.tensorflow.op.Ops

internal class RepeatVectorLayerTest {

@Test
fun testIllegalRepetitions() {
Assertions.assertThrows(IllegalArgumentException::class.java) {
RepeatVector(n = -10)
}
}

@Test
fun testOutputShape() {
val layer = RepeatVector(n = 2)
val x = Array(10) { FloatArray(10) { 1F } }
val y = layer(x)
Assertions.assertArrayEquals(intArrayOf(10, layer.n, 10), y.shape().toIntArray())
}

@Test
fun testOutput() {
val layer = RepeatVector(n = 2)
val x = Array(3) { FloatArray(3) { it.toFloat() } }
val y = layer(x)
val actual = y.tensor().copyTo(Array(3) { Array(layer.n) { FloatArray(3) } })
val expected = arrayOf(
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)),
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)),
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F))
)
Assertions.assertArrayEquals(expected, actual)
}

// TODO: generalise this for Layer, see https://github.com/JetBrains/KotlinDL/issues/145
private operator fun RepeatVector.invoke(input: Array<FloatArray>): Output<Float> = Ops.create().let { tf ->
build(tf, KGraph(Graph().toGraphDef()), Shape.make(10, 10))
val inputOp = tf.constant(input)
val isTraining = tf.constant(true)
val numberOfLosses = tf.constant(1.0f)
forward(tf, inputOp, isTraining, numberOfLosses).asOutput()
}
}

0 comments on commit 34ae33f

Please sign in to comment.