From 1b6d3f46583af4a894868c602d6091580d08e031 Mon Sep 17 00:00:00 2001 From: samuel-wj-chapman Date: Tue, 9 Apr 2024 13:41:28 +0100 Subject: [PATCH] suport for concat aliases added --- .../substitutions/concat_threshold_update.py | 8 +++++--- .../substitutions/concat_threshold_update.py | 9 ++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py index d45199406..db04e5b41 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py @@ -15,13 +15,13 @@ from tensorflow.keras.layers import Concatenate +import tensorflow as tf from model_compression_toolkit.core import common from model_compression_toolkit.core.common import Graph, BaseNode from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher from model_compression_toolkit.constants import THRESHOLD -MATCHER = NodeOperationMatcher(Concatenate) class ConcatThresholdUpdate(common.BaseSubstitution): @@ -35,7 +35,9 @@ def __init__(self): """ Initialize a threshold_updater object. """ - super().__init__(matcher_instance=MATCHER) + concatination_node = NodeOperationMatcher(Concatenate) | \ + NodeOperationMatcher(tf.concat) + super().__init__(matcher_instance=concatination_node) def substitute(self, graph: Graph, @@ -58,7 +60,7 @@ def substitute(self, concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] prev_nodes = graph.get_prev_nodes(node) for prev_node in prev_nodes: - if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != Concatenate: + if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != Concatenate and prev_node.type != tf.concat: prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold return graph diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py index 8a346b888..6947b784e 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py @@ -24,9 +24,6 @@ from model_compression_toolkit.constants import THRESHOLD - -MATCHER = NodeOperationMatcher(torch.cat) - class ConcatThresholdUpdate(common.BaseSubstitution): """ Find concat layers and match their prior layers thresholds unless prior layer outputs to multiple layers. @@ -37,7 +34,9 @@ def __init__(self): """ Initialize a threshold_updater object. """ - super().__init__(matcher_instance=MATCHER) + concatination_node = NodeOperationMatcher(torch.cat) | \ + NodeOperationMatcher(torch.concat) + super().__init__(matcher_instance=concatination_node) def substitute(self, graph: Graph, @@ -60,7 +59,7 @@ def substitute(self, concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] prev_nodes = graph.get_prev_nodes(node) for prev_node in prev_nodes: - if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != torch.cat: + if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != torch.cat and prev_node.type != torch.concat: prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold return graph