Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup model copying #503

Merged
merged 3 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,9 @@ public interface InferenceModel : AutoCloseable {
public fun reshape(vararg dims: Long)

/**
* Creates a copy.
* Creates a copy of this model.
*
* @param [copiedModelName] Set up this name to make a copy with a new name.
* @return A copied inference model.
*/
public fun copy(
copiedModelName: String? = null,
saveOptimizerState: Boolean = false,
copyWeights: Boolean = true
): InferenceModel
public fun copy(): InferenceModel
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,8 @@ public open class OnnxInferenceModel private constructor(
}
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): OnnxInferenceModel {
override fun copy(): OnnxInferenceModel {
val model = OnnxInferenceModel(modelSource)
model.name = copiedModelName
if (inputShape != null) {
model.reshape(*inputDimensions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.freeze
import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
Expand Down Expand Up @@ -43,28 +42,26 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
return input to output[layers.last()]!!
}

/** Returns a copy of this model. */
// TODO: support saveOptimizerState=true with assignment of intermediate optimizer state
public fun copy(saveOptimizerState: Boolean = false, copyWeights: Boolean = true): Functional {
val serializedModel = serializeModel(true)
val deserializedModel = deserializeFunctionalModel(serializedModel)
if (!copyWeights) {
return deserializedModel
} else {
// TODO: make deep copies, not just links
deserializedModel.compile(
optimizer = this.optimizer,
loss = this.loss,
metrics = this.metrics
)

deserializedModel.layers.forEach {
it.weights = this.getLayer(it.name).weights
}

deserializedModel.isModelInitialized = true
override fun copy(): Functional {
return copy(copiedModelName = null, copyOptimizerState = false, copyWeights = true)
}

return deserializedModel
/**
* Creates a copy of this model.
*
* @param [copiedModelName] a name for the copy
* @param [copyOptimizerState] whether optimizer state needs to be copied
* @param [copyWeights] whether model weights need to be copied
* @return A copied inference model.
*/
public fun copy(copiedModelName: String? = null,
copyOptimizerState: Boolean = false,
copyWeights: Boolean = true
): Functional {
val serializedModel = serializeModel(true)
return deserializeFunctionalModel(serializedModel).also { modelCopy ->
if (copiedModelName != null) modelCopy.name = copiedModelName
if (copyWeights) copyWeightsTo(modelCopy, copyOptimizerState)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,27 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
variable.initializerOperation.run(session)
}

protected fun copyWeightsTo(model: GraphTrainableModel, copyOptimizerState: Boolean) {
// TODO: make deep copies, not just links
model.compile(
optimizer = this.optimizer,
loss = this.loss,
metrics = this.metrics
)

model.layers.forEach {
it.weights = this.getLayer(it.name).weights
}

if (copyOptimizerState) {
val optimizerVariables = kGraph.variableNames().filter(::isOptimizerVariable)
copyVariablesToModel(model, optimizerVariables)
model.isOptimizerVariableInitialized = true
}

model.isModelInitialized = true
}

/**
* Return layer by [layerName].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package org.jetbrains.kotlinx.dl.api.core
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
import org.tensorflow.op.core.Placeholder
Expand Down Expand Up @@ -41,27 +40,26 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) {
return input to output
}

/** Returns a copy of this model. */
// TODO: implement the saving of optimizer state
public fun copy(saveOptimizerState: Boolean = false, copyWeights: Boolean = true): Sequential {
val serializedModel = serializeModel(true)
val deserializedModel = deserializeSequentialModel(serializedModel)
if (!copyWeights) {
return deserializedModel
} else {
deserializedModel.compile(
optimizer = this.optimizer,
loss = this.loss,
metrics = this.metrics
)

deserializedModel.layers.forEach {
it.weights = this.getLayer(it.name).weights
}

deserializedModel.isModelInitialized = true
override fun copy(): Sequential {
return copy(copiedModelName = null, copyOptimizerState = false, copyWeights = true)
}

return deserializedModel
/**
* Creates a copy of this model.
*
* @param [copiedModelName] a name for the copy
* @param [copyOptimizerState] whether optimizer state needs to be copied
* @param [copyWeights] whether model weights need to be copied
* @return A copied inference model.
*/
public fun copy(copiedModelName: String? = null,
copyOptimizerState: Boolean = false,
copyWeights: Boolean = true
): Sequential {
val serializedModel = serializeModel(true)
return deserializeSequentialModel(serializedModel).also { modelCopy ->
if (copiedModelName != null) modelCopy.name = copiedModelName
if (copyWeights) copyWeightsTo(modelCopy, copyOptimizerState)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,12 @@ public open class TensorFlowInferenceModel : InferenceModel {
}
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean, // TODO, check this case
copyWeights: Boolean
): TensorFlowInferenceModel {
override fun copy(): TensorFlowInferenceModel {
return copy(copiedModelName = null)
}

/** Returns a copy of this model. */
public fun copy(copiedModelName: String? = null): TensorFlowInferenceModel {
val model = TensorFlowInferenceModel()
model.kGraph = this.kGraph.copy()
model.tf = Ops.create(model.kGraph.tfGraph)
Expand All @@ -158,27 +159,21 @@ public open class TensorFlowInferenceModel : InferenceModel {
model.input = input
model.output = output
if (copiedModelName != null) model.name = name
// TODO: check that tensors are closed after usage
if (copyWeights) {
val modelWeightsExtractorRunner = session.runner()
val variableNames = kGraph.variableNames()
check(variableNames.isNotEmpty()) {
"Found 0 variable names in TensorFlow graph $kGraph. " +
"If copied model has no weights, set flag `copyWeights` to `false`."
}
copyVariablesToModel(model, kGraph.variableNames())
model.isModelInitialized = true
return model
}

val variableNamesToCopy = variableNames.filter { variableName ->
saveOptimizerState || !isOptimizerVariable(variableName)
}
variableNamesToCopy.forEach(modelWeightsExtractorRunner::fetch)
val modelWeights = variableNamesToCopy.zip(modelWeightsExtractorRunner.run()).toMap()
protected fun copyVariablesToModel(model: TensorFlowInferenceModel, variableNames: List<String>) {
if (variableNames.isEmpty()) return

model.loadVariables(modelWeights.keys) { variableName, _ ->
modelWeights[variableName]!!.use { it.convertTensorToMultiDimArray() }
}
val modelWeightsExtractorRunner = session.runner()
variableNames.forEach(modelWeightsExtractorRunner::fetch)
val modelWeights = variableNames.zip(modelWeightsExtractorRunner.run()).toMap()

model.loadVariables(modelWeights.keys) { variableName, _ ->
modelWeights[variableName]!!.use { it.convertTensorToMultiDimArray() }
}
model.isModelInitialized = true
return model
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class TransferLearningTest : IntegrationTest() {

it.loadWeights(hdfFile)

val copy = it.copy()
val copy = it.copy(copyOptimizerState = false, copyWeights = true)
assertTrue(copy.layers.size == 11)
copy.close()

Expand Down