Skip to content

Commit

Permalink
Merge branch 'main' into add-splits-type
Browse files Browse the repository at this point in the history
  • Loading branch information
jp1924 authored Feb 21, 2025
2 parents c7dbf4f + 14233c0 commit 7d255d0
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 39 deletions.
2 changes: 1 addition & 1 deletion docs/source/use_with_pandas.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ And you can use [`Dataset.to_pandas`] to export a Dataset to a Pandas DataFrame:


```python
df = Dataset.from_pandas(ds)
df = Dataset.to_pandas()
```
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@

setup(
name="datasets",
version="3.3.2.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="3.3.3.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="HuggingFace community-driven open-source library of datasets",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "3.3.2.dev0"
__version__ = "3.3.3.dev0"

from .arrow_dataset import Dataset
from .arrow_reader import ReadInstruction
Expand Down
34 changes: 25 additions & 9 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def from_pandas(
Important: a dataset created with from_pandas() lives in memory
and therefore doesn't have an associated cache directory.
This may change in the feature, but in the meantime if you
This may change in the future, but in the meantime if you
want to reduce memory usage you should write it back on disk
and reload using e.g. save_to_disk / load_from_disk.
Expand Down Expand Up @@ -908,7 +908,7 @@ def from_dict(
Important: a dataset created with from_dict() lives in memory
and therefore doesn't have an associated cache directory.
This may change in the feature, but in the meantime if you
This may change in the future, but in the meantime if you
want to reduce memory usage you should write it back on disk
and reload using e.g. save_to_disk / load_from_disk.
Expand Down Expand Up @@ -973,7 +973,7 @@ def from_list(
Important: a dataset created with from_list() lives in memory
and therefore doesn't have an associated cache directory.
This may change in the feature, but in the meantime if you
This may change in the future, but in the meantime if you
want to reduce memory usage you should write it back on disk
and reload using e.g. save_to_disk / load_from_disk.
Expand Down Expand Up @@ -3177,6 +3177,8 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
transformed_shards[rank] = content
else:
pbar.update(content)
pool.close()
pool.join()
# Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
for kwargs in kwargs_per_job:
del kwargs["shard"]
Expand Down Expand Up @@ -3431,14 +3433,19 @@ def init_buffer_and_writer():
)
return buf_writer, writer, tmp_file

tasks: List[asyncio.Task] = []
if inspect.iscoroutinefunction(function):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
else:
loop = None

def iter_outputs(shard_iterable):
nonlocal tasks, loop
if inspect.iscoroutinefunction(function):
indices: Union[List[int], List[List[int]]] = []
tasks: List[asyncio.Task] = []
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
for i, example in shard_iterable:
indices.append(i)
tasks.append(loop.create_task(async_apply_function(example, i, offset=offset)))
Expand All @@ -3455,7 +3462,8 @@ def iter_outputs(shard_iterable):
while tasks and tasks[0].done():
yield indices.pop(0), tasks.pop(0).result()
while tasks:
yield indices.pop(0), loop.run_until_complete(tasks.pop(0))
yield indices[0], loop.run_until_complete(tasks[0])
indices.pop(0), tasks.pop(0)
else:
for i, example in shard_iterable:
yield i, apply_function(example, i, offset=offset)
Expand Down Expand Up @@ -3540,6 +3548,14 @@ def iter_outputs(shard_iterable):
tmp_file.close()
if os.path.exists(tmp_file.name):
os.remove(tmp_file.name)
if loop:
logger.debug(f"Canceling {len(tasks)} async tasks.")
for task in tasks:
task.cancel(msg="KeyboardInterrupt")
try:
loop.run_until_complete(asyncio.gather(*tasks))
except asyncio.CancelledError:
logger.debug("Tasks canceled.")
raise

yield rank, False, num_examples_progress_update
Expand Down
71 changes: 44 additions & 27 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,15 +1171,20 @@ async def async_apply_function(key_example, indices):
processed_inputs = await self.function(*fn_args, *additional_args, **fn_kwargs)
return prepare_outputs(key_example, inputs, processed_inputs)

tasks: List[asyncio.Task] = []
if inspect.iscoroutinefunction(self.function):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
else:
loop = None

def iter_outputs():
nonlocal tasks, loop
inputs_iterator = iter_batched_inputs() if self.batched else iter_inputs()
if inspect.iscoroutinefunction(self.function):
indices: Union[List[int], List[List[int]]] = []
tasks: List[asyncio.Task] = []
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
for i, key_example in inputs_iterator:
indices.append(i)
tasks.append(loop.create_task(async_apply_function(key_example, i)))
Expand All @@ -1196,36 +1201,48 @@ def iter_outputs():
while tasks and tasks[0].done():
yield indices.pop(0), tasks.pop(0).result()
while tasks:
yield indices.pop(0), loop.run_until_complete(tasks.pop(0))
yield indices[0], loop.run_until_complete(tasks[0])
indices.pop(0), tasks.pop(0)
else:
for i, key_example in inputs_iterator:
yield i, apply_function(key_example, i)

if self.batched:
if self._state_dict:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
for key, transformed_batch in iter_outputs():
# yield one example at a time from the transformed batch
for example in _batch_to_examples(transformed_batch):
current_idx += 1
if self._state_dict:
self._state_dict["num_examples_since_previous_state"] += 1
if num_examples_to_skip > 0:
num_examples_to_skip -= 1
continue
yield key, example
try:
if self.batched:
if self._state_dict:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
else:
for key, transformed_example in iter_outputs():
current_idx += 1
if self._state_dict:
self._state_dict["previous_state_example_idx"] += 1
yield key, transformed_example
for key, transformed_batch in iter_outputs():
# yield one example at a time from the transformed batch
for example in _batch_to_examples(transformed_batch):
current_idx += 1
if self._state_dict:
self._state_dict["num_examples_since_previous_state"] += 1
if num_examples_to_skip > 0:
num_examples_to_skip -= 1
continue
yield key, example
if self._state_dict:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
else:
for key, transformed_example in iter_outputs():
current_idx += 1
if self._state_dict:
self._state_dict["previous_state_example_idx"] += 1
yield key, transformed_example
except (Exception, KeyboardInterrupt):
if loop:
logger.debug(f"Canceling {len(tasks)} async tasks.")
for task in tasks:
task.cancel(msg="KeyboardInterrupt")
try:
loop.run_until_complete(asyncio.gather(*tasks))
except asyncio.CancelledError:
logger.debug("Tasks canceled.")
raise

def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key, pa.Table]]:
formatter: TableFormatter = get_formatter(self.formatting.format_type) if self.formatting else ArrowFormatter()
Expand Down

0 comments on commit 7d255d0

Please sign in to comment.