Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check disconnected inputs in torch #1360

Merged
merged 1 commit into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def remove_node(self,
if node_to_remove in self.get_inputs(): # If node is in the graph's inputs, the inputs should be updated
if new_graph_inputs is None:
Logger.critical(
f'{node_to_remove.name} s among the graph inputs; however, it cannot be removed without providing a new input.') # pragma: no cover
f'{node_to_remove.name} is among the graph inputs; however, it cannot be removed without providing a new input.') # pragma: no cover
self.set_inputs(new_graph_inputs)

# Make sure there are no connected edges left to the node before removing it.
Expand Down
3 changes: 3 additions & 0 deletions model_compression_toolkit/core/pytorch/reader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,8 @@ def model_reader(model: torch.nn.Module,
logging.info("Start Model Reading...")
fx_model = fx_graph_module_generation(model, representative_data_gen, to_tensor)
graph = build_graph(fx_model, to_numpy)
disconnected_inputs = [n.name for n in graph.get_inputs() if not graph.out_edges(n)]
if disconnected_inputs:
raise ValueError(f'The network contains disconnected input(s): {disconnected_inputs}.')
graph = remove_broken_nodes_from_graph(graph)
return graph
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import unittest
import torch
import numpy as np
from model_compression_toolkit.core.pytorch import pytorch_implementation

from model_compression_toolkit.core.pytorch.reader.reader import fx_graph_module_generation
from model_compression_toolkit.core.pytorch.pytorch_implementation import to_torch_tensor
from model_compression_toolkit.core.pytorch.pytorch_implementation import to_torch_tensor, PytorchImplementation


class BadFxModel(torch.nn.Module):
Expand All @@ -37,16 +38,32 @@ def forward(self, inputs, flag=False):
return x


class TestGraphReading(unittest.TestCase):
def data_gen():
yield [np.zeros((1, 3, 20, 20))]


def test_graph_reading(self):
class TestGraphReading(unittest.TestCase):
def test_fx_tracer_error(self):
model = BadFxModel()
try:
graph = fx_graph_module_generation(model,
lambda : np.zeros((1, 3, 20, 20)),
to_torch_tensor)
except Exception as e:
self.assertEqual(str(e).split('\n')[0], 'Error parsing model with torch.fx')

with self.assertRaises(Exception) as e:
fx_graph_module_generation(model,
data_gen,
to_torch_tensor)
self.assertEqual(str(e.exception).split('\n')[0], 'Error parsing model with torch.fx')

def test_disconnected_input(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)

def forward(self, x, unused=None):
return self.conv(x)

with self.assertRaises(ValueError) as e:
PytorchImplementation().model_reader(Model(), data_gen)
self.assertEqual(str(e.exception), r"The network contains disconnected input(s): ['unused'].")


if __name__ == '__main__':
Expand Down