Skip to content

Commit

Permalink
Merge pull request #1802 from microsoft/chuyang/splitters
Browse files Browse the repository at this point in the history
Optimized Python splitters
  • Loading branch information
miguelgfierro authored Jul 29, 2022
2 parents 3ccc487 + d80401f commit bd4a38a
Showing 1 changed file with 22 additions and 25 deletions.
47 changes: 22 additions & 25 deletions recommenders/datasets/python_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,36 +87,33 @@ def _do_stratification(
col_item=col_item,
)

# Split by each group and aggregate splits together.
splits = []

# If it is for chronological splitting, the split will be performed in a random way.
df_grouped = (
data.sort_values(col_timestamp).groupby(split_by_column)
if is_random is False
else data.groupby(split_by_column)
)

for _, group in df_grouped:
group_splits = split_pandas_data_with_ratios(
group, ratio, shuffle=is_random, seed=seed
)
if is_random:
np.random.seed(seed)
data["random"] = np.random.rand(data.shape[0])
order_by = "random"
else:
order_by = col_timestamp

# Concatenate the list of split dataframes.
concat_group_splits = pd.concat(group_splits)
data = data.sort_values([split_by_column, order_by])

splits.append(concat_group_splits)
groups = data.groupby(split_by_column)

# Concatenate splits for all the groups together.
splits_all = pd.concat(splits)
data["count"] = groups[split_by_column].transform("count")
data["rank"] = groups.cumcount() + 1

# Take split by split_index
splits_list = [
splits_all[splits_all["split_index"] == x].drop("split_index", axis=1)
for x in range(len(ratio))
]
if is_random:
data = data.drop("random", axis=1)

return splits_list
splits = []
prev_threshold = None
for threshold in np.cumsum(ratio):
condition = data["rank"] <= round(threshold * data["count"])
if prev_threshold is not None:
condition &= data["rank"] > round(prev_threshold * data["count"])
splits.append(data[condition].drop(["rank", "count"], axis=1))
prev_threshold = threshold

return splits


def python_chrono_split(
Expand Down

0 comments on commit bd4a38a

Please sign in to comment.