Skip to content

Commit

Permalink
suport for concat aliases added
Browse files Browse the repository at this point in the history
  • Loading branch information
samuel-wj-chapman committed Apr 9, 2024
1 parent da66b7c commit 1b6d3f4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@


from tensorflow.keras.layers import Concatenate
import tensorflow as tf

from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.constants import THRESHOLD

MATCHER = NodeOperationMatcher(Concatenate)


class ConcatThresholdUpdate(common.BaseSubstitution):
Expand All @@ -35,7 +35,9 @@ def __init__(self):
"""
Initialize a threshold_updater object.
"""
super().__init__(matcher_instance=MATCHER)
concatination_node = NodeOperationMatcher(Concatenate) | \
NodeOperationMatcher(tf.concat)
super().__init__(matcher_instance=concatination_node)

def substitute(self,
graph: Graph,
Expand All @@ -58,7 +60,7 @@ def substitute(self,
concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD]
prev_nodes = graph.get_prev_nodes(node)
for prev_node in prev_nodes:
if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != Concatenate:
if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != Concatenate and prev_node.type != tf.concat:
prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold

return graph
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
from model_compression_toolkit.constants import THRESHOLD



MATCHER = NodeOperationMatcher(torch.cat)

class ConcatThresholdUpdate(common.BaseSubstitution):
"""
Find concat layers and match their prior layers thresholds unless prior layer outputs to multiple layers.
Expand All @@ -37,7 +34,9 @@ def __init__(self):
"""
Initialize a threshold_updater object.
"""
super().__init__(matcher_instance=MATCHER)
concatination_node = NodeOperationMatcher(torch.cat) | \
NodeOperationMatcher(torch.concat)
super().__init__(matcher_instance=concatination_node)

def substitute(self,
graph: Graph,
Expand All @@ -60,7 +59,7 @@ def substitute(self,
concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD]
prev_nodes = graph.get_prev_nodes(node)
for prev_node in prev_nodes:
if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != torch.cat:
if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != torch.cat and prev_node.type != torch.concat:
prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold

return graph
Expand Down

0 comments on commit 1b6d3f4

Please sign in to comment.