From 82273f9867ded25214d7719584f7c2bf011dc7da Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Mon, 24 Jun 2024 06:27:53 +0200 Subject: [PATCH] Fix: Resolving of import paths for some torch functions not working (#535) --- CHANGELOG.rst | 5 +++++ jsonargparse/_util.py | 7 +++++-- jsonargparse_tests/test_util.py | 21 +++++++++++++++++++-- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 46a65d79..207108c4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,11 @@ paths are considered internals and can change in minor and patch releases. v4.31.0 (2024-06-??) -------------------- +Fixed +^^^^^ +- Resolving of import paths for some ``torch`` functions not working (`#535 + `__). + Changed ^^^^^^^ - Now ``--*.help`` output shows options without ``init_args`` (`#533 diff --git a/jsonargparse/_util.py b/jsonargparse/_util.py index 91767910..f07d0873 100644 --- a/jsonargparse/_util.py +++ b/jsonargparse/_util.py @@ -233,7 +233,7 @@ def get_import_path(value: Any) -> Optional[str]: if not path: raise ValueError(f"Not possible to determine the import path for object {value}.") - if qualname and module_path and "." in module_path: + if qualname and module_path and ("." in qualname or "." in module_path): module_parts = module_path.split(".") for num in range(len(module_parts)): module_path = ".".join(module_parts[: num + 1]) @@ -241,7 +241,10 @@ def get_import_path(value: Any) -> Optional[str]: if "." in qualname: obj_name, attr = qualname.rsplit(".", 1) obj = getattr(module, obj_name, None) - if getattr(obj, attr, None) is value: + if getattr(module, attr, None) is value: + path = module_path + "." + attr + break + elif getattr(obj, attr, None) is value: path = module_path + "." + qualname break elif getattr(module, qualname, None) is value: diff --git a/jsonargparse_tests/test_util.py b/jsonargparse_tests/test_util.py index 3c0aef39..c7709ac9 100644 --- a/jsonargparse_tests/test_util.py +++ b/jsonargparse_tests/test_util.py @@ -533,8 +533,12 @@ def test_logger_jsonargparse_debug(): def test_import_object_invalid(): - pytest.raises(ValueError, lambda: import_object(True)) - pytest.raises(ValueError, lambda: import_object("jsonargparse-tests.os")) + with pytest.raises(ValueError) as ctx: + import_object(True) + ctx.match("Expected a dot import path string") + with pytest.raises(ValueError) as ctx: + import_object("jsonargparse-tests.os") + ctx.match("Unexpected import path format") def test_get_import_path(): @@ -548,6 +552,19 @@ def test_get_import_path(): assert get_import_path(MISSING) == "dataclasses.MISSING" +class _StaticMethods: + @staticmethod + def static_method(): + pass + + +static_method = _StaticMethods.static_method + + +def test_get_import_path_static_method_shorthand(): + assert get_import_path(static_method) == f"{__name__}.static_method" + + def unresolvable_import(): pass