diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt index 51fed2a75..3c91e4a22 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt @@ -8,6 +8,7 @@ package org.jetbrains.kotlinx.dl.api.core import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.core.Input import org.jetbrains.kotlinx.dl.api.core.layer.freeze +import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically import org.jetbrains.kotlinx.dl.api.inference.keras.* @@ -257,39 +258,20 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) { inputLayer.build(tf) inputLayer.computeOutputShape() - layers.filter { it !is Input }.forEach { - it.buildFromInboundLayers(tf) - - val outputShape = it.computeOutputShapeFromInboundLayers() - val dims = outputShape.dims() - - check(outputShape.tail().all { elem -> elem > 0 }) - { - "The last dimensions (except first = -1) of shape of layer ${it.name} contains zero or negative dimension values: ${dims.contentToString()}.\n" + - "Analyze your model architecture and layer output shapes carefully to discover a problem." - } - - it.outputShape = outputShape //TODO: Refactoring: it could be done inside computeOutputShapeMethods - - logger.debug { "${it.name}; outputShape: $outputShape $it" } + layers.filter { it !is Input }.forEach { layer -> + layer.buildFromInboundLayers(tf) + val outputShape = layer.computeOutputShapeFromInboundLayers() + layer.setOutputShape(outputShape) + logger.debug { "${layer.name}; $layer; outputShape: $outputShape" } } } override fun forward(input: Operand, inputLayer: Input): Operand { - var output: Operand = input - val outputByLayerName = mutableMapOf>() - val outputs = mutableListOf>() - outputs.add(input) - outputByLayerName[inputLayer.name] = input - for (layer in layers) { - for (inboundLayer in layer.inboundLayers) { - outputs.add(outputByLayerName[inboundLayer.name]!!) - } - output = layer.forward(tf, outputs, training, numberOfLossesOp) - outputByLayerName[layer.name] = output - outputs.clear() + val output = mutableMapOf>(inputLayer to inputLayer.forward(tf, input, training, numberOfLossesOp)) + for (layer in layers.filter { it !is Input }) { + output[layer] = layer.forward(tf, layer.inboundLayers.map { output[it]!! }, training, numberOfLossesOp) } - return output + return output[layers.last()]!! } override fun save( diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt index 64ae295f1..d89a50353 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt @@ -78,7 +78,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel private lateinit var predictionOp: Operand /** TensorFlow prediction operand. */ - private lateinit var metricOps: MutableList> + private lateinit var metricOps: List> /** A list of targets to be optimized. */ protected lateinit var targets: List> @@ -185,10 +185,8 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel is SoftmaxCrossEntropyWithLogits -> tf.withName(OUTPUT_NAME).nn.softmax(yPredOp) else -> tf.withName(OUTPUT_NAME).identity(yPredOp) } - metricOps = mutableListOf() - metrics.forEach { - metricOps.add(it.apply(tf, predictionOp, yTrueOp, numberOfLossesOp)) - } + + metricOps = metrics.map { it.apply(tf, predictionOp, yTrueOp, numberOfLossesOp) } isModelCompiled = true } @@ -435,7 +433,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel batchLabels: Tensor, numberOfLosses: Tensor, isTraining: Tensor, - metricOps: MutableList> + metricOps: List> ): Pair> { val runner = session.runner() diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt index aa6ae4175..7e1c6f023 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt @@ -7,6 +7,7 @@ package org.jetbrains.kotlinx.dl.api.core import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape import org.jetbrains.kotlinx.dl.api.inference.keras.* @@ -145,22 +146,12 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) { inputLayer.build(tf) var inputShape: Shape = inputLayer.computeOutputShape() - layers.filter { it !is Input }.forEach { - it.build(tf, inputShape) + layers.filter { it !is Input }.forEach { layer -> + layer.build(tf, inputShape) - inputShape = it.computeOutputShape(inputShape) - val tensorShape = TensorShape(inputShape) - val dims = tensorShape.dims() - - check(tensorShape.tail().all { elem -> elem > 0 }) - { - "The last dimensions (except first = -1) of shape of layer ${it.name} contains zero or negative dimension values: ${dims.contentToString()}.\n" + - "Analyze your model architecture and layer output shapes carefully to discover a problem." - } - - it.outputShape = tensorShape //TODO: Refactoring: it could be done inside computeOutputShapeMethods - - logger.debug { "${it.name}; $it; outputShape: $tensorShape" } + inputShape = layer.computeOutputShape(inputShape) + layer.setOutputShape(inputShape) + logger.debug { "${layer.name}; $layer; outputShape: $inputShape" } } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Layer.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Layer.kt index 6b1e6c408..f6e75ef76 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Layer.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Layer.kt @@ -126,4 +126,17 @@ internal fun LongArray.toIntArray(): IntArray { 1 -> intArrayOf(this[0].toInt()) else -> IntArray(size) { this[it].toInt() } } +} + +internal fun Layer.setOutputShape(shape: Shape) { + setOutputShape(TensorShape(shape)) +} + +internal fun Layer.setOutputShape(tensorShape: TensorShape) { + check(tensorShape.tail().all { elem -> elem > 0 }) + { + "The last dimensions (except first = -1) of shape of layer $name contains zero or negative dimension values: ${tensorShape}.\n" + + "Analyze your model architecture and layer output shapes carefully to discover a problem." + } + outputShape = tensorShape } \ No newline at end of file diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt index 6e1d3eddf..e205610df 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt @@ -670,7 +670,7 @@ internal class SequentialModelTest { ) } assertEquals( - "The last dimensions (except first = -1) of shape of layer maxpool2d_14 contains zero or negative dimension values: [-1, 0, 0, 128].\n" + + "The last dimensions (except first = -1) of shape of layer maxpool2d_14 contains zero or negative dimension values: [None, 0, 0, 128].\n" + "Analyze your model architecture and layer output shapes carefully to discover a problem.", exception.message )