Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch MatMul Substitution #1292

Closed
wants to merge 36 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5e5741a
add substitution for functional linear
itai-berman Nov 10, 2024
29df8e5
change test and substitution to support kwargs input, update output s…
itai-berman Nov 18, 2024
6500bf8
add matmul decomposition substituion and tests
itai-berman Dec 12, 2024
1d3ba74
fix license
itai-berman Dec 12, 2024
d30b544
add substitution for functional linear
itai-berman Nov 10, 2024
37ac563
change test and substitution to support kwargs input, update output s…
itai-berman Nov 18, 2024
09f6df2
add matmul decomposition substituion and tests
itai-berman Dec 12, 2024
fd790bb
fix license
itai-berman Dec 12, 2024
96cd728
Merge remote-tracking branch 'origin/pytorch_matmul_substitution' int…
itai-berman Dec 12, 2024
ef3270b
restricting matplotlib to version < 3.10.0 to fix an issue with tenso…
ofirgo Dec 19, 2024
b125f96
Refactor Target Platform Capabilities - Phase 3 (#1297)
lior-dikstein Dec 22, 2024
3c536f2
TPC attach 2 framework (#1296)
ofirgo Dec 23, 2024
ceaf820
Replace max tensor with max cut (#1295)
elad-c Dec 25, 2024
ae8e60c
Fix bug (#1303)
elad-c Dec 29, 2024
dcbfc5c
Move splitting ops from default to qpreserving configs in TPCv4. (#1304)
elad-c Dec 30, 2024
afed6e3
Refactor Target Platform Capabilities - Phase 4 (#1301)
lior-dikstein Dec 31, 2024
5b2a55d
add substitution for functional linear
itai-berman Nov 10, 2024
14b6db7
change test and substitution to support kwargs input, update output s…
itai-berman Nov 18, 2024
b677a5a
add matmul decomposition substituion and tests
itai-berman Dec 12, 2024
8b42444
fix license
itai-berman Dec 12, 2024
8403eca
add substitution for functional linear
itai-berman Nov 10, 2024
b198bec
change test and substitution to support kwargs input, update output s…
itai-berman Nov 18, 2024
9fa5710
fix license
itai-berman Dec 12, 2024
9a480b3
add substitution for functional linear
itai-berman Nov 10, 2024
728cd90
change test and substitution to support kwargs input, update output s…
itai-berman Nov 18, 2024
f535add
fix license
itai-berman Dec 12, 2024
aba2cb0
add substitution for functional linear
itai-berman Nov 10, 2024
99f7694
change test and substitution to support kwargs input, update output s…
itai-berman Nov 18, 2024
98d78a4
fix license
itai-berman Dec 12, 2024
bdc9057
add tests for different input dimensions with only float tpc
itai-berman Jan 2, 2025
6e8246e
skip unnecessary nodes, add expand to support broadcasting
itai-berman Jan 2, 2025
8ae4bed
Merge remote-tracking branch 'origin/pytorch_matmul_substitution' int…
itai-berman Jan 2, 2025
3c3d99b
set use is close validation flag True in test
itai-berman Jan 2, 2025
40f038d
move substitution after scaled dot product
itai-berman Jan 5, 2025
05f0021
change split to unbind and cat to stack due to converter's supported …
itai-berman Jan 5, 2025
65fde0f
fix test to run both function and operator
itai-berman Jan 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions model_compression_toolkit/core/common/fusion/graph_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
The fusion process involves:
1. Creating new fused nodes to represent these groups.
2. Updating the graph structure to replace the original nodes with fused nodes.
3. Maintaining mapping mapping of original node names to their fused node names.
3. Maintaining mapping of original node names to their fused node names.

Args:
graph: Graph to sue its nodes.
graph: Graph to fuse its nodes.

Returns:
Mapping of original node names to their fused node names
Expand All @@ -54,7 +54,8 @@ def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
fused_nodes_mapping[node.name] = new_fused_node.name
return fused_nodes_mapping

def _create_fused_node(self, nodes: List[BaseNode]) -> BaseNode:
@staticmethod
def _create_fused_node(nodes: List[BaseNode]) -> BaseNode:
"""
Create a new node that represents the fusion of the given nodes.

Expand All @@ -79,10 +80,10 @@ def _create_fused_node(self, nodes: List[BaseNode]) -> BaseNode:

return fused_node

def _replace_nodes_with_fused_node(self,
graph: Graph,
nodes_to_fuse: List[BaseNode],
fused_node: BaseNode):
@staticmethod
def _replace_nodes_with_fused_node(graph: Graph,
nodes_to_fuse: List[BaseNode],
fused_node: BaseNode):
"""
Replace the specified nodes in the graph with a new fused node.

Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def set_tpc(self,
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature '
'request or an issue if you believe this should be supported.') # pragma: no cover
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]):
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_configurations]):
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover

self.tpc = tpc
Expand Down
6 changes: 3 additions & 3 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,12 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list
_node_qc_options = node_qc_options.quantization_configurations
if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])
for op_cfg in qc_opts.quantization_configurations])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
Expand All @@ -599,7 +599,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
# a qco from quantization_configurations with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def compute_graph_max_cut(memory_graph: MemoryGraph,
estimate = (u_bound + l_bound) / 2
schedule, max_cut_size, cuts = max_cut_astar.solve(estimate_factor=estimate, iter_limit=astar_n_iter)
if schedule is None:
return last_result
l_bound = estimate
else:
u_bound = min(estimate, max_cut_size)
last_result = (schedule, max_cut_size, cuts)

next_u_bound = min(estimate, max_cut_size)
last_result = (schedule, max_cut_size, cuts)

