Skip to content

Commit

Permalink
[PT FE] Support torch==2.6.0 (#29196)
Browse files Browse the repository at this point in the history
### Details:
 - *Support `torch==2.6.0`*

### Tickets:
 - *CVS-162009*

---------

Signed-off-by: Maxim Vafin <[email protected]>
  • Loading branch information
mvafin authored Feb 28, 2025
1 parent 3444a4a commit bd266dc
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/job_pytorch_layer_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ env:
jobs:
PyTorch_Layer_Tests:
name: PyTorch Layer Tests
timeout-minutes: 40
timeout-minutes: 50
runs-on: ${{ inputs.runner }}
container: ${{ fromJSON(inputs.container) }}
defaults:
Expand Down
28 changes: 28 additions & 0 deletions src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class TorchFXPythonDecoder (BaseFXDecoder):
Decoder for PyTorch FX GraphModule and Node objects to OpenVINO IR.
"""

_decomp_table = None

def __init__(self, pt_module, fx_gm=None, nodes=None,
mark_node_callback=None, input_shapes=[], input_types=[], dynamic_shapes=False):
super().__init__(mark_node_callback)
Expand Down Expand Up @@ -230,6 +232,32 @@ def __init__(self, pt_module, fx_gm=None, nodes=None,
self.input_types.append(
BaseFXDecoder.get_type_for_value(arg))

@classmethod
def from_exported_program(cls, exported_program: torch.export.ExportedProgram) -> 'TorchFXPythonDecoder':
"""
Create a TorchFXPythonDecoder instance from an exported PyTorch program.
"""
from packaging import version
if version.parse(torch.__version__) >= version.parse("2.6"):
if cls._decomp_table is None:
from torch.export.decomp_utils import CustomDecompTable
from openvino.frontend.pytorch.torchdynamo.decompositions import ops_to_not_decompose
cls._decomp_table = CustomDecompTable()
for op in ops_to_not_decompose():
try:
cls._decomp_table.pop(op)
except KeyError as e:
logging.warning("Operation %s not found in decomp table", op, exc_info=e)
exported_program = exported_program.run_decompositions(cls._decomp_table)
elif version.parse(torch.__version__) >= version.parse("2.2"):
from torch._decomp import get_decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
decomp = get_decompositions(get_export_decomposition_list())
exported_program = exported_program.run_decompositions(decomp_table=decomp)
gm = exported_program.module()
logger.debug(gm.code)
return cls(gm, dynamic_shapes=True)

@staticmethod
def get_found_shape(value) -> str:
# If input is a tensor, read the shape from meta data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,11 @@ def get_export_decomposition_list():
except ImportError:
pass
return decomp


def ops_to_not_decompose():
# List of operations that shouldn't be decomposed
return [
torch.ops.aten.col2im.default,
torch.ops.aten.upsample_nearest2d.default,
]
6 changes: 0 additions & 6 deletions tests/layer_tests/pytorch_tests/test_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def forward(self, x, y, z, d):

@pytest.mark.nightly
@pytest.mark.precommit_torch_export
@pytest.mark.precommit_fx_backend
@pytest.mark.parametrize("dtype", [None,
skip_if_export("float32"),
skip_if_export("float64"),
Expand All @@ -124,7 +123,6 @@ def test_arange_end_only(self, dtype, end, use_out, ie_device, precision, ir_ver
kwargs_to_prepare_input={"end": end})

@pytest.mark.nightly
@pytest.mark.precommit_fx_backend
@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"])
@pytest.mark.parametrize("start,end", [(0, 1), (-1, 1), (1, 5), (0.5, 2.5)])
def test_arange_start_end(self, dtype, end, start, ie_device, precision, ir_version):
Expand All @@ -133,7 +131,6 @@ def test_arange_start_end(self, dtype, end, start, ie_device, precision, ir_vers

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_fx_backend
@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"])
@pytest.mark.parametrize("start,end,step", [(0, 1, 1), (-2, 1, 1.25), (1, -5, -1), (1, 10, 2), (-1, -5, -2)])
def test_arange_start_end_step(self, dtype, end, start, step, ie_device, precision, ir_version):
Expand All @@ -142,7 +139,6 @@ def test_arange_start_end_step(self, dtype, end, start, step, ie_device, precisi

@pytest.mark.nightly
@pytest.mark.precommit_torch_export
@pytest.mark.precommit_fx_backend
@pytest.mark.parametrize("dtype", [skip_check(None),
skip_if_export("float32"),
skip_if_export("float64"),
Expand All @@ -156,7 +152,6 @@ def test_arange_end_only_with_prim_dtype(self, dtype, end, ie_device, precision,
kwargs_to_prepare_input={"end": end, "ref_dtype": dtype})

@pytest.mark.nightly
@pytest.mark.precommit_fx_backend
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"])
@pytest.mark.parametrize("start,end", [(0, 1), (-1, 1), (1, 5), (0.5, 2.5)])
def test_arange_start_end_with_prim_dtype(self, dtype, end, start, ie_device, precision, ir_version):
Expand All @@ -165,7 +160,6 @@ def test_arange_start_end_with_prim_dtype(self, dtype, end, start, ie_device, pr

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_fx_backend
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"])
@pytest.mark.parametrize("start,end,step", [(0, 1, 1), (-2, 1, 1.25), (1, -5, -1), (1, 10, 2), (-1, -5, -2)])
def test_arange_start_end_step_with_prim_dtype(self, dtype, end, start, step, ie_device, precision, ir_version):
Expand Down
10 changes: 4 additions & 6 deletions tests/layer_tests/pytorch_tests/test_trilu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,15 @@ def forward(self, x):

return aten_trilu(pt_op, diagonal), ref_net, f"aten::{op}"

@pytest.mark.parametrize("input_shape", [(5, 5), (6, 4), (4, 6)])
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8", "uint8", "bool"])
@pytest.mark.parametrize("diagonal", [0, 1, 2, -1, -2])
@pytest.mark.parametrize("op", ["triu", "tril"])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_fx_backend
def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version):
def test_trilu(self, dtype, diagonal, op, ie_device, precision, ir_version):
self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version,
kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype})
kwargs_to_prepare_input={"shape": (4, 6), "dtype": dtype})


class TestTriuTrilTensor(PytorchLayerTest):
Expand Down Expand Up @@ -84,13 +83,12 @@ def triu_(self, x):

return aten_trilu(op, diagonal), ref_net, f"aten::{op}"

@pytest.mark.parametrize("input_shape", [(5, 5), (6, 4), (4, 6)])
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8", "uint8", "bool"])
@pytest.mark.parametrize("diagonal", [0, 1, 2, -1, -2])
@pytest.mark.parametrize("op", ["triu", "tril", "triu_", "tril_"])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_fx_backend
def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version):
def test_trilu(self, dtype, diagonal, op, ie_device, precision, ir_version):
self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version,
kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype})
kwargs_to_prepare_input={"shape": (4, 6), "dtype": dtype})
6 changes: 3 additions & 3 deletions tests/requirements_pytorch
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
# optimum still requires numpy<2.0.0
numpy==1.26.4; python_version < "3.12"
numpy==2.1.1; python_version >= "3.12"
torch==2.5.1; platform_system != "Darwin" or platform_machine != "x86_64"
torch==2.6.0; platform_system != "Darwin" or platform_machine != "x86_64"
torch==2.2.2; platform_system == "Darwin" and platform_machine == "x86_64"
--extra-index-url https://download.pytorch.org/whl/cpu

torchvision==0.20.1; platform_system != "Darwin" or platform_machine != "x86_64"
torchvision==0.21.0; platform_system != "Darwin" or platform_machine != "x86_64"
torchvision==0.17.2; platform_system == "Darwin" and platform_machine == "x86_64"
torchaudio==2.5.1; platform_system != "Darwin" or platform_machine != "x86_64"
torchaudio==2.6.0; platform_system != "Darwin" or platform_machine != "x86_64"
torchaudio==2.2.2; platform_system == "Darwin" and platform_machine == "x86_64"
# before updating transformers version, make sure no tests (esp. sdpa2pa) are failing
transformers==4.47.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,6 @@ def extract_module_extensions(args):
return {extension.module: extension for extension in extensions if isinstance(extension, ModuleExtension)}


def get_decoder_for_exported_program(model):
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
import torch

from packaging import version
if version.parse(torch.__version__) >= version.parse("2.2"):
from torch._decomp import get_decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
decomp = get_decompositions(get_export_decomposition_list())
model = model.run_decompositions(decomp_table=decomp)
gm = model.module()
log.debug(gm.code)
decoder = TorchFXPythonDecoder(gm, dynamic_shapes=True)
return decoder


def get_pytorch_decoder(model, example_inputs, args):
try:
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
Expand Down Expand Up @@ -65,7 +49,7 @@ def get_pytorch_decoder(model, example_inputs, args):
inputs = prepare_torch_inputs(example_inputs)
if not isinstance(model, (TorchScriptPythonDecoder, TorchFXPythonDecoder)):
if hasattr(torch, "export") and isinstance(model, (torch.export.ExportedProgram)):
decoder = get_decoder_for_exported_program(model)
decoder = TorchFXPythonDecoder.from_exported_program(model)
else:
decoder = TorchScriptPythonDecoder(
model,
Expand Down Expand Up @@ -123,7 +107,7 @@ def get_pytorch_decoder_for_model_on_disk(argv, args):
try:
exported_program = torch.export.load(input_model)
if hasattr(torch, "export") and isinstance(exported_program, (torch.export.ExportedProgram)):
argv.input_model = get_decoder_for_exported_program(exported_program)
argv.input_model = TorchFXPythonDecoder.from_exported_program(exported_program)
argv.framework = 'pytorch'
return True
except:
Expand Down

0 comments on commit bd266dc

Please sign in to comment.