Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
� Conflicts:
�	api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt
  • Loading branch information
dosier committed Jun 16, 2021
2 parents 6347d65 + a5f478f commit 208cd56
Show file tree
Hide file tree
Showing 36 changed files with 1,981 additions and 904 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Variable
import java.lang.IllegalArgumentException

/**
* Base abstract class for all layers.
Expand Down Expand Up @@ -171,3 +172,8 @@ public abstract class Layer(public var name: String) {
/** Returns amount of neurons. */
public abstract val paramCount: Int
}

internal fun requireArraySize(array: LongArray, size: Int, name: String) =
require (array.size == size) {
"$name is expected to have size equal $size but got ${array.size}"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer.convolutional

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.shape.*
import org.jetbrains.kotlinx.dl.api.core.util.getDType
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Variable
import java.lang.IllegalArgumentException
import kotlin.math.roundToInt

public abstract class AbstractConv(
protected val filtersInternal: Long,
protected val kernelSizeInternal: LongArray,
protected val stridesInternal: LongArray,
protected val dilationsInternal: LongArray,
protected val activationInternal: Activations,
protected val kernelInitializerInternal: Initializer,
protected val biasInitializerInternal: Initializer,
protected val kernelRegularizerInternal: Regularizer?,
protected val biasRegularizerInternal: Regularizer?,
protected val activityRegularizerInternal: Regularizer?,
protected val paddingInternal: ConvPadding,
protected val useBiasInternal: Boolean,
protected val kernelVariableName: String,
protected val biasVariableName: String,
name: String
) : Layer(name) {
// weight tensors
protected lateinit var kernel: Variable<Float>
protected var bias: Variable<Float>? = null

// weight tensor shapes
protected lateinit var kernelShape: Shape
protected lateinit var biasShape: Shape

override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {
// Amount of channels should be the last value in the inputShape
val numberOfChannels = inputShape.size(inputShape.numDimensions() - 1)

// Compute shapes of kernel and bias matrices
computeMatricesShapes(numberOfChannels)

// should be calculated before addWeight because it's used in calculation,
// need to rewrite addWeight to avoid strange behaviour calculate fanIn, fanOut
val inputDepth = getInputDepth(numberOfChannels) // number of input channels
val outputDepth = getOutputDepth(numberOfChannels) // number of output channels

fanIn = (inputDepth * multiply(*kernelSizeInternal)).toInt()
fanOut = ((outputDepth * multiply(*kernelSizeInternal)).toDouble() /
multiply(*stridesInternal).toDouble()).roundToInt()

val (kernelVariableName, biasVariableName) = defineVariableNames()
createConvVariables(tf, kernelVariableName, biasVariableName, kGraph)
}

override fun computeOutputShape(inputShape: Shape): Shape {
val shape = defineOutputShape(inputShape)
outputShape = TensorShape(shape)
return shape
}

override fun forward(
tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> {
var output = convImplementation(tf, input)

if (useBiasInternal) {
output = tf.nn.biasAdd(output, bias)
}

return Activations.convert(activationInternal).apply(tf, output, name)
}

/** Returns the shape of kernel weights. */
public val kernelShapeArray: LongArray get() = TensorShape(kernelShape).dims()

/** Returns the shape of bias weights. */
public val biasShapeArray: LongArray get() = TensorShape(biasShape).dims()

override var weights: Map<String, Array<*>>
get() = extractConvWeights()
set(value) = assignWeights(value)

override val hasActivation: Boolean get() = true

override val paramCount: Int
get() = (kernelShape.numElements() + biasShape.numElements()).toInt()

private fun extractConvWeights(): Map<String, Array<*>> = extractWeights(defineVariableNames().toList())

private fun defineVariableNames(): Pair<String, String> = if (name.isNotEmpty()) {
Pair(kernelVarName(name), biasVarName(name))
} else {
Pair(kernelVariableName, biasVariableName)
}

private fun createConvVariables(
tf: Ops,
kernelVariableName: String,
biasVariableName: String,
kGraph: KGraph
) {
kernel = tf.withName(kernelVariableName).variable(kernelShape, getDType())
if (useBiasInternal) bias = tf.withName(biasVariableName).variable(biasShape, getDType())

kernel = addWeight(tf, kGraph, kernelVariableName, kernel, kernelInitializerInternal, kernelRegularizerInternal)
if (useBiasInternal) bias = addWeight(tf, kGraph, biasVariableName, bias!!, biasInitializerInternal, biasRegularizerInternal)
}

protected open fun getInputDepth(numberOfChannels: Long): Long = numberOfChannels

protected open fun getOutputDepth(numberOfChannels: Long): Long = filtersInternal

protected open fun computeMatricesShapes(numberOfChannels: Long) {
kernelShape = shapeFromDims(*kernelSizeInternal, numberOfChannels, filtersInternal)
biasShape = Shape.make(filtersInternal)
}

protected abstract fun kernelVarName(name: String): String

protected abstract fun biasVarName(name: String): String

protected abstract fun convImplementation(tf: Ops, input: Operand<Float>): Operand<Float>

protected abstract fun defineOutputShape(inputShape: Shape): Shape
}

private fun multiply(vararg values: Long) = values.fold(1L, Long::times)
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ import org.jetbrains.kotlinx.dl.api.core.activation.Activations
import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal
import org.jetbrains.kotlinx.dl.api.core.initializer.HeUniform
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.layer.requireArraySize
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.shape.convOutputLength
import org.jetbrains.kotlinx.dl.api.core.util.convBiasVarName
import org.jetbrains.kotlinx.dl.api.core.util.convKernelVarName
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Squeeze
import org.tensorflow.op.nn.Conv2d

private const val KERNEL_VARIABLE_NAME = "conv1d_kernel"

Expand All @@ -36,8 +42,7 @@ private const val EXTRA_DIM = 1L
*
* @property [filters] The dimensionality of the output space (i.e. the number of filters in the convolution).
* @property [kernelSize] Long number, specifying the width of the 1D convolution window.
* @property [strides] Three numbers specifying stride of the pooling
* operation for each dimension of input tensor.
* @property [strides] Three numbers specifying the strides of the pooling operation for each dimension of input tensor.
* NOTE: Specifying stride value != 1 is incompatible with specifying `dilation` value != 1.
* @property [dilations] Three numbers specifying the dilation rate to use for
* dilated convolution sequence dimensions of input tensor.
Expand Down Expand Up @@ -68,7 +73,7 @@ public class Conv1D(
public val padding: ConvPadding = ConvPadding.SAME,
public val useBias: Boolean = true,
name: String = "",
) : Conv2DImpl(
) : AbstractConv(
filtersInternal = filters,
kernelSizeInternal = longArrayOf(1, kernelSize),
stridesInternal = longArrayOf(strides[0], 1, strides[1], strides[2]),
Expand All @@ -85,22 +90,46 @@ public class Conv1D(
biasVariableName = BIAS_VARIABLE_NAME,
name = name
) {
init {
requireArraySize(strides, 3, "strides")
requireArraySize(dilations, 3, "dilations")
}

private val squeezeAxis = Squeeze.axis(listOf(EXTRA_DIM))

override fun forward(
override fun kernelVarName(name: String): String = convKernelVarName(name, dim = 1)

override fun biasVarName(name: String): String = convBiasVarName(name, dim = 1)

override fun convImplementation(
tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
input: Operand<Float>
): Operand<Float> {
val options = Conv2d.dilations(dilationsInternal.toList()).dataFormat("NHWC")
val reshapedInput = tf.expandDims(input, tf.constant(EXTRA_DIM))
val result = super.forward(tf, reshapedInput, isTraining, numberOfLosses)
val result =
tf.nn.conv2d(reshapedInput, kernel, stridesInternal.toMutableList(), paddingInternal.paddingName, options)
return tf.squeeze(result, squeezeAxis)
}

override fun toString(): String {
return "Conv2D(filters=$filters, kernelSize=$kernelSize, strides=$strides, " +
"dilation=$dilations, activation=$activation, kernelInitializer=$kernelInitializer, " +
"biasInitializer=$biasInitializer, kernelShape=$kernelShape, biasShape=$biasShape, padding=$padding)"
protected override fun defineOutputShape(inputShape: Shape): Shape {
val batchSize = inputShape.size(0)
val colsCount = inputShape.size(1)

val cols = convOutputLength(
colsCount,
kernelSize.toInt(),
paddingInternal,
strides[1].toInt(),
dilations[1].toInt()
)

return Shape.make(batchSize, cols, filtersInternal)
}

override fun toString(): String =
"Conv1D(filters=$filters, kernelSize=$kernelSize, strides=${strides.contentToString()}, " +
"dilation=${dilations.contentToString()}, activation=$activation, kernelInitializer=$kernelInitializer, " +
"biasInitializer=$biasInitializer, kernelShape=$kernelShape, biasShape=$biasShape, padding=$padding, " +
"biasRegularizer=$biasRegularizer, kernelRegularizer=$kernelRegularizer, activityRegularizer=$activityRegularizer)"
}
Loading

0 comments on commit 208cd56

Please sign in to comment.