Skip to content

Commit

Permalink
PR additions
Browse files Browse the repository at this point in the history
  • Loading branch information
edenlum committed Nov 30, 2023
1 parent 67bfc16 commit 9c69278
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,32 +14,20 @@
# ==============================================================================
from torch import nn
import torch.nn.functional as F
import copy

from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import BaseNode, Graph, FrameworkInfo
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.pytorch.constants import *


def functional_batch_norm_matcher() -> NodeOperationMatcher:
"""
Function generates matchers for matching:
functional.batch_norm.
Returns:
Matcher for batch_norm node.
"""
bn_node = NodeOperationMatcher(F.batch_norm)
return bn_node
from model_compression_toolkit.logger import Logger


class FunctionalBatchNorm(common.BaseSubstitution):
"""
Replace functional batch_norm with BatchNorm2d.
"""

def __init__(self, ):
def __init__(self):
"""
Matches: functional batch_norm
"""
Expand All @@ -59,7 +47,7 @@ def get_attributes_from_inputs(self, graph: Graph, node: BaseNode) -> dict:
gamma = None
beta = None
else:
raise ValueError(f'functional batch_norm is expected to have 3 or 5 inputs, got {len(input_nodes)}')
Logger.error(f'functional batch_norm is expected to have 3 or 5 inputs, got {len(input_nodes)}')

return {GAMMA: gamma,
BETA: beta,
Expand All @@ -73,7 +61,7 @@ def substitute(self,
Substitute functional.batch_norm and its inputs with BatchNorm2d.
Args:
graph: Graph we apply the substitution on.
node: nodes that match the pattern in the substitution init.
node: node that match the pattern in the substitution init.
Returns:
Graph after applying the substitution.
Expand All @@ -93,9 +81,6 @@ def substitute(self,
num_nodes_before_substitution = len(graph.nodes)
num_edges_before_substitution = len(graph.edges)

# new_batchnorm2d.prior_info = copy.deepcopy(node.prior_info)
# new_batchnorm2d.candidates_quantization_cfg = copy.deepcopy(node.candidates_quantization_cfg)

batch_norm_consts = graph.get_prev_nodes(node)[1:]
for const in batch_norm_consts:
graph.remove_edge(const, node)
Expand Down
13 changes: 9 additions & 4 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,19 @@ def test_bn_folding(self):
"""
This test checks the BatchNorm folding feature.
"""
assert False

def batch_norm_wrapper(channels):
return partial(nn.functional.batch_norm,
running_mean=torch.zeros(channels, device='cuda'),
running_var=torch.ones(channels, device='cuda'))
running_mean=torch.randn(channels, device='cuda'),
running_var=1+torch.randn(channels, device='cuda'))

for bn_layer in [nn.BatchNorm2d, batch_norm_wrapper]:
def batch_norm_wrapper_with_bias(channels):
return partial(nn.functional.batch_norm,
running_mean=torch.randn(channels, device='cuda'),
running_var=1+torch.randn(channels, device='cuda'),
bias=torch.randn(channels, device='cuda'))

for bn_layer in [nn.BatchNorm2d, batch_norm_wrapper, batch_norm_wrapper_with_bias]:
BNFoldingNetTest(self, nn.Conv2d(3, 2, kernel_size=1), bn_layer).run_test()
BNFoldingNetTest(self, nn.Conv2d(3, 3, kernel_size=3, groups=3)).run_test() # DW-Conv test
BNFoldingNetTest(self, nn.ConvTranspose2d(3, 2, kernel_size=(2, 1))).run_test()
Expand Down

0 comments on commit 9c69278

Please sign in to comment.