if l_bound * (1 + eps) >= next_u_bound:
return last_result
if l_bound * (1 + eps) >= u_bound:
return last_result

it += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
cut_route = routes[next_cut]

if next_cut == self.target_cut:
# TODO maxcut: Why do we filter the cuts (cut_route) but not the max cut size (cut_sost).
# This is a mismatch between max_cut and max(cuts).
# Also, unfiltered cut_route seems perfect, including input and output tensor sizes of current op.
return self._remove_dummys_from_path(cut_route[0].op_order), cut_cost,\
list(set([self._remove_dummys_from_cut(self.clean_memory_for_next_step(c)) for c in cut_route]))

Expand All @@ -178,7 +181,8 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
cost = self.accumulate(cut_cost, c.memory_size())
if c not in open_list:
self._update_expanded_node(c, cost, cut_route, open_list, costs, routes)
elif self.ordering(cost, costs[c]):
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
elif self.ordering(cost, costs[c]): # pragma: no cover
# If we already saw this cut during the search with a larger cost, then we want to update the order
# of the schedule in the cut
# Remove call - removes the cut with the same memory elements but different ordering from open
Expand All @@ -187,7 +191,8 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
self._update_expanded_node(c, cost, cut_route, open_list, costs, routes)

# Halt or No Solution
return None, 0, None
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
return None, 0, None # pragma: no cover

@staticmethod
def _update_expanded_node(cut: Cut, cost: float, route: List[Cut], open_list: List[Cut],
Expand Down Expand Up @@ -223,8 +228,7 @@ def _get_cut_to_expand(self, open_list: List[Cut], costs: Dict[Cut, float], rout

"""
ordered_cuts_list = sorted(open_list,
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate_factor)), len(routes[c])),
reverse=False)
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate_factor)), -len(routes[c])))

assert len(ordered_cuts_list) > 0
return ordered_cuts_list[0]
Expand Down Expand Up @@ -349,7 +353,8 @@ def ordering(cost_1, cost_2) -> bool:
Returns: True if the first cost is smaller than the second one, else otherwise.

"""
return cost_1 < cost_2
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
return cost_1 < cost_2 # pragma: no cover

def estimate(self, cut: Cut, estimate_factor: float) -> float:
"""
Expand Down Expand Up @@ -377,9 +382,10 @@ def get_init_estimate_factor(memory_graph: MemoryGraph) -> float:
Returns: An initial estimate value.

"""
l_bound = memory_graph.memory_lbound_single_op
u_bound = 2 * sum([t.total_size for t in memory_graph.b_nodes]) - l_bound
return (u_bound + l_bound) / 2
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
l_bound = memory_graph.memory_lbound_single_op # pragma: no cover
u_bound = 2 * sum([t.total_size for t in memory_graph.b_nodes]) - l_bound # pragma: no cover
return (u_bound + l_bound) / 2 # pragma: no cover

@staticmethod
def _remove_dummys_from_path(path: List[BaseNode]) -> List[BaseNode]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ def __init__(self, shape: Tuple[Any], node_name: str, node_output_index: int, in
init_size_to_zero: Whether to initialize the memory tensor size to 0 or not.
"""

self.shape = shape[1:] # remove batch size (first element) from output shape
# remove batch size (first element) from output shape. If the shape is a list then remove the first
# axis. If shape a vector (e.g. output of size) then set the shape minus 1 to ignore the batch value.
if len(shape) == 1:
self.shape = [] if shape[0] is None else [shape[0] - 1]
else:
self.shape = shape[1:]
# The total size of a tensor is considered to be the number of elements in the tensor
self.total_size = self._get_tensor_total_size() if not init_size_to_zero else 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
from typing import List
from operator import getitem

from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.edge import EDGE_SOURCE_INDEX
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(self, model_graph: Graph):
tensor_to_node = []

for n in nodes:
n_outputs = [n.output_shape] if isinstance(n.output_shape, tuple) else n.output_shape
n_outputs = n.output_shape if isinstance(n.output_shape[0], (tuple, list)) else [n.output_shape]

out_edges = model_graph.out_edges(n, sort_by_attr=EDGE_SOURCE_INDEX)

for i, ot in enumerate(n_outputs):
Expand All @@ -54,7 +56,16 @@ def __init__(self, model_graph: Graph):
# Add memory tensor as current node's output
node_to_tensor.append((n, memory_tensor))

ot_edges = [oe for oe in out_edges if oe.source_index == i]
# TODO maxcut: refactor this code. it handles split->getitem generated by fx.
ot_edges = []
for oe in out_edges:
if oe.sink_node.type is getitem and len(oe.sink_node.op_call_args) == 1 and isinstance(oe.sink_node.op_call_args[0], int):
source_index = oe.sink_node.op_call_args[0]
else:
source_index = oe.source_index
if source_index == i:
ot_edges.append(oe)

for oe in ot_edges:
# Add current memory tensor as input to current node's successors
tensor_to_node.append((memory_tensor, oe.sink_node))
Expand All @@ -71,6 +82,7 @@ def __init__(self, model_graph: Graph):
inputs_tensors_memory = [sum([t.total_size for t in self.operation_node_children(n)])
for n in nodes if n in model_graph.get_inputs()]

# TODO maxcut: why both inputs and outputs of each nodes, while the A* solves for node outputs only???
nodes_total_memory = [sum([t.total_size for t in self.operation_node_children(n)] +
[t.total_size for t in self.operation_node_parents(n)])
for n in nodes if n not in model_graph.get_inputs()]
Expand Down
Loading
Loading