-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add codepath for computing buckets without int conversion #326
Changes from all commits
ccb1e31
f2b1888
816940b
30f383c
d7a2617
954a043
3b51aad
d119740
8dbc48a
dccd964
2e497df
c969e1f
f2a59f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,7 @@ | |
from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import int_ids_to_str | ||
from nemo_curator.utils.fuzzy_dedup_utils.io_utils import ( | ||
aggregated_anchor_docs_with_bk_read, | ||
check_empty_buckets, | ||
get_restart_offsets, | ||
update_restart_offsets, | ||
) | ||
|
@@ -261,6 +262,7 @@ def __init__( | |
num_hashes: int, | ||
num_buckets: int, | ||
buckets_per_shuffle: int = 1, | ||
false_positive_check: bool = False, | ||
logger: Union[logging.LoggerAdapter, str] = "./", | ||
id_fields: Union[str, list] = "id", | ||
minhash_field: str = "_minhash_signature", | ||
|
@@ -275,8 +277,9 @@ def __init__( | |
num_buckets: Number of bands/buckets to create from the minhash signature. | ||
Hashes_per_signature = num_hashes / num_buckets | ||
buckets_per_shuffle: Number of bands/buckets to shuffle concurrently. | ||
Larger values process larger batches by processing multiple bands | ||
but might lead to memory pressures and related errors. | ||
false_positive_check: bool | ||
If True, writes out buckets in a format compatible with downstream false positive check. | ||
logger: Existing logger to log to, or a path to a log directory. | ||
id_field: Columns in the Dataset denoting document ID. | ||
minhash_field: Column in the Dataset denoting minhash signature. | ||
|
@@ -291,6 +294,7 @@ def __init__( | |
self.bucket_ranges = self._generate_bucket_ranges( | ||
self.num_buckets, self.num_hashes | ||
) | ||
self.buckets_as_int = false_positive_check | ||
|
||
if cache_dir is None: | ||
raise ValueError( | ||
|
@@ -379,10 +383,19 @@ def lsh( | |
self, | ||
write_path: str, | ||
df: dask_cudf.DataFrame, | ||
) -> None: | ||
) -> bool: | ||
""" | ||
Computes buckets and writes them as parquet files to the write_path | ||
Computes hash buckets for the DataFrame and writes them as parquet files to the specified path. | ||
|
||
Parameters: | ||
- write_path (str): The directory path to write parquet files. | ||
- df (dask_cudf.DataFrame): The input DataFrame with minhashes to be bucketed. | ||
Returns: | ||
are_buckets_empty: True if buckets were empty (no duplicates found), False otherwise. | ||
""" | ||
wrote_buckets = False | ||
are_buckets_empty = True | ||
|
||
meta = self._minhash_to_bucket_meta(df) | ||
df = df.map_partitions( | ||
self.minhash_to_buckets, | ||
|
@@ -391,12 +404,14 @@ def lsh( | |
) | ||
bucket_start_id = 0 | ||
for i in range(0, self.num_buckets, self.buckets_per_shuffle): | ||
value_vars = [ | ||
bucket_columns = [ | ||
f"_bucket_{i}" | ||
for i in range(i, min(self.num_buckets, i + self.buckets_per_shuffle)) | ||
] | ||
df2 = df.melt( | ||
id_vars=self.id_fields, value_name="_bucket_id", value_vars=value_vars | ||
id_vars=self.id_fields, | ||
value_name="_bucket_id", | ||
value_vars=bucket_columns, | ||
)[self.id_fields + ["_bucket_id"]] | ||
|
||
df2 = df2.shuffle( | ||
|
@@ -406,40 +421,88 @@ def lsh( | |
).map_partitions(lambda x: x[x["_bucket_id"].duplicated(keep=False)]) | ||
|
||
df2 = df2.reset_index(drop=True) | ||
df2, end_id = self.bucket_id_to_int( | ||
df2, bucket_col_name="_bucket_id", start_id=bucket_start_id | ||
# Buckets to Int | ||
if self.buckets_as_int: | ||
df2, end_id = self.bucket_id_to_int( | ||
df2, bucket_col_name="_bucket_id", start_id=bucket_start_id | ||
) | ||
# If bucketing return empty dataframe | ||
if end_id < bucket_start_id: | ||
self._logger.info( | ||
f"No duplicate documents found for buckets: {bucket_columns}" | ||
) | ||
continue | ||
bucket_start_id = end_id + 1 | ||
are_buckets_empty = False | ||
|
||
wrote_buckets, are_buckets_empty = self._write_bucket_parquet( | ||
df2, | ||
write_path, | ||
wrote_buckets, | ||
are_buckets_empty, | ||
bucket_columns, | ||
) | ||
# If bucketing return empty dataframe | ||
if end_id < bucket_start_id: | ||
continue | ||
bucket_start_id = end_id + 1 | ||
|
||
# Workaround for dtype mismatches with empty partitions | ||
dtypes = df2.dtypes.to_dict() | ||
df2 = df2.map_partitions(lambda x: x.astype(dtypes)) | ||
if are_buckets_empty: | ||
self._logger.info("No duplicate documents found during LSH") | ||
if os.path.exists(write_path): | ||
import shutil | ||
|
||
if i == 0: | ||
if os.path.exists(write_path): | ||
warnings.warn( | ||
f"Output path {write_path} already exists and will be overwritten" | ||
) | ||
df2.to_parquet(write_path, write_index=False, overwrite=True) | ||
else: | ||
df2.to_parquet(write_path, write_index=False, append=True) | ||
shutil.rmtree(write_path) | ||
|
||
self._logger.info(f"Wrote data for buckets: {value_vars}") | ||
return are_buckets_empty | ||
|
||
def _write_bucket_parquet( | ||
ayushdg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
df: dask_cudf.DataFrame, | ||
write_path: str, | ||
wrote_buckets: bool, | ||
are_buckets_empty: bool, | ||
buckets_to_write: List[str], | ||
) -> tuple[bool, bool]: | ||
""" | ||
Utility function to write the bucketed data to parquet | ||
handling cases of overwriting and appending as needed. | ||
""" | ||
if not wrote_buckets: | ||
if os.path.exists(write_path): | ||
warnings.warn( | ||
f"Output path {write_path} already exists and will be overwritten" | ||
) | ||
df.to_parquet(write_path, write_index=False, overwrite=True) | ||
else: | ||
df.to_parquet( | ||
write_path, | ||
write_index=False, | ||
overwrite=are_buckets_empty, | ||
append=not are_buckets_empty, | ||
ayushdg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ignore_divisions=True, | ||
) | ||
# Only check if buckets written so far are empty | ||
if are_buckets_empty: | ||
are_buckets_empty = check_empty_buckets(write_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason we need to do this in the first place is because there's no way to know if we're writing out an empty dataframe or not, unless we persist, or write it out, check the metadata and then overwrite on the next iteration. |
||
wrote_buckets = True | ||
|
||
if are_buckets_empty: | ||
self._logger.info( | ||
f"No duplicate documents found for buckets: {buckets_to_write}" | ||
) | ||
else: | ||
self._logger.info(f"Wrote data for buckets: {buckets_to_write}") | ||
return wrote_buckets, are_buckets_empty | ||
|
||
def __call__(self, dataset: DocumentDataset) -> DocumentDataset: | ||
df = dataset.df | ||
|
||
write_path = os.path.join(self.cache_dir, "_buckets.parquet") | ||
t0 = time.time() | ||
with performance_report_if_with_ts_suffix(self.profile_dir, "lsh-profile"): | ||
self.lsh(write_path=write_path, df=df) | ||
empty_result = self.lsh(write_path=write_path, df=df) | ||
self._logger.info( | ||
f"Time taken for LSH = {time.time() - t0}s and output written at {write_path}" | ||
) | ||
|
||
if empty_result: | ||
return None | ||
buckets_df = dask_cudf.read_parquet(write_path, split_row_groups=False) | ||
return DocumentDataset(buckets_df) | ||
|
||
|
@@ -488,6 +551,7 @@ def __init__( | |
num_hashes=self.config.num_hashes, | ||
num_buckets=self.config.num_buckets, | ||
buckets_per_shuffle=self.config.buckets_per_shuffle, | ||
false_positive_check=self.config.false_positive_check, | ||
logger=self._logger, | ||
id_fields=[self.config.id_field], | ||
profile_dir=self.config.profile_dir, | ||
|
@@ -556,6 +620,11 @@ def __call__(self, dataset: DocumentDataset): | |
minhashLSH = Sequential([self.minhash, self.lsh]) | ||
buckets_df = minhashLSH(dataset) | ||
print(f"Stage{stage_num}: Minhash + LSH complete!") | ||
if buckets_df is None: | ||
print( | ||
f"Stage{stage_num}: No potential duplicate documents found during LSH" | ||
) | ||
return None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this return None or an empty There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer returning There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, but then for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't seen |
||
stage_num += 1 | ||
|
||
if self.config.false_positive_check: | ||
|
@@ -740,6 +809,7 @@ def buckets_to_edges( | |
|
||
def __call__(self, dataset: DocumentDataset) -> DocumentDataset: | ||
buckets_df = dataset.df | ||
self._logger.info(f"Starting conversion of LSH Buckets to Graph Edgelist") | ||
if len(self.id_fields) > 1: | ||
buckets_df = buckets_df.map_partitions( | ||
BucketsToEdges._combine_multiple_ids, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -202,3 +202,16 @@ def strip_trailing_sep(path: str): | |
Strips a path string of trailing path seperators like `/` if any. | ||
""" | ||
return path.rstrip(os.path.sep) | ||
|
||
|
||
def check_empty_buckets(bucket_path): | ||
""" | ||
Inspects parquet metadata of the buckets dataset to check if it's an empty dataset. | ||
""" | ||
from pyarrow.dataset import dataset | ||
|
||
ds = dataset(bucket_path, format="parquet") | ||
for fragment in ds.get_fragments(): | ||
if fragment.metadata.num_rows > 0: | ||
return False | ||
Comment on lines
+213
to
+216
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic can probably be simplified by using a global metadata file when writing out the parquet dataset |
||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable for tracking if all the buckets were empty