Skip to content

Commit

Permalink
Fix: Resolving of import paths for some torch functions not working (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa authored Jun 24, 2024
1 parent 2bcbd48 commit 82273f9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ paths are considered internals and can change in minor and patch releases.
v4.31.0 (2024-06-??)
--------------------

Fixed
^^^^^
- Resolving of import paths for some ``torch`` functions not working (`#535
<https://github.com/omni-us/jsonargparse/pull/535>`__).

Changed
^^^^^^^
- Now ``--*.help`` output shows options without ``init_args`` (`#533
Expand Down
7 changes: 5 additions & 2 deletions jsonargparse/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,18 @@ def get_import_path(value: Any) -> Optional[str]:
if not path:
raise ValueError(f"Not possible to determine the import path for object {value}.")

if qualname and module_path and "." in module_path:
if qualname and module_path and ("." in qualname or "." in module_path):
module_parts = module_path.split(".")
for num in range(len(module_parts)):
module_path = ".".join(module_parts[: num + 1])
module = import_module(module_path)
if "." in qualname:
obj_name, attr = qualname.rsplit(".", 1)
obj = getattr(module, obj_name, None)
if getattr(obj, attr, None) is value:
if getattr(module, attr, None) is value:
path = module_path + "." + attr
break
elif getattr(obj, attr, None) is value:
path = module_path + "." + qualname
break
elif getattr(module, qualname, None) is value:
Expand Down
21 changes: 19 additions & 2 deletions jsonargparse_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,12 @@ def test_logger_jsonargparse_debug():


def test_import_object_invalid():
pytest.raises(ValueError, lambda: import_object(True))
pytest.raises(ValueError, lambda: import_object("jsonargparse-tests.os"))
with pytest.raises(ValueError) as ctx:
import_object(True)
ctx.match("Expected a dot import path string")
with pytest.raises(ValueError) as ctx:
import_object("jsonargparse-tests.os")
ctx.match("Unexpected import path format")


def test_get_import_path():
Expand All @@ -548,6 +552,19 @@ def test_get_import_path():
assert get_import_path(MISSING) == "dataclasses.MISSING"


class _StaticMethods:
@staticmethod
def static_method():
pass


static_method = _StaticMethods.static_method


def test_get_import_path_static_method_shorthand():
assert get_import_path(static_method) == f"{__name__}.static_method"


def unresolvable_import():
pass

Expand Down

0 comments on commit 82273f9

Please sign in to comment.