Skip to content

Commit

Permalink
Avoid accidental namedtuple conversion in apply_to_collection (#210)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
GdoongMathew and carmocca authored Dec 20, 2023
1 parent f7a0693 commit 990ba45
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 26 deletions.
6 changes: 3 additions & 3 deletions src/lightning_utilities/core/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def apply_to_collection(
# fast path for the most common cases:
if isinstance(data, dtype): # single element
return function(data, *args, **kwargs)
if isinstance(data, list) and all(isinstance(x, dtype) for x in data): # 1d homogeneous list
if data.__class__ is list and all(isinstance(x, dtype) for x in data): # 1d homogeneous list
return [function(x, *args, **kwargs) for x in data]
if isinstance(data, tuple) and all(isinstance(x, dtype) for x in data): # 1d homogeneous tuple
if data.__class__ is tuple and all(isinstance(x, dtype) for x in data): # 1d homogeneous tuple
return tuple(function(x, *args, **kwargs) for x in data)
if isinstance(data, dict) and all(isinstance(x, dtype) for x in data.values()): # 1d homogeneous dict
if data.__class__ is dict and all(isinstance(x, dtype) for x in data.values()): # 1d homogeneous dict
return {k: function(v, *args, **kwargs) for k, v in data.items()}
# slow path for everything else
return _apply_to_collection_slow(
Expand Down
64 changes: 41 additions & 23 deletions tests/unittests/core/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,15 @@ def __eq__(self, o: object) -> bool:
return self.dummy == o.dummy


def test_recursive_application_to_collection():
ntc = namedtuple("Foo", ["bar"])
class _CustomCollection(dict):
def __init__(self, initial_dict) -> None:
super().__init__(initial_dict)


ntc = namedtuple("Foo", ["bar", "baz"])


def test_recursive_application_to_collection():
model_example = ModelExample(
example_ids=["i-1", "i-2", "i-3"],
feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=torch.tensor([4.0, 5.0, 6.0])),
Expand All @@ -112,7 +118,7 @@ def test_recursive_application_to_collection():
"a": torch.tensor([1.0]), # Tensor
"b": [torch.tensor([2.0])], # list
"c": (torch.tensor([100.0]),), # tuple
"d": ntc(bar=5.0), # named tuple
"d": ntc(bar=5.0, baz=10.0), # named tuple
"f": "this_is_a_dummy_str", # string
"g": 12.0, # number
"h": Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=torch.tensor([4.0, 5.0, 6.0])), # dataclass
Expand All @@ -132,7 +138,7 @@ def test_recursive_application_to_collection():
"a": torch.tensor([2.0]),
"b": [torch.tensor([4.0])],
"c": (torch.tensor([200.0]),),
"d": ntc(bar=10),
"d": ntc(bar=10, baz=20),
"f": "this_is_a_dummy_str",
"g": 24.0,
"h": Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=torch.tensor([8.0, 10.0, 12.0])),
Expand Down Expand Up @@ -201,25 +207,37 @@ def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""):
WithClassAndInitVar.class_var, torch.tensor(0)
), f"Reduction of a {dataclass_type} dataclass should not change the class var"

# mapping support
reduced = apply_to_collection({"a": 1, "b": 2}, int, lambda x: str(x))
assert reduced == {"a": "1", "b": "2"}
reduced = apply_to_collection(OrderedDict([("b", 2), ("a", 1)]), int, lambda x: str(x))
assert reduced == OrderedDict([("b", "2"), ("a", "1")])

# custom mappings
class _CustomCollection(dict):
def __init__(self, initial_dict) -> None:
super().__init__(initial_dict)

to_reduce = _CustomCollection({"a": 1, "b": 2, "c": 3})
reduced = apply_to_collection(to_reduce, int, lambda x: str(x))
assert reduced == _CustomCollection({"a": "1", "b": "2", "c": "3"})

# defaultdict
to_reduce = defaultdict(int, {"a": 1, "b": 2, "c": 3})
reduced = apply_to_collection(to_reduce, int, lambda x: str(x))
assert reduced == defaultdict(int, {"a": "1", "b": "2", "c": "3"})

@pytest.mark.parametrize(
("ori", "target"),
[
(
{"a": 1, "b": 2, "c": 3},
{"a": "1", "b": "2", "c": "3"},
),
(
OrderedDict([("b", 2), ("a", 1), ("c", 3)]),
OrderedDict([("b", "2"), ("a", "1"), ("c", "3")]),
),
(
_CustomCollection({"a": 1, "b": 2, "c": 3}),
_CustomCollection({"a": "1", "b": "2", "c": "3"}),
),
(
defaultdict(int, {"a": 1, "b": 2, "c": 3}),
defaultdict(int, {"a": "1", "b": "2", "c": "3"}),
),
(
ntc(bar=5, baz=5),
ntc(bar="5", baz="5"),
),
],
)
def test_application_to_collection_return_type(ori, target):
# custom mapping support
reduced = apply_to_collection(ori, int, lambda x: str(x))
assert reduced == target
assert type(reduced) is type(target)


def test_apply_to_collection_include_none():
Expand Down

0 comments on commit 990ba45

Please sign in to comment.