Skip to content

Commit

Permalink
Support faster processing using pandas or polars functions in `Iterab…
Browse files Browse the repository at this point in the history
…leDataset.map()` (#7370)

* add pandas and polars formatting in iterabledataset

* fix tests

* docs

* fix ci

* add tests
  • Loading branch information
lhoestq authored Jan 30, 2025
1 parent fb91fd3 commit c2b7303
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 44 deletions.
2 changes: 1 addition & 1 deletion docs/source/process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ The [`~Dataset.with_format`] function also changes the format of a column, excep

<Tip>

🤗 Datasets also provides support for other common data formats such as NumPy, Pandas, and JAX. Check out the [Using Datasets with TensorFlow](https://huggingface.co/docs/datasets/master/en/use_with_tensorflow#using-totfdataset) guide for more details on how to efficiently create a TensorFlow dataset.
🤗 Datasets also provides support for other common data formats such as NumPy, TensorFlow, JAX, Arrow, Pandas and Polars. Check out the [Using Datasets with TensorFlow](https://huggingface.co/docs/datasets/master/en/use_with_tensorflow#using-totfdataset) guide for more details on how to efficiently create a TensorFlow dataset.

</Tip>

Expand Down
6 changes: 3 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,7 +2457,7 @@ def formatted_as(
Args:
type (`str`, *optional*):
Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__`` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down Expand Up @@ -2491,7 +2491,7 @@ def set_format(
Args:
type (`str`, *optional*):
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down Expand Up @@ -2644,7 +2644,7 @@ def with_format(
Args:
type (`str`, *optional*):
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down
9 changes: 4 additions & 5 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def formatted_as(
Args:
type (`str`, *optional*):
Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down Expand Up @@ -563,7 +563,7 @@ def set_format(
Args:
type (`str`, *optional*):
Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down Expand Up @@ -670,7 +670,7 @@ def with_format(
Args:
type (`str`, *optional*):
Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down Expand Up @@ -1821,12 +1821,11 @@ def with_format(
) -> "IterableDatasetDict":
"""
Return a dataset with the specified format.
The 'pandas' format is currently not implemented.
Args:
type (`str`, *optional*):
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means it returns python objects (default).
Example:
Expand Down
1 change: 1 addition & 0 deletions src/datasets/formatting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Formatter,
PandasFormatter,
PythonFormatter,
TableFormatter,
TensorFormatter,
format_table,
query_table,
Expand Down
15 changes: 13 additions & 2 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,15 @@ def recursive_tensorize(self, data_struct: dict):
raise NotImplementedError


class ArrowFormatter(Formatter[pa.Table, pa.Array, pa.Table]):
class TableFormatter(Formatter[RowFormat, ColumnFormat, BatchFormat]):
table_type: str
column_type: str


class ArrowFormatter(TableFormatter[pa.Table, pa.Array, pa.Table]):
table_type = "arrow table"
column_type = "arrow array"

def format_row(self, pa_table: pa.Table) -> pa.Table:
return self.simple_arrow_extractor().extract_row(pa_table)

Expand Down Expand Up @@ -465,7 +473,10 @@ def format_batch(self, pa_table: pa.Table) -> Mapping:
return batch


class PandasFormatter(Formatter[pd.DataFrame, pd.Series, pd.DataFrame]):
class PandasFormatter(TableFormatter[pd.DataFrame, pd.Series, pd.DataFrame]):
table_type = "pandas dataframe"
column_type = "pandas series"

def format_row(self, pa_table: pa.Table) -> pd.DataFrame:
row = self.pandas_arrow_extractor().extract_row(pa_table)
row = self.pandas_features_decoder.decode_row(row)
Expand Down
8 changes: 5 additions & 3 deletions src/datasets/formatting/polars_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import sys
from collections.abc import Mapping
from functools import partial
from typing import TYPE_CHECKING, Optional

Expand All @@ -23,7 +22,7 @@
from ..features import Features
from ..features.features import decode_nested_example
from ..utils.py_utils import no_op_if_value_is_null
from .formatting import BaseArrowExtractor, TensorFormatter
from .formatting import BaseArrowExtractor, TableFormatter


if TYPE_CHECKING:
Expand Down Expand Up @@ -98,7 +97,10 @@ def decode_batch(self, batch: "pl.DataFrame") -> "pl.DataFrame":
return self.decode_row(batch)


class PolarsFormatter(TensorFormatter[Mapping, "pl.DataFrame", Mapping]):
class PolarsFormatter(TableFormatter["pl.DataFrame", "pl.Series", "pl.DataFrame"]):
table_type = "polars dataframe"
column_type = "polars series"

def __init__(self, features=None, **np_array_kwargs):
super().__init__(features=features)
self.np_array_kwargs = np_array_kwargs
Expand Down
Loading

0 comments on commit c2b7303

Please sign in to comment.