Skip to content

Commit

Permalink
Support TF 2.15
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c committed Mar 12, 2024
1 parent 0026b21 commit 10bf8c5
Show file tree
Hide file tree
Showing 15 changed files with 89 additions and 30 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/run_tests_python310_keras215.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Python 3.10, Keras 2.15
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *
pull_request:
branches:
- main


jobs:
run-tests:
uses: ./.github/workflows/run_keras_tests.yml
with:
python-version: "3.10"
tf-version: "2.15.*"
16 changes: 16 additions & 0 deletions .github/workflows/run_tests_python311_keras215.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Python 3.11, Keras 2.15
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *
pull_request:
branches:
- main


jobs:
run-tests:
uses: ./.github/workflows/run_keras_tests.yml
with:
python-version: "3.11"
tf-version: "2.15.*"
16 changes: 16 additions & 0 deletions .github/workflows/run_tests_python39_keras215.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Python 3.9, Keras 2.15
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 0 * * *
pull_request:
branches:
- main


jobs:
run-tests:
uses: ./.github/workflows/run_keras_tests.yml
with:
python-version: "3.9"
tf-version: "2.15.*"
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
| Python 3.11 | | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch20.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch20.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml) |


| | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 |
|-------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) |
| Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) |
| Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) |
| | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 | TensorFlow 2.15 |
|-------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) |
| Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) |
| Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) |


## Supported Features
Expand Down
14 changes: 8 additions & 6 deletions model_compression_toolkit/core/common/fusion/layer_fusing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def filter_fusing_patterns(fusing_patterns: List[List[Any]], node: BaseNode, idx
fusing_patterns after filtering non-relevant fusions
"""
valid_fusing_patterns = []
for i,fusing_pattern in enumerate(fusing_patterns):
for i, fusing_pattern in enumerate(fusing_patterns):
if idx < len(fusing_pattern):
if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or fusing_pattern[idx] == node.type:
if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or \
node.is_match_type(fusing_pattern[idx]):
valid_fusing_patterns.append(fusing_pattern)

# Return only valid patterns for this node
Expand All @@ -44,7 +45,7 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) ->
"""
Check if the fusion is valid: exist in fusing_patterns
Args:
fusing_patterns: supported fusings
fusing_patterns: supported fusing patterns
nodes: nodes which are participating in fusion
Returns:
whether the fusion in valid
Expand All @@ -56,8 +57,9 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) ->
if fusion_depth != len(fusing_pattern):
continue
counter = 0
for i,layer in enumerate(fusing_pattern):
if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or layer == nodes[i].type:
for i, layer in enumerate(fusing_pattern):
if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \
nodes[i].is_match_type(layer):
counter += 1
if counter == fusion_depth:
return True
Expand Down Expand Up @@ -107,7 +109,7 @@ def fusion(graph: Graph, tpc: TargetPlatformCapabilities) -> Graph:
if node in fused_nodes:
continue
# Start fusing search
fusing_nodes = [] # nodes that are candidates for participating in fusing
fusing_nodes = [] # nodes that are candidates for participating in fusing
patterns = copy.deepcopy(fusing_patterns)
next_nodes = [node]
for i in range(max_layers_fusing):
Expand Down
5 changes: 4 additions & 1 deletion model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,9 @@ def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions:
return tpc.layer2qco.get(self.type)
return tpc.tp_model.default_qco

def is_match_type(self, _type):
return _type == self.type or _type.__name__ == self.type.__name__

