diff --git a/src/acbm/utils.py b/src/acbm/utils.py index df8bdfc..91469e1 100644 --- a/src/acbm/utils.py +++ b/src/acbm/utils.py @@ -59,6 +59,8 @@ def households_with_common_travel_days( .apply( lambda common_days: [day for day in common_days if day in days] if common_days is not None + and common_days != {pd.NA} + and common_days != {np.nan} else [] ) .apply(lambda common_days: common_days if common_days else pd.NA) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0a90147..565ce31 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,9 +8,26 @@ def nts_trips(): return pd.DataFrame.from_dict( { - "IndividualID": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - "HouseholdID": [1, 1, 1, 2, 2, 2, 3, 3, 3, 3], - "TravDay": [1, 1, 1, 2, 3, 2, 3, 3, 3, 3], + "IndividualID": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "HouseholdID": [1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5], + "TravDay": [ + 1, + 1, + 1, + 2, + 3, + 2, + 3, + 3, + 3, + 3, + pd.NA, + pd.NA, + pd.NA, + pd.NA, + pd.NA, + 4, + ], } ) @@ -19,3 +36,4 @@ def test_households_with_common_travel_days(nts_trips): assert households_with_common_travel_days(nts_trips, [1]) == [1] assert households_with_common_travel_days(nts_trips, [1, 2]) == [1] assert households_with_common_travel_days(nts_trips, [1, 3]) == [1, 3] + assert households_with_common_travel_days(nts_trips, [1, 3, 4]) == [1, 3]