Skip to content

Commit

Permalink
[PT FE] Extract signature for FX scenario (#27014)
Browse files Browse the repository at this point in the history
### Details:
 - *Extract signature from `forward`*
 - *Remove `gpt2` from hf tests as it was moved to `test_llm.py`*

### Tickets:
 - *ticket-id*
  • Loading branch information
mvafin authored Oct 11, 2024
1 parent f7081a7 commit e2b09ea
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import torch
import inspect

from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
Expand Down Expand Up @@ -77,6 +78,10 @@ def __init__(self, pt_module, fx_gm=None, nodes=None, mark_node_callback=None, i
if not input_types or len(input_types) == 0:
self.input_types = found_types

if hasattr(pt_module, "forward"):
input_params = inspect.signature(pt_module.forward).parameters
self._input_signature = list(input_params)

elif issubclass(type(pt_module), torch.fx.Node):

self._nodes = nodes # passed from outer context
Expand Down
16 changes: 16 additions & 0 deletions tests/layer_tests/py_frontend_tests/test_torch_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,3 +667,19 @@ def test_pytorch_decoder_can_convert_scripted_function():
scripted = torch.jit.script(f)
model = convert_model(scripted, input=[Type.f32, Type.f32])
assert model is not None


@pytest.mark.precommit
def test_pytorch_fx_decoder_extracts_signature():
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder

class TestModel(torch.nn.Module):
def forward(self, a, b):
return a["x"] + a["y"] + b

example = ({"x": torch.tensor(1), "y": torch.tensor(2)}, torch.tensor(3))
em = torch.export.export(TestModel(), example)
nc_decoder = TorchFXPythonDecoder(em.module())
assert nc_decoder.get_input_signature_name(0) == "a"
assert nc_decoder.get_input_signature_name(1) == "b"
assert nc_decoder._input_signature == ["a", "b"]
1 change: 0 additions & 1 deletion tests/model_hub_tests/pytorch/hf_transformers_models
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ git,microsoft/git-large-coco,xfail,Trace error: We don't have an op for aten::fu
glpn,vinvino02/glpn-nyu
gpt_bigcode,bigcode/tiny_starcoder_py
gpt_neo,EleutherAI/gpt-neo-2.7B
gpt2,openai-community/gpt2
gptsan-japanese,Tanrei/GPTSAN-japanese,xfail,Unsupported op aten::index_put_ prim::TupleConstruct prim::TupleUnpack
graphormer,clefourrier/graphormer-base-pcqm4mv2,xfail,Load error: GraphormerForGraphClassification.forward() missing 6 required positional arguments: 'input_edges' 'attn_bias' 'in_degree' 'out_degree' 'spatial_pos' and 'attn_edge_type'
grounding-dino,IDEA-Research/grounding-dino-base,xfail,Trace error: op->outputs().size() == 1 INTERNAL ASSERT FAILED
Expand Down
9 changes: 8 additions & 1 deletion tests/model_hub_tests/pytorch/test_hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,13 +520,20 @@ def load_model_with_default_class(name, **kwargs):
("bert-base-uncased", "bert"),
("google/flan-t5-base", "t5"),
("google/tapas-large-finetuned-wtq", "tapas"),
("gpt2", "gpt2"),
("openai/clip-vit-large-patch14", "clip"),
])
@pytest.mark.precommit
def test_convert_model_precommit(self, name, type, ie_device):
self.run(model_name=name, model_link=type, ie_device=ie_device)

@pytest.mark.parametrize("name,type", [("bert-base-uncased", "bert"),
("openai/clip-vit-large-patch14", "clip"),
])
@pytest.mark.precommit
def test_convert_model_precommit_export(self, name, type, ie_device):
self.mode = "export"
self.run(model_name=name, model_link=type, ie_device=ie_device)

@pytest.mark.parametrize("type,name,mark,reason",
get_models_list(os.path.join(os.path.dirname(__file__), "hf_transformers_models")))
@pytest.mark.nightly
Expand Down
20 changes: 3 additions & 17 deletions tests/model_hub_tests/pytorch/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,28 +68,14 @@ def convert_model_impl(self, model_obj):
from packaging import version

model_obj.eval()
graph = None
if isinstance(self.example, dict):
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder

# need to infer before export to initialize everything, otherwise it will be initialized with FakeTensors
pt_res = model_obj(**self.example)
graph = export(model_obj, tuple(), self.example)
if version.parse(torch.__version__) >= version.parse("2.2"):
from torch._decomp import get_decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
decomp = get_decompositions(get_export_decomposition_list())
graph = graph.run_decompositions(decomp_table=decomp)

gm = graph.module()
print(gm.code)

decoder = TorchFXPythonDecoder(gm)
decoder._input_signature = list(self.example.keys())
ov_model = convert_model(decoder, verbose=True)
graph = export(model_obj, args=tuple(), kwargs=self.example)
else:
pt_res = model_obj(*self.example)
graph = export(model_obj, self.example)
ov_model = convert_model(graph, verbose=True)
ov_model = convert_model(graph, verbose=True)

if isinstance(pt_res, dict):
for i, k in enumerate(pt_res.keys()):
Expand Down

0 comments on commit e2b09ea

Please sign in to comment.