def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool:
"""
Check if the node matches a LayerFilterParams according to its
Expand All @@ -572,7 +575,7 @@ def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool
return False

# Check the node has the same type as the layer in LayerFilterParams
if layer_filter_params.layer != self.type:
if not self.is_match_type(layer_filter_params.layer):
return False

# Get attributes from node to filter
Expand Down
4 changes: 2 additions & 2 deletions model_compression_toolkit/core/common/graph/graph_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, operation: Any):

self.operation = operation

def apply(self, input_node_object: Any) -> bool:
def apply(self, input_node_object: BaseNode) -> bool:
"""
Check if input_node_object matches the matcher condition.
Expand All @@ -47,7 +47,7 @@ def apply(self, input_node_object: Any) -> bool:
return nothing.
"""

if input_node_object.type == self.operation:
if input_node_object.is_match_type(self.operation):
return True


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def create_add_node(add_value: float,
quantization_attr={},
layer_class=TFOpLambda,
op_call_args=[np.array(add_value, dtype=np.float32).reshape([1] * len(input_shape))],
op_call_kwargs={})
op_call_kwargs={},
functional_op=tf.add)
return add_node


Expand Down Expand Up @@ -157,7 +158,8 @@ def create_pad_node(next_node_name: str,
layer_class=TFOpLambda,
op_call_args=[],
op_call_kwargs={'paddings': num_elements_to_pad,
'constant_values': value_to_pad})
'constant_values': value_to_pad},
functional_op=tf.pad)

return pad_node

Expand Down
18 changes: 10 additions & 8 deletions model_compression_toolkit/core/keras/reader/node_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,24 @@

is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray))
is_tensor = lambda x: isinstance(x, KerasTensor)
tf_function_symbols = [TFOpLambda(f).symbol for f in [tf.add, tf.multiply, tf.subtract, tf.divide,
tf.truediv, tf.pow, tf.matmul]]


def get_kwargs2index(tf_func: Callable) -> Dict[str, int]:
def get_kwargs2index(tfoplambda_layer: tf.keras.layers.Layer) -> Dict[str, int]:
"""
Positional weights are saved according to their index in the node's call arguments, so
need to know the function arguments' names in case the weights are in the kwargs.
Args:
tf_func: functional node function.
tfoplambda_layer: TFOpLambda layer.
Returns:
A dictionary with argument number and index: {arg_name: arg_index}.
"""
if tf_func in [tf.add, tf.subtract, tf.divide, tf.truediv, tf.multiply, tf.pow,
tf.matmul, tf.image.crop_and_resize, tf.image.combined_non_max_suppression]:
return {arg_name: i for i, arg_name in enumerate(tf_inspect.getfullargspec(tf_func).args)}
if tfoplambda_layer.function in [tf.add, tf.subtract, tf.divide, tf.truediv, tf.multiply, tf.pow,
tf.matmul, tf.image.crop_and_resize, tf.image.combined_non_max_suppression] or \
tfoplambda_layer.symbol in ['__operators__.add', 'math.add', 'math.multiply', 'linalg.matmul', 'concat']:
return {arg_name: i for i, arg_name in enumerate(tf_inspect.getfullargspec(tfoplambda_layer.function).args)}
else:
return {}

Expand Down Expand Up @@ -110,7 +113,7 @@ def build_node(node: KerasNode,
# a flag to indicate that.
inputs_as_list = __is_functional_inputs_a_list(op_call_args)

kwarg2index = get_kwargs2index(keras_layer.function)
kwarg2index = get_kwargs2index(keras_layer)

# Functional nodes do not have weights, but may have constants in their call_args and\or
# call kwargs. Therefore, we extract these constants and save them in the node's weights as
Expand All @@ -124,8 +127,7 @@ def build_node(node: KerasNode,
# read weights from call args
for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args):
if is_const(arg) or (
keras_layer.function in [tf.add, tf.multiply, tf.subtract, tf.divide, tf.truediv, tf.pow,
tf.matmul] and
keras_layer.symbol in tf_function_symbols and
isinstance(arg, (tuple, list))):
weights.update({i: to_numpy(arg, is_single_tensor=True)})
# remove weights and KerasTensors and weights from op_call_args
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/quantization_prep_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def quantization_preparation_runner(graph: Graph,
fw_info,
core_config.quantization_config) # Mark points for statistics collection

for _data in tqdm(representative_data_gen(), "Statistics Collection:"):
for _data in tqdm(representative_data_gen(), "Statistics Collection"):
mi.infer(_data)

if tb_w is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from keras.src.layers import Conv2D, TFOpLambda, Add, DepthwiseConv2D, Dense
else:
from keras.layers import Conv2D, TFOpLambda, Add, DepthwiseConv2D, Dense
import tensorflow as tf

from tests.keras_tests.exporter_tests.keras_fake_quant.keras_fake_quant_exporter_base_test import \
KerasFakeQuantExporterBaseTest
Expand Down Expand Up @@ -59,7 +58,7 @@ def run_checks(self):
assert self.loaded_model.layers[7].output.ref() == self.loaded_model.layers[9].input.ref()

assert isinstance(self.loaded_model.layers[10], TFOpLambda)
assert self.loaded_model.layers[10].function == tf.add
assert self.loaded_model.layers[10].symbol == 'math.add'
assert self.loaded_model.layers[10].input.ref() == self.loaded_model.layers[8].output.ref()
assert self.loaded_model.layers[10].inbound_nodes[0].call_kwargs['y'].ref() == self.loaded_model.layers[9].output.ref()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
for layer in quantized_model.layers:
if type(layer) in [layers.Conv2D, layers.DepthwiseConv2D, layers.Conv2DTranspose, layers.Dense]:
self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!')
elif isinstance(layer, TFOpLambda) and layer.function is tf.add:
elif isinstance(layer, TFOpLambda) and (layer.function is tf.add or layer.symbol == TFOpLambda(tf.add).symbol):
num_adds += 1

# check all "add"s were folded except the one with 2 tensor inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=

num_matmuls = 0
for layer in quantized_model.layers:
if isinstance(layer, TFOpLambda) and layer.function is tf.matmul:
if isinstance(layer, TFOpLambda) and layer.symbol is TFOpLambda(tf.matmul).symbol:
num_matmuls += 1

# check all "matmul"s were replaced except the one with 2 tensor inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ def test_const_representation(self):
ConstRepresentationTest(self, func, c, use_kwrags=True, is_list_input=True).run_test()

ConstRepresentationMultiInputTest(self).run_test()

def test_second_moment(self):
DepthwiseConv2DSecondMomentTest(self).run_test()
# DepthwiseConv2DWithMultiplierSecondMomentTest(self).run_test()
Expand Down
Loading

0 comments on commit 10bf8c5

Please sign in to comment.