-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added missing saving functions for ReLU and ELU activation layers (JetBrains#78) * Reverted changes to the imports * Added RepeatVector layer #123 * Added serialisation support for RepeatVector layer #123 * Wrote test for RepeatVector #123 * Made changed requested by avan (see desc.) - added missing require check in init block of RepeatVector - updated docs - reformatted code - housekeeping * Removed redundant Obs.repeat ext fun * Made changed requested by avan (see desc.) - change require message in computeOutputShape - used inputShape.size(...) for creating shape - removed author tag * Used `=` instead of `return` block, added TODO * Implemented changes requested by zaleslaw - save trainability status - renamed tests * Added test for negative `n` #123 * Added missing newline
- Loading branch information
Showing
6 changed files
with
146 additions
and
0 deletions.
There are no files selected for viewing
67 changes: 67 additions & 0 deletions
67
api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. | ||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. | ||
*/ | ||
|
||
package org.jetbrains.kotlinx.dl.api.core.layer.reshaping | ||
|
||
import org.jetbrains.kotlinx.dl.api.core.KGraph | ||
import org.jetbrains.kotlinx.dl.api.core.layer.Layer | ||
import org.tensorflow.Operand | ||
import org.tensorflow.Shape | ||
import org.tensorflow.op.Ops | ||
|
||
/** | ||
* Layer that repeats the input [n] times. | ||
* | ||
* Input shape: `2D tensor of shape (num_samples, features)`. | ||
* | ||
* Output shape: `3D tensor of shape (num_samples, n, features)`. | ||
* | ||
* @property n Repetition factor. | ||
* @property [name] Custom layer name. | ||
* @constructor Creates [RepeatVector] object. | ||
* | ||
* @since 0.3 | ||
*/ | ||
public class RepeatVector( | ||
public val n: Int, | ||
name: String = "" | ||
) : Layer(name) { | ||
|
||
init { | ||
require(n >= 1) { "Number of repetitions (n) in RepeatVector should be positive but got $n" } | ||
} | ||
|
||
override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit | ||
|
||
override fun computeOutputShape(inputShape: Shape): Shape { | ||
require(inputShape.numDimensions() == 2) { | ||
"Input tensor must have 2 dimensions but got ${inputShape.numDimensions()}" | ||
} | ||
return Shape.make(inputShape.size(0), n.toLong(), inputShape.size(1)) | ||
} | ||
|
||
override fun forward( | ||
tf: Ops, | ||
input: Operand<Float>, | ||
isTraining: Operand<Boolean>, | ||
numberOfLosses: Operand<Float>? | ||
): Operand<Float> { | ||
val x = tf.expandDims(input, tf.constant(1)) | ||
val pattern = tf.stack(listOf(tf.constant(1), tf.constant(n), tf.constant(1))) | ||
return tf.tile(x, pattern) | ||
} | ||
|
||
override var weights: Map<String, Array<*>> | ||
get() = emptyMap() | ||
set(value) = assignWeights(value) | ||
|
||
override val hasActivation: Boolean get() = false | ||
|
||
override val paramCount: Int get() = 0 | ||
|
||
override fun toString(): String { | ||
return "RepeatVector" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* | ||
* Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. | ||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. | ||
*/ | ||
|
||
package org.jetbrains.kotlinx.dl.api.core.layer | ||
|
||
import org.jetbrains.kotlinx.dl.api.core.KGraph | ||
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector | ||
import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray | ||
import org.junit.jupiter.api.Assertions | ||
import org.junit.jupiter.api.Test | ||
import org.tensorflow.Graph | ||
import org.tensorflow.Output | ||
import org.tensorflow.Shape | ||
import org.tensorflow.op.Ops | ||
|
||
internal class RepeatVectorLayerTest { | ||
|
||
@Test | ||
fun testIllegalRepetitions() { | ||
Assertions.assertThrows(IllegalArgumentException::class.java) { | ||
RepeatVector(n = -10) | ||
} | ||
} | ||
|
||
@Test | ||
fun testOutputShape() { | ||
val layer = RepeatVector(n = 2) | ||
val x = Array(10) { FloatArray(10) { 1F } } | ||
val y = layer(x) | ||
Assertions.assertArrayEquals(intArrayOf(10, layer.n, 10), y.shape().toIntArray()) | ||
} | ||
|
||
@Test | ||
fun testOutput() { | ||
val layer = RepeatVector(n = 2) | ||
val x = Array(3) { FloatArray(3) { it.toFloat() } } | ||
val y = layer(x) | ||
val actual = y.tensor().copyTo(Array(3) { Array(layer.n) { FloatArray(3) } }) | ||
val expected = arrayOf( | ||
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)), | ||
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)), | ||
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)) | ||
) | ||
Assertions.assertArrayEquals(expected, actual) | ||
} | ||
|
||
// TODO: generalise this for Layer, see https://github.com/JetBrains/KotlinDL/issues/145 | ||
private operator fun RepeatVector.invoke(input: Array<FloatArray>): Output<Float> = Ops.create().let { tf -> | ||
build(tf, KGraph(Graph().toGraphDef()), Shape.make(10, 10)) | ||
val inputOp = tf.constant(input) | ||
val isTraining = tf.constant(true) | ||
val numberOfLosses = tf.constant(1.0f) | ||
forward(tf, inputOp, isTraining, numberOfLosses).asOutput() | ||
} | ||
} |