Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle pandas timestamps #958

Merged
merged 15 commits into from
Dec 1, 2022
37 changes: 36 additions & 1 deletion python/deltalake/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,42 @@
from typing import Union
from typing import TYPE_CHECKING, Tuple, Union

import pyarrow as pa

if TYPE_CHECKING:
import pandas as pd

from ._internal import ArrayType, Field, MapType, PrimitiveType, Schema, StructType

# Can't implement inheritance (see note in src/schema.rs), so this is next
# best thing.
DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"]


def delta_arrow_schema_from_pandas(
data: "pd.DataFrame",
) -> Tuple[pa.Table, pa.Schema]:
"""
Infers the schema for the delta table from the Pandas DataFrame.
Necessary because of issues such as: https://github.com/delta-io/delta-rs/issues/686

:param data: Data to write.
:return: A PyArrow Table and the inferred schema for the Delta Table
"""

table = pa.Table.from_pandas(data)
schema = table.schema
schema_out = []
for field in schema:
if isinstance(field.type, pa.TimestampType):
f = pa.field(
name=field.name,
type=pa.timestamp("us"),
nullable=field.nullable,
metadata=field.metadata,
)
schema_out.append(f)
else:
schema_out.append(field)
schema = pa.schema(schema_out, metadata=schema.metadata)
data = pa.Table.from_pandas(data, schema=schema)
return data, schema
8 changes: 7 additions & 1 deletion python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import pyarrow.fs as pa_fs
from pyarrow.lib import RecordBatchReader

from deltalake.schema import delta_arrow_schema_from_pandas

from ._internal import DeltaDataChecker as _DeltaDataChecker
from ._internal import PyDeltaTableError
from ._internal import write_new_deltalake as _write_new_deltalake
Expand Down Expand Up @@ -132,8 +134,12 @@ def write_deltalake(
:param overwrite_schema: If True, allows updating the schema of the table.
:param storage_options: options passed to the native delta filesystem. Unused if 'filesystem' is defined.
"""

if _has_pandas and isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data)
if schema is not None:
data = pa.Table.from_pandas(data, schema=schema)
else:
data, schema = delta_arrow_schema_from_pandas(data)

if schema is None:
if isinstance(data, RecordBatchReader):
Expand Down
2 changes: 2 additions & 0 deletions python/stubs/pyarrow/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ float16: Any
float32: Any
float64: Any
dictionary: Any
timestamp: Any
TimestampType: Any

py_buffer: Callable[[bytes], Any]
NativeFile: Any
Expand Down
12 changes: 8 additions & 4 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,16 @@ def test_fails_wrong_partitioning(existing_table: DeltaTable, sample_data: pa.Ta


@pytest.mark.pandas
def test_write_pandas(tmp_path: pathlib.Path, sample_data: pa.Table):
@pytest.mark.parametrize("schema_provided", [True, False])
def test_write_pandas(tmp_path: pathlib.Path, sample_data: pa.Table, schema_provided):
# When timestamp is converted to Pandas, it gets casted to ns resolution,
# but Delta Lake schemas only support us resolution.
sample_pandas = sample_data.to_pandas().drop(["timestamp"], axis=1)
write_deltalake(str(tmp_path), sample_pandas)

sample_pandas = sample_data.to_pandas()
if schema_provided is True:
schema = sample_data.schema
else:
schema = None
write_deltalake(str(tmp_path), sample_pandas, schema=schema)
delta_table = DeltaTable(str(tmp_path))
df = delta_table.to_pandas()
assert_frame_equal(df, sample_pandas)
Expand Down