Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
zaleslaw committed Jun 2, 2021
2 parents b9970a5 + a672162 commit 7e8c7d9
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer.pooling

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.util.TF
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops

/**
* Global average pooling operation for temporal data.
* NOTE: Works with tensors which must have rank 3 (batch, steps, features).
* Input shape: 3D tensor with shape `(batch_size, steps, features)`.
* Output shape: 2D tensor with shape `(batch_size, features)`.
* @property [name] Custom layer name.
* @constructor Creates [GlobalAvgPool1D] object.
*/
public class GlobalAvgPool1D(
name: String = ""
) : Layer(name) {
override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { }

override fun computeOutputShape(inputShape: Shape): Shape {
return Shape.make(inputShape.size(0), inputShape.size(2))
}

override fun forward(tf: Ops, input: Operand<Float>, isTraining: Operand<Boolean>, numberOfLosses: Operand<Float>?): Operand<Float> {
// TODO support for different dataFormat("channel_last", "channel_first")
var stepAxis = 1
// TODO support for masking
return TF.mean(tf, input, tf.constant(stepAxis))
}

override val weights: Map<String, Array<*>> get() = emptyMap()

override val hasActivation: Boolean get() = false

override val paramCount: Int get() = 0

override fun toString(): String {
return "GlobalAvgPool1D(name=$name)"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ internal const val LAYER_MAXIMUM: String = "Maximum"
internal const val LAYER_MINIMUM: String = "Minimum"
internal const val LAYER_CONCATENATE: String = "Concatenate"
internal const val LAYER_GLOBAL_AVG_POOLING_2D: String = "GlobalAveragePooling2D"
internal const val LAYER_GLOBAL_AVG_POOLING_1D: String = "GlobalAveragePooling1D"

// Keras data types
internal const val DATATYPE_FLOAT32: String = "float32"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.merge.*
import org.jetbrains.kotlinx.dl.api.core.layer.normalization.BatchNorm
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.AvgPool2D
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.GlobalAvgPool1D
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.GlobalAvgPool2D
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.MaxPool2D
import org.jetbrains.kotlinx.dl.api.core.layer.regularization.Dropout
Expand Down Expand Up @@ -159,6 +160,7 @@ private fun convertToSequentialLayer(
LAYER_GLOBAL_AVG_POOLING_2D -> createGlobalAvgPooling2D(
kerasLayer.config!!.name!!
)
LAYER_GLOBAL_AVG_POOLING_1D -> createGlobalAvgPooling1D( kerasLayer.config!!.name!! )
else -> throw IllegalStateException("${kerasLayer.class_name} is not supported for Sequential model!")
}
}
Expand Down Expand Up @@ -306,6 +308,7 @@ private fun convertToLayer(
LAYER_GLOBAL_AVG_POOLING_2D -> createGlobalAvgPooling2D(
kerasLayer.config!!.name!!
)
LAYER_GLOBAL_AVG_POOLING_1D -> createGlobalAvgPooling1D( kerasLayer.config!!.name!! )
else -> throw IllegalStateException("${kerasLayer.class_name} is not supported yet!")
}

Expand All @@ -332,6 +335,14 @@ private fun createGlobalAvgPooling2D(
)
}

private fun createGlobalAvgPooling1D(
name: String
): Layer {
return GlobalAvgPool1D(
name = name
)
}

private fun createAddLayer(
name: String
): Layer {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.merge.Add
import org.jetbrains.kotlinx.dl.api.core.layer.merge.Concatenate
import org.jetbrains.kotlinx.dl.api.core.layer.normalization.BatchNorm
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.AvgPool2D
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.GlobalAvgPool1D
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.GlobalAvgPool2D
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.MaxPool2D
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
Expand Down Expand Up @@ -86,6 +87,7 @@ private fun convertToKerasLayer(layer: Layer, isKerasFullyCompatible: Boolean, i
DepthwiseConv2D::class -> createKerasDepthwiseConv2D(layer as DepthwiseConv2D, isKerasFullyCompatible)
SeparableConv2D::class -> createSeparableConv2D(layer as SeparableConv2D, isKerasFullyCompatible)
Concatenate::class -> createKerasConcatenate(layer as Concatenate)
GlobalAvgPool1D::class -> createKerasGlobalAveragePooling1DLayer(layer as GlobalAvgPool1D)
else -> throw IllegalStateException("${layer.name} with type ${layer::class.simpleName} is not supported yet!")
}

Expand Down Expand Up @@ -123,6 +125,14 @@ private fun createKerasGlobalAveragePooling2DLayer(layer: GlobalAvgPool2D): Kera
return KerasLayer(class_name = LAYER_GLOBAL_AVG_POOLING_2D, config = configX)
}

private fun createKerasGlobalAveragePooling1DLayer(layer: GlobalAvgPool1D): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
name = layer.name
)
return KerasLayer(class_name = LAYER_GLOBAL_AVG_POOLING_1D, config = configX)
}

private fun createKerasAddLayer(layer: Add): KerasLayer {
val configX = LayerConfig(
dtype = DATATYPE_FLOAT32,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.layer.pooling.GlobalAvgPool1D
import org.junit.jupiter.api.Test

internal class GlobalAvgPooling1DTest : PoolLayerTest() {
@Test
fun globalAvgPool1DTest(){
val input = Array(2, { Array(3, { FloatArray(4) { 0f } } ) } )
val expected = Array(2, {FloatArray(4) { 0f } })
assertGlobalAvgPool1DEquals(GlobalAvgPool1D(),input, expected )
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.activation.EPS
import org.jetbrains.kotlinx.dl.api.core.shape.*
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.tensorflow.EagerSession
import org.tensorflow.Graph
import org.tensorflow.*
import org.tensorflow.op.Ops

open class PoolLayerTest {
protected fun assertGlobalAvgPool1DEquals(
layer: Layer,
input:Array<Array<FloatArray>>,
expected: Array<FloatArray>,
) {
val actual = Array(expected.size) {FloatArray(expected[0].size) { 0.toFloat() } }
assertPoolingLayer(layer,input, expected,actual,::assertGlobalAvgPool1DEquals)
}

private fun assertPoolingLayer(
layer: Layer,
input:Array<Array<FloatArray>>,
expected: Array<FloatArray>,
actual:Array<FloatArray>,
assertEqual: (Array<FloatArray>, Array<FloatArray>)->Unit,
){
val inputSize = input.size
val inputShape = Shape.make(inputSize.toLong())
EagerSession.create().use {
val tf = Ops.create(it)
val inputOp = tf.constant(input)
layer.build(tf, KGraph(Graph().toGraphDef()), inputShape)
val isTraining = tf.constant(true)
val numberOfLosses = tf.constant(1.0f)
val output = layer.forward(tf, inputOp, isTraining, numberOfLosses).asOutput().tensor()

val expectedShape = Shape.make(
expected.size.toLong(),
expected[0].size.toLong()
)

val actualShape = shapeFromDims(*output.shape())
output.copyTo(actual)
assertEquals(expectedShape, actualShape)
assertEqual(expected,actual)
}
}

private fun assertGlobalAvgPool1DEquals(
expected: Array<FloatArray>,
actual: Array<FloatArray>
) {
val expectedTensor = expected
val actualTensor = actual
val msg = "Expected ${expectedTensor.contentDeepToString()} " +
"to equal ${actualTensor.contentDeepToString()}"
for (i in expectedTensor.indices) {
assertArrayEquals(expectedTensor[i], actualTensor[i], EPS, msg)
}
}
}

0 comments on commit 7e8c7d9

Please sign in to comment.