diff --git a/model_compression_toolkit/core/common/substitutions/remove_identity.py b/model_compression_toolkit/core/common/substitutions/remove_identity.py index fa9929306..c391b8194 100644 --- a/model_compression_toolkit/core/common/substitutions/remove_identity.py +++ b/model_compression_toolkit/core/common/substitutions/remove_identity.py @@ -5,7 +5,7 @@ def remove_identity_node(graph: Graph, node: BaseNode) -> Graph: """ - The method to perform the substitution of the `torch.nn.Identity` node by + The method to perform the substitution of the identity node by reconnecting its input directly to its output, effectively removing the node from the graph. diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_identity.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_identity.py index c92f8bcdb..b57d17f42 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_identity.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_identity.py @@ -25,7 +25,7 @@ class RemoveIdentity(common.BaseSubstitution): """ - Remove `torch.nn.Identity` layers from the graph. + Remove Identity layers from the graph. """ def __init__(self): @@ -36,7 +36,7 @@ def substitute(self, graph: Graph, node: BaseNode) -> Graph: """ - The method to perform the substitution of the `torch.nn.Identity` node by + The method to perform the substitution of the identity keras node by reconnecting its input directly to its output, effectively removing the node from the graph.