Skip to content

Commit

Permalink
[PT FE] Inherit signature from forward while patching (#27413)
Browse files Browse the repository at this point in the history
### Details:
 - *Inherit signature from forward while patching*

### Tickets:
 - *ticket-id*

---------

Signed-off-by: Maxim Vafin <[email protected]>
  • Loading branch information
mvafin authored Nov 7, 2024
1 parent 98f6ea5 commit 5c83460
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# flake8: noqa
# mypy: ignore-errors

import functools
import logging
import torch
from openvino.frontend.pytorch import ModuleExtension
Expand Down Expand Up @@ -70,6 +71,8 @@ def new_forward(*args, **kwargs):
Trampoline.stashed_kwargs = kwargs
return extension.convert(m, Trampoline.apply, *args, **kwargs)

# make signature of new_forward same as of forward
new_forward = functools.wraps(m.forward)(new_forward)
setattr(m, orig_forward_name, m.forward)
m.forward = new_forward

Expand Down
4 changes: 4 additions & 0 deletions tests/layer_tests/py_frontend_tests/test_torch_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ def test_patched_16bit_model_converts():
from openvino.frontend.pytorch import patch_model
from openvino import convert_model, compile_model
import copy
import inspect
from transformers.pytorch_utils import Conv1D

class ModelWithLinear(torch.nn.Module):
Expand Down Expand Up @@ -716,6 +717,9 @@ def forward(self, x1, x2):
model_fp16 = copy.deepcopy(model_ref).half()

patch_model.__make_16bit_traceable(model_fp16)
# verify torch.nn.Linear signature after patching
signature = inspect.signature(model_ref.branch1[0].forward).parameters
assert ["input"] == list(signature)
# the approach with patching only works for node with no grad
with torch.no_grad():
converted_model = convert_model(model_fp16, example_input=example)
Expand Down

0 comments on commit 5c83460

Please sign in to comment.