Skip to content

Commit

Permalink
Cleanup code in model classes (#406)
Browse files Browse the repository at this point in the history
* Simplify Functional#forward function

* Make GraphTrainableModel#metricOps list immutable

* Extract Layer#setOutputShape function
  • Loading branch information
juliabeliaeva authored Jul 25, 2022
1 parent c6685c1 commit 8ab9d85
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 50 deletions.
38 changes: 10 additions & 28 deletions api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package org.jetbrains.kotlinx.dl.api.core
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.freeze
import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically
import org.jetbrains.kotlinx.dl.api.inference.keras.*
Expand Down Expand Up @@ -257,39 +258,20 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
inputLayer.build(tf)
inputLayer.computeOutputShape()

layers.filter { it !is Input }.forEach {
it.buildFromInboundLayers(tf)

val outputShape = it.computeOutputShapeFromInboundLayers()
val dims = outputShape.dims()

check(outputShape.tail().all { elem -> elem > 0 })
{
"The last dimensions (except first = -1) of shape of layer ${it.name} contains zero or negative dimension values: ${dims.contentToString()}.\n" +
"Analyze your model architecture and layer output shapes carefully to discover a problem."
}

it.outputShape = outputShape //TODO: Refactoring: it could be done inside computeOutputShapeMethods

logger.debug { "${it.name}; outputShape: $outputShape $it" }
layers.filter { it !is Input }.forEach { layer ->
layer.buildFromInboundLayers(tf)
val outputShape = layer.computeOutputShapeFromInboundLayers()
layer.setOutputShape(outputShape)
logger.debug { "${layer.name}; $layer; outputShape: $outputShape" }
}
}

override fun forward(input: Operand<Float>, inputLayer: Input): Operand<Float> {
var output: Operand<Float> = input
val outputByLayerName = mutableMapOf<String, Operand<Float>>()
val outputs = mutableListOf<Operand<Float>>()
outputs.add(input)
outputByLayerName[inputLayer.name] = input
for (layer in layers) {
for (inboundLayer in layer.inboundLayers) {
outputs.add(outputByLayerName[inboundLayer.name]!!)
}
output = layer.forward(tf, outputs, training, numberOfLossesOp)
outputByLayerName[layer.name] = output
outputs.clear()
val output = mutableMapOf<Layer, Operand<Float>>(inputLayer to inputLayer.forward(tf, input, training, numberOfLossesOp))
for (layer in layers.filter { it !is Input }) {
output[layer] = layer.forward(tf, layer.inboundLayers.map { output[it]!! }, training, numberOfLossesOp)
}
return output
return output[layers.last()]!!
}

override fun save(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
private lateinit var predictionOp: Operand<Float>

/** TensorFlow prediction operand. */
private lateinit var metricOps: MutableList<Operand<Float>>
private lateinit var metricOps: List<Operand<Float>>

/** A list of targets to be optimized. */
protected lateinit var targets: List<Operand<Float>>
Expand Down Expand Up @@ -185,10 +185,8 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
is SoftmaxCrossEntropyWithLogits -> tf.withName(OUTPUT_NAME).nn.softmax(yPredOp)
else -> tf.withName(OUTPUT_NAME).identity(yPredOp)
}
metricOps = mutableListOf()
metrics.forEach {
metricOps.add(it.apply(tf, predictionOp, yTrueOp, numberOfLossesOp))
}

metricOps = metrics.map { it.apply(tf, predictionOp, yTrueOp, numberOfLossesOp) }

isModelCompiled = true
}
Expand Down Expand Up @@ -435,7 +433,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
batchLabels: Tensor<Float>,
numberOfLosses: Tensor<Float>,
isTraining: Tensor<Float>,
metricOps: MutableList<Operand<Float>>
metricOps: List<Operand<Float>>
): Pair<Float, List<Float>> {
val runner = session.runner()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package org.jetbrains.kotlinx.dl.api.core

import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.inference.keras.*
Expand Down Expand Up @@ -145,22 +146,12 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) {
inputLayer.build(tf)
var inputShape: Shape = inputLayer.computeOutputShape()

layers.filter { it !is Input }.forEach {
it.build(tf, inputShape)
layers.filter { it !is Input }.forEach { layer ->
layer.build(tf, inputShape)

inputShape = it.computeOutputShape(inputShape)
val tensorShape = TensorShape(inputShape)
val dims = tensorShape.dims()

check(tensorShape.tail().all { elem -> elem > 0 })
{
"The last dimensions (except first = -1) of shape of layer ${it.name} contains zero or negative dimension values: ${dims.contentToString()}.\n" +
"Analyze your model architecture and layer output shapes carefully to discover a problem."
}

it.outputShape = tensorShape //TODO: Refactoring: it could be done inside computeOutputShapeMethods

logger.debug { "${it.name}; $it; outputShape: $tensorShape" }
inputShape = layer.computeOutputShape(inputShape)
layer.setOutputShape(inputShape)
logger.debug { "${layer.name}; $layer; outputShape: $inputShape" }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,17 @@ internal fun LongArray.toIntArray(): IntArray {
1 -> intArrayOf(this[0].toInt())
else -> IntArray(size) { this[it].toInt() }
}
}

internal fun Layer.setOutputShape(shape: Shape) {
setOutputShape(TensorShape(shape))
}

internal fun Layer.setOutputShape(tensorShape: TensorShape) {
check(tensorShape.tail().all { elem -> elem > 0 })
{
"The last dimensions (except first = -1) of shape of layer $name contains zero or negative dimension values: ${tensorShape}.\n" +
"Analyze your model architecture and layer output shapes carefully to discover a problem."
}
outputShape = tensorShape
}
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ internal class SequentialModelTest {
)
}
assertEquals(
"The last dimensions (except first = -1) of shape of layer maxpool2d_14 contains zero or negative dimension values: [-1, 0, 0, 128].\n" +
"The last dimensions (except first = -1) of shape of layer maxpool2d_14 contains zero or negative dimension values: [None, 0, 0, 128].\n" +
"Analyze your model architecture and layer output shapes carefully to discover a problem.",
exception.message
)
Expand Down

0 comments on commit 8ab9d85

Please sign in to comment.