Skip to content

Commit

Permalink
Wrote test for RepeatVector Kotlin#123
Browse files Browse the repository at this point in the history
  • Loading branch information
dosier committed Jun 18, 2021
1 parent 56445be commit bf61c5d
Showing 1 changed file with 55 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector
import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import org.tensorflow.Graph
import org.tensorflow.Output
import org.tensorflow.op.Ops

/**
* A test for the [RepeatVector] layer.
*
* @author Stan van der Bend
*/
internal class RepeatVectorLayerTest {

@Test
fun `test output shape`(){
val layer = RepeatVector(n = 2)
val x = Array(10) { FloatArray(10) { 1F } }
val y = layer(x)
Assertions.assertArrayEquals(intArrayOf(10, layer.n, 10), y.shape().toIntArray())
}

@Test
fun `test repetition output`(){
val layer = RepeatVector(n = 2)
val x = Array(3) { FloatArray(3) { it.toFloat() } }
val y = layer(x)
val actual = y.tensor().copyTo(Array(3) { Array(layer.n) { FloatArray(3) } })
val expected = arrayOf(
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)),
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F)),
arrayOf(floatArrayOf(0F, 1F, 2F), floatArrayOf(0F, 1F, 2F))
)
Assertions.assertArrayEquals(expected, actual)
}

private operator fun RepeatVector.invoke(input : Array<FloatArray>) : Output<Float> {
return Ops.create().let { tf ->
build(tf, KGraph(Graph().toGraphDef()), org.tensorflow.Shape.make(10, 10))
val inputOp = tf.constant(input)
val isTraining = tf.constant(true)
val numberOfLosses = tf.constant(1.0f)
forward(tf, inputOp, isTraining, numberOfLosses).asOutput()
}
}
}

0 comments on commit bf61c5d

Please sign in to comment.