Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RepeatVector layer #139

Merged
merged 15 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
dosier marked this conversation as resolved.
Show resolved Hide resolved
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer.reshaping

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Tile

/**
* Layer that repeats the input [n] times.
*
* Input shape: `2D tensor of shape (num_samples, features)`.
*
* Output shape: `3D tensor of shape (num_samples, n, features)`.
*
* @property n Repetition factor.
* @property [name] Custom layer name.
* @constructor Creates [RepeatVector] object.
*
* @author Stan van der Bend
* @since 0.2
dosier marked this conversation as resolved.
Show resolved Hide resolved
*/
public class RepeatVector(
public val n: Int,
name: String = ""
) : Layer(name) {
dosier marked this conversation as resolved.
Show resolved Hide resolved

override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {
dosier marked this conversation as resolved.
Show resolved Hide resolved
//left empty
}

override fun computeOutputShape(inputShape: Shape): Shape {
require(inputShape.numDimensions() == 2) { "input tensor must have 2 dimensions" }
dosier marked this conversation as resolved.
Show resolved Hide resolved
val tensorShape = TensorShape(inputShape)
// TODO: maybe make `n` of type Long?
dosier marked this conversation as resolved.
Show resolved Hide resolved
return Shape.make(tensorShape[0], n.toLong(), tensorShape[1])
}

override fun forward(
tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> {
return tf.repeat(input, n)
dosier marked this conversation as resolved.
Show resolved Hide resolved
}

private fun Ops.repeat(input: Operand<Float>, n : Int) : Tile<Float> {
val x = expandDims(input, constant(1))
val pattern = stack(listOf(constant(1), constant(n), constant(1)))
return tile(x, pattern)
}

override var weights: Map<String, Array<*>>
get() = emptyMap()
set(value) = assignWeights(value)

override val hasActivation: Boolean get() = false

override val paramCount: Int get() = 0

override fun toString(): String {
return "RepeatVector"
}
}
dosier marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ internal const val LAYER_DROPOUT: String = "Dropout"
// Attention layers
// Reshaping layers
internal const val LAYER_FLATTEN: String = "Flatten"
internal const val LAYER_REPEAT_VECTOR: String = "RepeatVector"
internal const val LAYER_RESHAPE: String = "Reshape"
internal const val LAYER_ZERO_PADDING_2D = "ZeroPadding2D"
internal const val LAYER_CROPPING_2D = "Cropping2D"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.pooling.*
import org.jetbrains.kotlinx.dl.api.core.layer.regularization.Dropout
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Cropping2D
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Reshape
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D
import org.jetbrains.kotlinx.dl.api.core.regularizer.L1
Expand Down Expand Up @@ -141,6 +142,7 @@ private fun convertToLayer(
// Attention layers
// Reshaping layers
LAYER_FLATTEN -> createFlattenLayer(kerasLayer.config!!.name!!)
LAYER_REPEAT_VECTOR -> createRepeatVectorLayer(kerasLayer.config!!, kerasLayer.config.name!!)
LAYER_RESHAPE -> createReshapeLayer(kerasLayer.config!!, kerasLayer.config.name!!)
LAYER_CROPPING_2D -> createCropping2DLayer(kerasLayer.config!!, kerasLayer.config.name!!)
LAYER_ZERO_PADDING_2D -> createZeroPadding2DLayer(kerasLayer.config!!, kerasLayer.config.name!!)
Expand Down Expand Up @@ -715,6 +717,10 @@ private fun createFlattenLayer(name: String): Layer {
return Flatten(name = name)
}

private fun createRepeatVectorLayer(config: LayerConfig, name: String): Layer {
return RepeatVector(name = name, n = config.n!!)
}

private fun createReshapeLayer(config: LayerConfig, name: String): Layer {
return Reshape(name = name, targetShape = config.target_shape!!)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.merge.*
import org.jetbrains.kotlinx.dl.api.core.layer.normalization.BatchNorm
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.*
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D
import org.jetbrains.kotlinx.dl.api.core.regularizer.L2L1
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
Expand Down Expand Up @@ -100,6 +101,7 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i
// Attention layers
// Reshaping layers
is Flatten -> createKerasFlattenLayer(layer)
is RepeatVector -> createKerasRepeatVectorLayer(layer)
is ZeroPadding2D -> createKerasZeroPadding2DLayer(layer)
// Merging layers
is Add -> createKerasAddLayer(layer)
Expand Down Expand Up @@ -530,6 +532,16 @@ private fun createKerasFlattenLayer(layer: Flatten): KerasLayer {
return KerasLayer(class_name = LAYER_FLATTEN, config = configX)
}

private fun createKerasRepeatVectorLayer(layer: RepeatVector): KerasLayer {
val configX = LayerConfig(
dosier marked this conversation as resolved.
Show resolved Hide resolved
data_format = CHANNELS_LAST,
dtype = DATATYPE_FLOAT32,
name = layer.name,
n = layer.n
)
return KerasLayer(class_name = LAYER_REPEAT_VECTOR, config = configX)
}

private fun createKerasConcatenateLayer(layer: Concatenate): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ internal data class LayerConfig(
@Json(serializeNull = false)
val moving_variance_initializer: KerasInitializer? = null,
@Json(serializeNull = false)
val n: Int? = null,
@Json(serializeNull = false)
val name: String? = null,
@Json(serializeNull = false)
val negative_slope: Double? = null,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
dosier marked this conversation as resolved.
Show resolved Hide resolved
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector
import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import org.tensorflow.Graph
import org.tensorflow.Output
import org.tensorflow.op.Ops

/**
dosier marked this conversation as resolved.
Show resolved Hide resolved
* A test for the [RepeatVector] layer.
*
* @author Stan van der Bend
*/
internal class RepeatVectorLayerTest {

@Test
fun `test output shape`(){
dosier marked this conversation as resolved.
Show resolved Hide resolved
val layer = RepeatVector(n = 2)
val x = Array(10) { FloatArray(10) { 1F } }
val y = layer(x)
Assertions.assertArrayEquals(intArrayOf(10, layer.n, 10), y.shape().toIntArray())
}

@Test
fun `test repetition output`(){
val layer = RepeatVector(n = 2)
val x = Array(3) { FloatArray(3) { it.toFloat() } }
val y = layer(x)
val actual = y.tensor().copyTo(Array(3) { Array(layer.n) { FloatArray(3) } })
val expected = arrayOf(
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)),
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)),
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F))
)
Assertions.assertArrayEquals(expected, actual)
}

private operator fun RepeatVector.invoke(input : Array<FloatArray>) : Output<Float> {
dosier marked this conversation as resolved.
Show resolved Hide resolved
return Ops.create().let { tf ->
build(tf, KGraph(Graph().toGraphDef()), org.tensorflow.Shape.make(10, 10))
val inputOp = tf.constant(input)
val isTraining = tf.constant(true)
val numberOfLosses = tf.constant(1.0f)
forward(tf, inputOp, isTraining, numberOfLosses).asOutput()
}
}
}