Skip to content

Commit

Permalink
check disconnected inputs in torch (#1360)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby authored Feb 18, 2025
1 parent 0ec73ad commit 3dfdbfe
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def remove_node(self,
if node_to_remove in self.get_inputs(): # If node is in the graph's inputs, the inputs should be updated
if new_graph_inputs is None:
Logger.critical(
f'{node_to_remove.name} s among the graph inputs; however, it cannot be removed without providing a new input.') # pragma: no cover
f'{node_to_remove.name} is among the graph inputs; however, it cannot be removed without providing a new input.') # pragma: no cover
self.set_inputs(new_graph_inputs)

# Make sure there are no connected edges left to the node before removing it.
Expand Down
3 changes: 3 additions & 0 deletions model_compression_toolkit/core/pytorch/reader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,8 @@ def model_reader(model: torch.nn.Module,
logging.info("Start Model Reading...")
fx_model = fx_graph_module_generation(model, representative_data_gen, to_tensor)
graph = build_graph(fx_model, to_numpy)
disconnected_inputs = [n.name for n in graph.get_inputs() if not graph.out_edges(n)]
if disconnected_inputs:
raise ValueError(f'The network contains disconnected input(s): {disconnected_inputs}.')
graph = remove_broken_nodes_from_graph(graph)
return graph
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import unittest
import torch
import numpy as np
from model_compression_toolkit.core.pytorch import pytorch_implementation

from model_compression_toolkit.core.pytorch.reader.reader import fx_graph_module_generation
from model_compression_toolkit.core.pytorch.pytorch_implementation import to_torch_tensor
from model_compression_toolkit.core.pytorch.pytorch_implementation import to_torch_tensor, PytorchImplementation


class BadFxModel(torch.nn.Module):
Expand All @@ -37,16 +38,32 @@ def forward(self, inputs, flag=False):
return x


class TestGraphReading(unittest.TestCase):
def data_gen():
yield [np.zeros((1, 3, 20, 20))]


def test_graph_reading(self):
class TestGraphReading(unittest.TestCase):
def test_fx_tracer_error(self):
model = BadFxModel()
try:
graph = fx_graph_module_generation(model,
lambda : np.zeros((1, 3, 20, 20)),
to_torch_tensor)
except Exception as e:
self.assertEqual(str(e).split('\n')[0], 'Error parsing model with torch.fx')

with self.assertRaises(Exception) as e:
fx_graph_module_generation(model,
data_gen,
to_torch_tensor)
self.assertEqual(str(e.exception).split('\n')[0], 'Error parsing model with torch.fx')

def test_disconnected_input(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)

def forward(self, x, unused=None):
return self.conv(x)

with self.assertRaises(ValueError) as e:
PytorchImplementation().model_reader(Model(), data_gen)
self.assertEqual(str(e.exception), r"The network contains disconnected input(s): ['unused'].")


if __name__ == '__main__':
Expand Down

0 comments on commit 3dfdbfe

Please sign in to comment.