Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve parsing of PyTorch functional layers arguments #923

Merged
merged 6 commits into from
Jan 18, 2024
2 changes: 2 additions & 0 deletions model_compression_toolkit/core/pytorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@
BIAS = 'bias'
GAMMA = 'weight'
BETA = 'bias'
WEIGHT = 'weight'
MOVING_MEAN = 'running_mean'
MOVING_VARIANCE = 'running_var'
EPSILON = 'eps'
EPSILON_VAL = 1e-5
MOMENTUM = 'momentum'
MOMENTUM_VAL = 0.1
NORMALIZED_SHAPE = 'normalized_shape'
DIM = 'dim'
IN_CHANNELS = 'in_channels'
OUT_CHANNELS = 'out_channels'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, List

from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.pytorch.constants import *
from model_compression_toolkit.logger import Logger


class FunctionalLayerNorm(common.BaseSubstitution):
"""
Replace functional layer_norm with LayerNorm.
"""

def __init__(self):
"""
Matches: functional layer_norm
"""
ln_node = NodeOperationMatcher(F.layer_norm)
super().__init__(matcher_instance=ln_node)

def get_attributes_from_inputs(self, graph: Graph, node: BaseNode, normalized_shape: [Tuple, List, int]) -> dict:
"""
Parse layer_norm(input, normalized_shape, weight=None, bias=None)
Args:
graph: Graph we apply the substitution on.
node: Node that match the pattern in the substitution init.
normalized_shape: nn.LayerNorm "normalized_shape" argument

Returns:
Graph after applying the substitution.
"""

# Get input nodes (sorted)
input_nodes = graph.get_prev_nodes(node, sink_index_sorted=True)

# Define default weight and bias
w0 = np.ones(normalized_shape) # Default value in case weight is not given
b0 = np.zeros(normalized_shape) # Default value in case bias is not given

# Check if weight and/or bias were not given.
has_weight = WEIGHT not in node.framework_attr
has_bias = BIAS not in node.framework_attr

weight_input_ind = 1 if has_weight else 0
bias_input_ind = weight_input_ind + 1

return {
GAMMA: list(input_nodes[weight_input_ind].weights.values())[0] if has_weight else w0,
BETA: list(input_nodes[bias_input_ind].weights.values())[0] if has_bias else b0
}

def substitute(self,
graph: Graph,
node: BaseNode) -> Graph:
"""
Substitute functional.layer_norm and its inputs with LayerNorm.
Args:
graph: Graph we apply the substitution on.
node: node that match the pattern in the substitution init.

Returns:
Graph after applying the substitution.
"""
normalized_shape = node.input_shape[0][-1]

ln_node_weights = self.get_attributes_from_inputs(graph, node, normalized_shape)
if not ln_node_weights:
return graph
new_layernorm = BaseNode(name=node.name + '_into_LayerNorm',
framework_attr={NORMALIZED_SHAPE: normalized_shape,
EPSILON: EPSILON_VAL,
},
input_shape=node.output_shape,
output_shape=node.output_shape,
weights=ln_node_weights,
layer_class=nn.LayerNorm)

num_nodes_before_substitution = len(graph.nodes)
num_edges_before_substitution = len(graph.edges)

layer_norm_consts = graph.get_prev_nodes(node)[1:]
for const in layer_norm_consts:
graph.remove_edge(const, node)
graph.remove_node(const)

graph.replace_node(node, new_layernorm)

assert num_nodes_before_substitution - len(graph.nodes) == len(layer_norm_consts)
assert num_edges_before_substitution - len(graph.edges) == len(layer_norm_consts)

return graph
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
pytorch_batchnorm_refusing
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_batch_norm import \
FunctionalBatchNorm
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_layer_norm import \
FunctionalLayerNorm
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.linear_collapsing import \
pytorch_linear_collapsing
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
Expand Down Expand Up @@ -246,7 +248,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
MultiHeadAttentionDecomposition(),
PermuteCallMethod(),
ConstantHolderConv(fw_info),
FunctionalBatchNorm()]
FunctionalBatchNorm(),
FunctionalLayerNorm()]

def get_substitutions_pre_statistics_collection(self,
quant_config: QuantizationConfig
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a feature model test or a layer test? If layer there is a library for layer tests maybe add it there

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest

"""
This tests check the addition and subtraction operations.
Both with different layers and with constants.
"""
class LayerNormNet(torch.nn.Module):
def __init__(self, has_weight=None, has_bias=None):
super(LayerNormNet, self).__init__()

self.has_weight = has_weight
self.has_bias = has_bias
self.bias0 = torch.nn.Parameter(torch.rand(3))
self.weight0 = torch.nn.Parameter(torch.rand(3))
self.bias1 = torch.nn.Parameter(torch.rand(3))
self.weight1 = torch.nn.Parameter(torch.rand(3))

def forward(self, x):
# Transpose the tensor such that last dim is the channels.
x = torch.transpose(x, 1, 3)
x = torch.transpose(x, 1, 2)

# Apply layer_norm with all the combinations of arguments.
if self.has_weight and self.has_bias:
x = torch.nn.functional.layer_norm(x, normalized_shape=(3,), weight=self.weight0, bias=self.bias0)
elif self.has_weight and not self.has_bias:
x = torch.nn.functional.layer_norm(x, normalized_shape=(3,), weight=self.weight1) # Layer normalization along the last dimension
elif not self.has_weight and self.has_bias:
x = torch.nn.functional.layer_norm(x, normalized_shape=(3,), bias=self.bias1) # Layer normalization along the last dimension
else:
x = torch.nn.functional.layer_norm(x, normalized_shape=(3,)) # Layer normalization along the last dimension

return x



class LayerNormNetTest(BasePytorchTest):
"""
This tests check the addition and subtraction operations.
Both with different layers and with constants.
"""
def __init__(self, unit_test, has_weight=None, has_bias=None):
super().__init__(unit_test)
self.has_weight = has_weight
self.has_bias = has_bias

def create_inputs_shape(self):
return [[self.val_batch_size, 3, 32, 32]]

def create_feature_network(self, input_shape):
return LayerNormNet(self.has_weight, self.has_bias)
10 changes: 10 additions & 0 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import model_compression_toolkit as mct
from model_compression_toolkit.gptq.common.gptq_config import RoundingType
from tests.pytorch_tests.model_tests.feature_models.add_net_test import AddNetTest
from tests.pytorch_tests.model_tests.feature_models.layer_norm_net_test import LayerNormNetTest
from tests.pytorch_tests.model_tests.feature_models.conv2d_replacement_test import DwConv2dReplacementTest
from tests.pytorch_tests.model_tests.feature_models.mixed_precision_bops_test import MixedPrecisionBopsBasicTest, \
MixedPrecisionBopsAllWeightsLayersTest, MixedPrecisionWeightsOnlyBopsTest, MixedPrecisionActivationOnlyBopsTest, \
Expand Down Expand Up @@ -121,6 +122,15 @@ def test_add_net(self):
"""
AddNetTest(self).run_test()

def test_layer_norm_net(self):
"""
These tests check the nn.functional.layer_norm operations.
"""
LayerNormNetTest(self, has_weight=True, has_bias=True).run_test()
LayerNormNetTest(self, has_weight=True, has_bias=False).run_test()
LayerNormNetTest(self, has_weight=False, has_bias=True).run_test()
LayerNormNetTest(self, has_weight=False, has_bias=False).run_test()

def test_assert_net(self):
"""
This tests check that the assert operation is being
Expand Down
Loading