diff --git a/.gitignore b/.gitignore index 776885c858..ff22decf25 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,6 @@ tlaplus/*.toolbox/*/MC.cfg tlaplus/*.toolbox/*/[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*/ /.idea .vscode -.env \ No newline at end of file +.env +**/.DS_Store +**/.python-version \ No newline at end of file diff --git a/README.adoc b/README.adoc index db8c98e744..aa902a8303 100644 --- a/README.adoc +++ b/README.adoc @@ -73,7 +73,7 @@ link:https://github.com/rajasekarv/vega[vega], etc. It also provides bindings to | High-level file writer | -| +| link:https://github.com/delta-io/delta-rs/issues/542[#542] | | Optimize diff --git a/python/deltalake/__init__.py b/python/deltalake/__init__.py index aeb999c97a..eaa3c39c9c 100644 --- a/python/deltalake/__init__.py +++ b/python/deltalake/__init__.py @@ -2,3 +2,4 @@ from .deltalake import PyDeltaTableError, RawDeltaTable, rust_core_version from .schema import DataType, Field, Schema from .table import DeltaTable, Metadata +from .writer import write_deltalake diff --git a/python/deltalake/schema.py b/python/deltalake/schema.py index 9ac5123098..9363b6d097 100644 --- a/python/deltalake/schema.py +++ b/python/deltalake/schema.py @@ -205,7 +205,7 @@ def pyarrow_datatype_from_dict(json_dict: Dict[str, Any]) -> pyarrow.DataType: key, pyarrow.list_( pyarrow.field( - "element", pyarrow.struct([pyarrow_field_from_dict(value_type)]) + "entries", pyarrow.struct([pyarrow_field_from_dict(value_type)]) ) ), ) @@ -218,7 +218,7 @@ def pyarrow_datatype_from_dict(json_dict: Dict[str, Any]) -> pyarrow.DataType: elif type_class == "list": field = json_dict["children"][0] element_type = pyarrow_datatype_from_dict(field) - return pyarrow.list_(pyarrow.field("element", element_type)) + return pyarrow.list_(pyarrow.field("item", element_type)) elif type_class == "struct": fields = [pyarrow_field_from_dict(field) for field in json_dict["children"]] return pyarrow.struct(fields) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 1c8cc8b4e7..59096c38db 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1,7 +1,7 @@ import json import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union import pyarrow import pyarrow.fs as pa_fs @@ -63,6 +63,11 @@ def __str__(self) -> str: ) +class ProtocolVersions(NamedTuple): + min_reader_version: int + min_writer_version: int + + @dataclass(init=False) class DeltaTable: """Create a DeltaTable instance.""" @@ -219,6 +224,9 @@ def metadata(self) -> Metadata: """ return self._metadata + def protocol(self) -> ProtocolVersions: + return ProtocolVersions(*self._table.protocol_versions()) + def history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: """ Run the history command on the DeltaTable. diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py new file mode 100644 index 0000000000..7c38887301 --- /dev/null +++ b/python/deltalake/writer.py @@ -0,0 +1,243 @@ +import json +import uuid +from dataclasses import dataclass +from datetime import date, datetime +from decimal import Decimal +from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Union + +import pandas as pd +import pyarrow as pa +import pyarrow.dataset as ds +import pyarrow.fs as pa_fs +from pyarrow.lib import RecordBatchReader +from typing_extensions import Literal + +from .deltalake import PyDeltaTableError +from .deltalake import write_new_deltalake as _write_new_deltalake +from .table import DeltaTable + + +class DeltaTableProtocolError(PyDeltaTableError): + pass + + +@dataclass +class AddAction: + path: str + size: int + partition_values: Mapping[str, Optional[str]] + modification_time: int + data_change: bool + stats: str + + +def write_deltalake( + table_or_uri: Union[str, DeltaTable], + data: Union[ + pd.DataFrame, + pa.Table, + pa.RecordBatch, + Iterable[pa.RecordBatch], + RecordBatchReader, + ], + schema: Optional[pa.Schema] = None, + partition_by: Optional[List[str]] = None, + filesystem: Optional[pa_fs.FileSystem] = None, + mode: Literal["error", "append", "overwrite", "ignore"] = "error", +) -> None: + """Write to a Delta Lake table (Experimental) + + If the table does not already exist, it will be created. + + This function only supports protocol version 1 currently. If an attempting + to write to an existing table with a higher min_writer_version, this + function will throw DeltaTableProtocolError. + + :param table_or_uri: URI of a table or a DeltaTable object. + :param data: Data to write. If passing iterable, the schema must also be given. + :param schema: Optional schema to write. + :param partition_by: List of columns to partition the table by. Only required + when creating a new table. + :param filesystem: Optional filesystem to pass to PyArrow. If not provided will + be inferred from uri. + :param mode: How to handle existing data. Default is to error if table + already exists. If 'append', will add new data. If 'overwrite', will + replace table with new data. If 'ignore', will not write anything if + table already exists. + """ + if isinstance(data, pd.DataFrame): + data = pa.Table.from_pandas(data) + + if schema is None: + if isinstance(data, RecordBatchReader): + schema = data.schema + elif isinstance(data, Iterable): + raise ValueError("You must provide schema if data is Iterable") + else: + schema = data.schema + + if isinstance(table_or_uri, str): + table = try_get_deltatable(table_or_uri) + table_uri = table_or_uri + else: + table = table_or_uri + table_uri = table_uri = table._table.table_uri() + + # TODO: Pass through filesystem once it is complete + # if filesystem is None: + # filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri)) + + if table: # already exists + if mode == "error": + raise AssertionError("DeltaTable already exists.") + elif mode == "ignore": + return + + current_version = table.version() + + if partition_by: + assert partition_by == table.metadata().partition_columns + + if table.protocol().min_writer_version > 1: + raise DeltaTableProtocolError( + "This table's min_writer_version is " + f"{table.protocol().min_writer_version}, " + "but this method only supports version 1." + ) + else: # creating a new table + current_version = -1 + + # TODO: Don't allow writing to non-empty directory + # Blocked on: Finish filesystem implementation in fs.py + # assert len(filesystem.get_file_info(pa_fs.FileSelector(table_uri, allow_not_found=True))) == 0 + + if partition_by: + partition_schema = pa.schema([schema.field(name) for name in partition_by]) + partitioning = ds.partitioning(partition_schema, flavor="hive") + else: + partitioning = None + + add_actions: List[AddAction] = [] + + def visitor(written_file: Any) -> None: + partition_values = get_partitions_from_path(table_uri, written_file.path) + stats = get_file_stats_from_metadata(written_file.metadata) + + add_actions.append( + AddAction( + written_file.path, + written_file.metadata.serialized_size, + partition_values, + int(datetime.now().timestamp()), + True, + json.dumps(stats, cls=DeltaJSONEncoder), + ) + ) + + ds.write_dataset( + data, + base_dir=table_uri, + basename_template=f"{current_version + 1}-{uuid.uuid4()}-{{i}}.parquet", + format="parquet", + partitioning=partitioning, + # It will not accept a schema if using a RBR + schema=schema if not isinstance(data, RecordBatchReader) else None, + file_visitor=visitor, + existing_data_behavior="overwrite_or_ignore", + ) + + if table is None: + _write_new_deltalake(table_uri, schema, add_actions, mode, partition_by or []) + else: + table._table.create_write_transaction( + add_actions, + mode, + partition_by or [], + ) + + +class DeltaJSONEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, bytes): + return obj.decode("unicode_escape") + elif isinstance(obj, date): + return obj.isoformat() + elif isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, Decimal): + return str(obj) + # Let the base class default method raise the TypeError + return json.JSONEncoder.default(self, obj) + + +def try_get_deltatable(table_uri: str) -> Optional[DeltaTable]: + try: + return DeltaTable(table_uri) + except PyDeltaTableError as err: + if "Not a Delta table" not in str(err): + raise + return None + + +def get_partitions_from_path(base_path: str, path: str) -> Dict[str, str]: + path = path.split(base_path, maxsplit=1)[1] + parts = path.split("/") + parts.pop() # remove filename + out = {} + for part in parts: + if part == "": + continue + key, value = part.split("=", maxsplit=1) + out[key] = value + return out + + +def get_file_stats_from_metadata( + metadata: Any, +) -> Dict[str, Union[int, Dict[str, Any]]]: + stats = { + "numRecords": metadata.num_rows, + "minValues": {}, + "maxValues": {}, + "nullCount": {}, + } + + def iter_groups(metadata: Any) -> Iterator[Any]: + for i in range(metadata.num_row_groups): + yield metadata.row_group(i) + + for column_idx in range(metadata.num_columns): + name = metadata.row_group(0).column(column_idx).path_in_schema + # If stats missing, then we can't know aggregate stats + if all( + group.column(column_idx).is_stats_set for group in iter_groups(metadata) + ): + stats["nullCount"][name] = sum( + group.column(column_idx).statistics.null_count + for group in iter_groups(metadata) + ) + + # I assume for now this is based on data type, and thus is + # consistent between groups + if metadata.row_group(0).column(column_idx).statistics.has_min_max: + # Min and Max are recorded in physical type, not logical type + # https://stackoverflow.com/questions/66753485/decoding-parquet-min-max-statistics-for-decimal-type + # TODO: Add logic to decode physical type for DATE, DECIMAL + logical_type = ( + metadata.row_group(0) + .column(column_idx) + .statistics.logical_type.type + ) + # + if logical_type not in ["STRING", "INT", "TIMESTAMP", "NONE"]: + continue + # import pdb; pdb.set_trace() + stats["minValues"][name] = min( + group.column(column_idx).statistics.min + for group in iter_groups(metadata) + ) + stats["maxValues"][name] = max( + group.column(column_idx).statistics.max + for group in iter_groups(metadata) + ) + return stats diff --git a/python/docs/source/api_reference.rst b/python/docs/source/api_reference.rst index 0fcb6579c3..09659ebc10 100644 --- a/python/docs/source/api_reference.rst +++ b/python/docs/source/api_reference.rst @@ -7,6 +7,11 @@ DeltaTable .. automodule:: deltalake.table :members: +Writing DeltaTables +------------------- + +.. autofunction:: deltalake.write_deltalake + DeltaSchema ----------- diff --git a/python/docs/source/conf.py b/python/docs/source/conf.py index 5fbd59eb17..bb6c11237b 100644 --- a/python/docs/source/conf.py +++ b/python/docs/source/conf.py @@ -42,7 +42,12 @@ def get_release_version() -> str: # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ["sphinx_rtd_theme", "sphinx.ext.autodoc", "edit_on_github"] +extensions = [ + "sphinx_rtd_theme", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "edit_on_github", +] autodoc_typehints = "description" nitpicky = True nitpick_ignore = [ @@ -52,6 +57,7 @@ def get_release_version() -> str: ("py:class", "pyarrow.lib.DataType"), ("py:class", "pyarrow.lib.Field"), ("py:class", "pyarrow.lib.NativeFile"), + ("py:class", "pyarrow.lib.RecordBatchReader"), ("py:class", "pyarrow._fs.FileSystem"), ("py:class", "pyarrow._fs.FileInfo"), ("py:class", "pyarrow._fs.FileSelector"), @@ -84,3 +90,10 @@ def get_release_version() -> str: edit_on_github_project = "delta-io/delta-rs" edit_on_github_branch = "main" page_source_prefix = "python/docs/source" + + +intersphinx_mapping = { + "pyarrow": ("https://arrow.apache.org/docs/", None), + "pyspark": ("https://spark.apache.org/docs/latest/api/python/", None), + "pandas": ("https://pandas.pydata.org/docs/", None), +} diff --git a/python/docs/source/usage.rst b/python/docs/source/usage.rst index e72aa802c2..249eee01c2 100644 --- a/python/docs/source/usage.rst +++ b/python/docs/source/usage.rst @@ -328,4 +328,32 @@ Optimizing tables is not currently supported. Writing Delta Tables -------------------- -Writing Delta tables is not currently supported. +.. py:currentmodule:: deltalake + +.. warning:: + The writer is currently *experimental*. Please use on test data first, not + on production data. Report any issues at https://github.com/delta-io/delta-rs/issues. + +For overwrites and appends, use :py:func:`write_deltalake`. If the table does not +already exist, it will be created. The ``data`` parameter will accept a Pandas +DataFrame, a PyArrow Table, or an iterator of PyArrow Record Batches. + +.. code-block:: python + + >>> from deltalake.writer import write_deltalake + >>> df = pd.DataFrame({'x': [1, 2, 3]}) + >>> write_deltalake('path/to/table', df) + +.. note:: + :py:func:`write_deltalake` accepts a Pandas DataFrame, but will convert it to + a Arrow table before writing. See caveats in :doc:`pyarrow:python/pandas`. + +By default, writes create a new table and error if it already exists. This is +controlled by the ``mode`` parameter, which mirrors the behavior of Spark's +:py:meth:`pyspark.sql.DataFrameWriter.saveAsTable` DataFrame method. To overwrite pass in ``mode='overwrite'`` and +to append pass in ``mode='append'``: + +.. code-block:: python + + >>> write_deltalake('path/to/table', df, mode='overwrite') + >>> write_deltalake('path/to/table', df, mode='append') diff --git a/python/pyproject.toml b/python/pyproject.toml index bd02df5790..7b096de6f5 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -36,7 +36,8 @@ devel = [ "sphinx", "sphinx-rtd-theme", "toml", - "pandas" + "pandas", + "typing-extensions" ] [project.urls] diff --git a/python/src/lib.rs b/python/src/lib.rs index fcab8b11b0..64762a6beb 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -3,19 +3,28 @@ extern crate pyo3; use chrono::{DateTime, FixedOffset, Utc}; -use deltalake::action::Stats; -use deltalake::action::{ColumnCountStat, ColumnValueStat}; +use deltalake::action; +use deltalake::action::Action; +use deltalake::action::{ColumnCountStat, ColumnValueStat, DeltaOperation, SaveMode, Stats}; use deltalake::arrow::datatypes::Schema as ArrowSchema; +use deltalake::get_backend_for_uri; use deltalake::partitions::PartitionFilter; use deltalake::storage; +use deltalake::DeltaDataTypeLong; +use deltalake::DeltaDataTypeTimestamp; +use deltalake::DeltaTableMetaData; +use deltalake::DeltaTransactionOptions; use deltalake::{arrow, StorageBackend}; use pyo3::create_exception; use pyo3::exceptions::PyException; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyTuple, PyType}; use std::collections::HashMap; use std::collections::HashSet; use std::convert::TryFrom; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; create_exception!(deltalake, PyDeltaTableError, PyException); @@ -145,6 +154,13 @@ impl RawDeltaTable { }) } + pub fn protocol_versions(&self) -> PyResult<(i32, i32)> { + Ok(( + self._table.get_min_reader_version(), + self._table.get_min_writer_version(), + )) + } + pub fn load_version(&mut self, version: deltalake::DeltaDataTypeVersion) -> PyResult<()> { rt()? .block_on(self._table.load_version(version)) @@ -272,6 +288,50 @@ impl RawDeltaTable { }) .collect() } + + fn create_write_transaction( + &mut self, + add_actions: Vec, + mode: &str, + partition_by: Vec, + ) -> PyResult<()> { + let mode = save_mode_from_str(mode)?; + + let mut actions: Vec = add_actions + .iter() + .map(|add| Action::add(add.into())) + .collect(); + + if let SaveMode::Overwrite = mode { + // Remove all current files + for old_add in self._table.get_state().files().iter() { + let remove_action = Action::remove(action::Remove { + path: old_add.path.clone(), + deletion_timestamp: Some(current_timestamp()), + data_change: true, + extended_file_metadata: Some(old_add.tags.is_some()), + partition_values: Some(old_add.partition_values.clone()), + size: Some(old_add.size), + tags: old_add.tags.clone(), + }); + actions.push(remove_action); + } + } + + let mut transaction = self + ._table + .create_transaction(Some(DeltaTransactionOptions::new(3))); + transaction.add_actions(actions); + rt()? + .block_on(transaction.commit(Some(DeltaOperation::Write { + mode, + partitionBy: Some(partition_by), + predicate: None, + }))) + .map_err(PyDeltaTableError::from_raw)?; + + Ok(()) + } } fn json_value_to_py(value: &serde_json::Value, py: Python) -> PyObject { @@ -409,12 +469,96 @@ fn rust_core_version() -> &'static str { deltalake::crate_version() } +fn save_mode_from_str(value: &str) -> PyResult { + match value { + "append" => Ok(SaveMode::Append), + "overwrite" => Ok(SaveMode::Overwrite), + "error" => Ok(SaveMode::ErrorIfExists), + "ignore" => Ok(SaveMode::Ignore), + _ => Err(PyValueError::new_err("Invalid save mode")), + } +} + +fn current_timestamp() -> DeltaDataTypeTimestamp { + let start = SystemTime::now(); + let since_the_epoch = start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + since_the_epoch.as_millis().try_into().unwrap() +} + +#[derive(FromPyObject)] +pub struct PyAddAction { + path: String, + size: DeltaDataTypeLong, + partition_values: HashMap>, + modification_time: DeltaDataTypeTimestamp, + data_change: bool, + stats: Option, +} + +impl From<&PyAddAction> for action::Add { + fn from(action: &PyAddAction) -> Self { + action::Add { + path: action.path.clone(), + size: action.size, + partition_values: action.partition_values.clone(), + partition_values_parsed: None, + modification_time: action.modification_time, + data_change: action.data_change, + stats: action.stats.clone(), + stats_parsed: None, + tags: None, + } + } +} + +#[pyfunction] +fn write_new_deltalake( + table_uri: String, + schema: ArrowSchema, + add_actions: Vec, + _mode: &str, + partition_by: Vec, +) -> PyResult<()> { + let mut table = deltalake::DeltaTable::new( + &table_uri, + get_backend_for_uri(&table_uri).map_err(PyDeltaTableError::from_storage)?, + deltalake::DeltaTableConfig::default(), + ) + .map_err(PyDeltaTableError::from_raw)?; + + let metadata = DeltaTableMetaData::new( + None, + None, + None, + (&schema).try_into()?, + partition_by, + HashMap::new(), + ); + + let fut = table.create( + metadata, + action::Protocol { + min_reader_version: 1, + min_writer_version: 1, // TODO: Make sure we comply with protocol + }, + None, // TODO + Some(add_actions.iter().map(|add| add.into()).collect()), + ); + + rt()?.block_on(fut).map_err(PyDeltaTableError::from_raw)?; + + Ok(()) +} + #[pymodule] // module name need to match project name fn deltalake(py: Python, m: &PyModule) -> PyResult<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("warn")).init(); m.add_function(pyo3::wrap_pyfunction!(rust_core_version, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(write_new_deltalake, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/python/stubs/deltalake/deltalake.pyi b/python/stubs/deltalake/deltalake.pyi index f420569e5d..b872e61bd3 100644 --- a/python/stubs/deltalake/deltalake.pyi +++ b/python/stubs/deltalake/deltalake.pyi @@ -1,6 +1,13 @@ -from typing import Any, Callable +from typing import Any, Callable, List + +import pyarrow as pa + +from deltalake.writer import AddAction RawDeltaTable: Any -PyDeltaTableError: Any rust_core_version: Callable[[], str] DeltaStorageFsBackend: Any + +write_new_deltalake: Callable[[str, pa.Schema, List[AddAction], str, List[str]], None] + +class PyDeltaTableError(BaseException): ... diff --git a/python/stubs/pyarrow/__init__.pyi b/python/stubs/pyarrow/__init__.pyi index 3b28a2a714..c7cef34ba9 100644 --- a/python/stubs/pyarrow/__init__.pyi +++ b/python/stubs/pyarrow/__init__.pyi @@ -2,6 +2,7 @@ from typing import Any, Callable Schema: Any Table: Any +RecordBatch: Any Field: Any DataType: Any schema: Any diff --git a/python/stubs/pyarrow/dataset.pyi b/python/stubs/pyarrow/dataset.pyi index 5d9683dee1..d06f843246 100644 --- a/python/stubs/pyarrow/dataset.pyi +++ b/python/stubs/pyarrow/dataset.pyi @@ -5,3 +5,4 @@ dataset: Any partitioning: Any FileSystemDataset: Any ParquetFileFormat: Any +write_dataset: Any diff --git a/python/stubs/pyarrow/lib.pyi b/python/stubs/pyarrow/lib.pyi new file mode 100644 index 0000000000..fc97dea727 --- /dev/null +++ b/python/stubs/pyarrow/lib.pyi @@ -0,0 +1,3 @@ +from typing import Any + +RecordBatchReader: Any diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/tests/data/date_partitioned_df/_delta_log/.00000000000000000000.json.crc b/python/tests/data/date_partitioned_df/_delta_log/.00000000000000000000.json.crc deleted file mode 100644 index f141a1d1b7..0000000000 Binary files a/python/tests/data/date_partitioned_df/_delta_log/.00000000000000000000.json.crc and /dev/null differ diff --git a/python/tests/data/date_partitioned_df/_delta_log/00000000000000000000.json b/python/tests/data/date_partitioned_df/_delta_log/00000000000000000000.json deleted file mode 100644 index 9c01cf24d8..0000000000 --- a/python/tests/data/date_partitioned_df/_delta_log/00000000000000000000.json +++ /dev/null @@ -1,4 +0,0 @@ -{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} -{"metaData":{"id":"588135b2-b298-4d9f-aab6-6dd9bf90d575","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"id\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"date\",\"type\":\"date\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["date"],"configuration":{},"createdTime":1645893400586}} -{"add":{"path":"date=2021-01-01/part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet","partitionValues":{"date":"2021-01-01"},"size":500,"modificationTime":1645893404567,"dataChange":true}} -{"commitInfo":{"timestamp":1645893404671,"operation":"WRITE","operationParameters":{"mode":"ErrorIfExists","partitionBy":"[\"date\"]"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"5","numOutputBytes":"500"},"engineInfo":"Apache-Spark/3.2.1 Delta-Lake/1.1.0"}} diff --git a/python/tests/data/date_partitioned_df/date=2021-01-01/.part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet.crc b/python/tests/data/date_partitioned_df/date=2021-01-01/.part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet.crc deleted file mode 100644 index 3271be8603..0000000000 Binary files a/python/tests/data/date_partitioned_df/date=2021-01-01/.part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet.crc and /dev/null differ diff --git a/python/tests/data/date_partitioned_df/date=2021-01-01/part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet b/python/tests/data/date_partitioned_df/date=2021-01-01/part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet deleted file mode 100644 index c59e5b21d8..0000000000 Binary files a/python/tests/data/date_partitioned_df/date=2021-01-01/part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet and /dev/null differ diff --git a/python/tests/data/timestamp_partitioned_df/_delta_log/.00000000000000000000.json.crc b/python/tests/data/timestamp_partitioned_df/_delta_log/.00000000000000000000.json.crc deleted file mode 100644 index 2fcd6c71fa..0000000000 Binary files a/python/tests/data/timestamp_partitioned_df/_delta_log/.00000000000000000000.json.crc and /dev/null differ diff --git a/python/tests/data/timestamp_partitioned_df/_delta_log/00000000000000000000.json b/python/tests/data/timestamp_partitioned_df/_delta_log/00000000000000000000.json deleted file mode 100644 index 5bb061bf67..0000000000 --- a/python/tests/data/timestamp_partitioned_df/_delta_log/00000000000000000000.json +++ /dev/null @@ -1,4 +0,0 @@ -{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} -{"metaData":{"id":"e79e0060-0670-46ed-9e93-9dff7ac96b07","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"id\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"date\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["date"],"configuration":{},"createdTime":1645893413372}} -{"add":{"path":"date=2021-01-01%2000%253A00%253A00/part-00000-6177a755-69ce-44ea-bbfa-0479c4c1e704.c000.snappy.parquet","partitionValues":{"date":"2021-01-01 00:00:00"},"size":500,"modificationTime":1645893413662,"dataChange":true}} -{"commitInfo":{"timestamp":1645893413668,"operation":"WRITE","operationParameters":{"mode":"ErrorIfExists","partitionBy":"[\"date\"]"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"5","numOutputBytes":"500"},"engineInfo":"Apache-Spark/3.2.1 Delta-Lake/1.1.0"}} diff --git a/python/tests/data/timestamp_partitioned_df/date=2021-01-01 00%3A00%3A00/.part-00000-6177a755-69ce-44ea-bbfa-0479c4c1e704.c000.snappy.parquet.crc b/python/tests/data/timestamp_partitioned_df/date=2021-01-01 00%3A00%3A00/.part-00000-6177a755-69ce-44ea-bbfa-0479c4c1e704.c000.snappy.parquet.crc deleted file mode 100644 index 3271be8603..0000000000 Binary files a/python/tests/data/timestamp_partitioned_df/date=2021-01-01 00%3A00%3A00/.part-00000-6177a755-69ce-44ea-bbfa-0479c4c1e704.c000.snappy.parquet.crc and /dev/null differ diff --git a/python/tests/data/timestamp_partitioned_df/date=2021-01-01 00%3A00%3A00/part-00000-6177a755-69ce-44ea-bbfa-0479c4c1e704.c000.snappy.parquet b/python/tests/data/timestamp_partitioned_df/date=2021-01-01 00%3A00%3A00/part-00000-6177a755-69ce-44ea-bbfa-0479c4c1e704.c000.snappy.parquet deleted file mode 100644 index c59e5b21d8..0000000000 Binary files a/python/tests/data/timestamp_partitioned_df/date=2021-01-01 00%3A00%3A00/part-00000-6177a755-69ce-44ea-bbfa-0479c4c1e704.c000.snappy.parquet and /dev/null differ diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index 78eed79385..296a5d8c5f 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -232,9 +232,7 @@ def test_schema_pyarrow_types(): } ) assert pyarrow_field.name == field_name - assert pyarrow_field.type == pyarrow.list_( - pyarrow.field("element", pyarrow.int32()) - ) + assert pyarrow_field.type == pyarrow.list_(pyarrow.field("item", pyarrow.int32())) assert pyarrow_field.metadata == metadata assert pyarrow_field.nullable is False @@ -276,7 +274,7 @@ def test_schema_pyarrow_types(): pyarrow.int32(), pyarrow.list_( pyarrow.field( - "element", + "entries", pyarrow.struct( [pyarrow.field("val", pyarrow.int32(), False, metadata)] ), diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index ebc0ca1c1e..a5ab208a48 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -95,20 +95,6 @@ def test_read_partitioned_table_to_dict(): assert dt.to_pyarrow_dataset().to_table().to_pydict() == expected -@pytest.mark.parametrize( - "name,partition_type", [("date", pa.date32()), ("timestamp", pa.timestamp("us"))] -) -def test_read_date_partitioned_table(name, partition_type): - table_path = f"tests/data/{name}_partitioned_df" - dt = DeltaTable(table_path) - table = dt.to_pyarrow_table() - assert table["date"].type == partition_type - date_expected = pa.array([date(2021, 1, 1)] * 5).cast(partition_type) - assert table["date"] == pa.chunked_array([date_expected]) - id_expected = pa.array(range(5)) - assert table["id"] == pa.chunked_array([id_expected]) - - def test_read_partitioned_table_with_partitions_filters_to_dict(): table_path = "../rust/tests/data/delta-0.8.0-partitioned" dt = DeltaTable(table_path) @@ -227,6 +213,14 @@ def test_read_partitioned_table_metadata(): assert metadata.configuration == {} +def test_read_partitioned_table_protocol(): + table_path = "../rust/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + protocol = dt.protocol() + assert protocol.min_reader_version == 1 + assert protocol.min_writer_version == 2 + + def test_history_partitioned_table_metadata(): table_path = "../rust/tests/data/delta-0.8.0-partitioned" dt = DeltaTable(table_path) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py new file mode 100644 index 0000000000..b87806aafc --- /dev/null +++ b/python/tests/test_writer.py @@ -0,0 +1,294 @@ +import json +import os +import pathlib +import sys +from datetime import date, datetime, timedelta +from decimal import Decimal +from unittest.mock import Mock + +import pyarrow as pa +import pyarrow.compute as pc +import pytest +from pandas.testing import assert_frame_equal +from pyarrow.lib import RecordBatchReader + +from deltalake import DeltaTable, write_deltalake +from deltalake.table import ProtocolVersions +from deltalake.writer import DeltaTableProtocolError + + +def _is_old_glibc_version(): + if "CS_GNU_LIBC_VERSION" in os.confstr_names: + version = os.confstr("CS_GNU_LIBC_VERSION").split(" ")[1] + return version < "2.28" + else: + return False + + +if sys.platform == "win32": + pytest.skip("Writer isn't yet supported on Windows", allow_module_level=True) + +if _is_old_glibc_version(): + pytest.skip( + "Writer isn't yet supported on Linux with glibc < 2.28", allow_module_level=True + ) + + +@pytest.fixture() +def sample_data(): + nrows = 5 + return pa.table( + { + "utf8": pa.array([str(x) for x in range(nrows)]), + "int64": pa.array(list(range(nrows)), pa.int64()), + "int32": pa.array(list(range(nrows)), pa.int32()), + "int16": pa.array(list(range(nrows)), pa.int16()), + "int8": pa.array(list(range(nrows)), pa.int8()), + "float32": pa.array([float(x) for x in range(nrows)], pa.float32()), + "float64": pa.array([float(x) for x in range(nrows)], pa.float64()), + "bool": pa.array([x % 2 == 0 for x in range(nrows)]), + "binary": pa.array([str(x).encode() for x in range(nrows)]), + "decimal": pa.array([Decimal("10.000") + x for x in range(nrows)]), + "date32": pa.array( + [date(2022, 1, 1) + timedelta(days=x) for x in range(nrows)] + ), + "timestamp": pa.array( + [datetime(2022, 1, 1) + timedelta(hours=x) for x in range(nrows)] + ), + "struct": pa.array([{"x": x, "y": str(x)} for x in range(nrows)]), + "list": pa.array([list(range(x + 1)) for x in range(nrows)]), + # NOTE: https://github.com/apache/arrow-rs/issues/477 + #'map': pa.array([[(str(y), y) for y in range(x)] for x in range(nrows)], pa.map_(pa.string(), pa.int64())), + } + ) + + +@pytest.fixture() +def existing_table(tmp_path: pathlib.Path, sample_data: pa.Table): + path = str(tmp_path) + write_deltalake(path, sample_data) + return DeltaTable(path) + + +@pytest.mark.skip(reason="Waiting on #570") +def test_handle_existing(tmp_path: pathlib.Path, sample_data: pa.Table): + # if uri points to a non-empty directory that isn't a delta table, error + tmp_path + p = tmp_path / "hello.txt" + p.write_text("hello") + + with pytest.raises(OSError) as exception: + write_deltalake(str(tmp_path), sample_data, mode="overwrite") + + assert "directory is not empty" in str(exception) + + +def test_roundtrip_basic(tmp_path: pathlib.Path, sample_data: pa.Table): + write_deltalake(str(tmp_path), sample_data) + + assert ("0" * 20 + ".json") in os.listdir(tmp_path / "_delta_log") + + delta_table = DeltaTable(str(tmp_path)) + assert delta_table.pyarrow_schema() == sample_data.schema + + table = delta_table.to_pyarrow_table() + assert table == sample_data + + +@pytest.mark.parametrize( + "column", + [ + "utf8", + "int64", + "int32", + "int16", + "int8", + "float32", + "float64", + "bool", + "binary", + "date32", + ], +) +def test_roundtrip_partitioned( + tmp_path: pathlib.Path, sample_data: pa.Table, column: str +): + write_deltalake(str(tmp_path), sample_data, partition_by=[column]) + + delta_table = DeltaTable(str(tmp_path)) + assert delta_table.pyarrow_schema() == sample_data.schema + + table = delta_table.to_pyarrow_table() + table = table.take(pc.sort_indices(table["int64"])) + assert table == sample_data + + +def test_roundtrip_multi_partitioned(tmp_path: pathlib.Path, sample_data: pa.Table): + write_deltalake(str(tmp_path), sample_data, partition_by=["int32", "bool"]) + + delta_table = DeltaTable(str(tmp_path)) + assert delta_table.pyarrow_schema() == sample_data.schema + + table = delta_table.to_pyarrow_table() + table = table.take(pc.sort_indices(table["int64"])) + assert table == sample_data + + +def test_write_modes(tmp_path: pathlib.Path, sample_data: pa.Table): + path = str(tmp_path) + + write_deltalake(path, sample_data) + assert DeltaTable(path).to_pyarrow_table() == sample_data + + with pytest.raises(AssertionError): + write_deltalake(path, sample_data, mode="error") + + write_deltalake(path, sample_data, mode="ignore") + assert ("0" * 19 + "1.json") not in os.listdir(tmp_path / "_delta_log") + + write_deltalake(path, sample_data, mode="append") + expected = pa.concat_tables([sample_data, sample_data]) + assert DeltaTable(path).to_pyarrow_table() == expected + + write_deltalake(path, sample_data, mode="overwrite") + assert DeltaTable(path).to_pyarrow_table() == sample_data + + +def test_writer_with_table(existing_table: DeltaTable, sample_data: pa.Table): + write_deltalake(existing_table, sample_data, mode="overwrite") + existing_table.update_incremental() + assert existing_table.to_pyarrow_table() == sample_data + + +def test_fails_wrong_partitioning(existing_table: DeltaTable, sample_data: pa.Table): + with pytest.raises(AssertionError): + write_deltalake( + existing_table, sample_data, mode="append", partition_by="int32" + ) + + +def test_write_pandas(tmp_path: pathlib.Path, sample_data: pa.Table): + # When timestamp is converted to Pandas, it gets casted to ns resolution, + # but Delta Lake schemas only support us resolution. + sample_pandas = sample_data.to_pandas().drop(["timestamp"], axis=1) + write_deltalake(str(tmp_path), sample_pandas) + + delta_table = DeltaTable(str(tmp_path)) + df = delta_table.to_pandas() + assert_frame_equal(df, sample_pandas) + + +def test_write_iterator( + tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table +): + batches = existing_table.to_pyarrow_dataset().to_batches() + with pytest.raises(ValueError): + write_deltalake(str(tmp_path), batches, mode="overwrite") + + write_deltalake(str(tmp_path), batches, schema=sample_data.schema, mode="overwrite") + assert DeltaTable(str(tmp_path)).to_pyarrow_table() == sample_data + + +def test_write_recordbatchreader( + tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table +): + batches = existing_table.to_pyarrow_dataset().to_batches() + reader = RecordBatchReader.from_batches(sample_data.schema, batches) + + write_deltalake(str(tmp_path), reader, mode="overwrite") + assert DeltaTable(str(tmp_path)).to_pyarrow_table() == sample_data + + +def test_writer_partitioning(tmp_path: pathlib.Path): + test_strings = ["a=b", "hello world", "hello%20world"] + data = pa.table( + {"p": pa.array(test_strings), "x": pa.array(range(len(test_strings)))} + ) + + write_deltalake(str(tmp_path), data) + + assert DeltaTable(str(tmp_path)).to_pyarrow_table() == data + + +def get_stats(table: DeltaTable): + log_path = table._table.table_uri() + "/_delta_log/" + ("0" * 20 + ".json") + + # Should only have single add entry + for line in open(log_path, "r").readlines(): + log_entry = json.loads(line) + + if "add" in log_entry: + return json.loads(log_entry["add"]["stats"]) + else: + raise AssertionError("No add action found!") + + +def test_writer_stats(existing_table: DeltaTable, sample_data: pa.Table): + stats = get_stats(existing_table) + + assert stats["numRecords"] == sample_data.num_rows + + assert all(null_count == 0 for null_count in stats["nullCount"].values()) + + expected_mins = { + "utf8": "0", + "int64": 0, + "int32": 0, + "int16": 0, + "int8": 0, + "float32": 0.0, + "float64": 0.0, + "bool": False, + "binary": "0", + # TODO: Writer needs special decoding for decimal and date32. + #'decimal': '10.000', + # "date32": '2022-01-01', + "timestamp": "2022-01-01T00:00:00", + "struct.x": 0, + "struct.y": "0", + "list.list.item": 0, + } + assert stats["minValues"] == expected_mins + + expected_maxs = { + "utf8": "4", + "int64": 4, + "int32": 4, + "int16": 4, + "int8": 4, + "float32": 4.0, + "float64": 4.0, + "bool": True, + "binary": "4", + #'decimal': '40.000', + # "date32": '2022-01-04', + "timestamp": "2022-01-01T04:00:00", + "struct.x": 4, + "struct.y": "4", + "list.list.item": 4, + } + assert stats["maxValues"] == expected_maxs + + +def test_writer_null_stats(tmp_path: pathlib.Path): + data = pa.table( + { + "int32": pa.array([1, None, 2, None], pa.int32()), + "float64": pa.array([1.0, None, None, None], pa.float64()), + "str": pa.array([None] * 4, pa.string()), + } + ) + path = str(tmp_path) + write_deltalake(path, data) + + table = DeltaTable(path) + stats = get_stats(table) + + expected_nulls = {"int32": 2, "float64": 3, "str": 4} + assert stats["nullCount"] == expected_nulls + + +def test_writer_fails_on_protocol(existing_table: DeltaTable, sample_data: pa.Table): + existing_table.protocol = Mock(return_value=ProtocolVersions(1, 2)) + with pytest.raises(DeltaTableProtocolError): + write_deltalake(existing_table, sample_data, mode="overwrite") diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 2382506037..3e71a90585 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -26,7 +26,10 @@ lazy_static = "1" percent-encoding = "2" # HTTP Client -reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "stream"], optional = true} +reqwest = { version = "0.11", default-features = false, features = [ + "rustls-tls", + "stream", +], optional = true } # Azure azure_core = { version = "0.1", optional = true } diff --git a/rust/src/delta.rs b/rust/src/delta.rs index 085f076a33..d67637c65f 100644 --- a/rust/src/delta.rs +++ b/rust/src/delta.rs @@ -1293,6 +1293,7 @@ impl DeltaTable { metadata: DeltaTableMetaData, protocol: action::Protocol, commit_info: Option>, + add_actions: Option>, ) -> Result<(), DeltaTableError> { let meta = action::MetaData::try_from(metadata)?; @@ -1307,11 +1308,16 @@ impl DeltaTable { Value::Number(serde_json::Number::from(Utc::now().timestamp_millis())), ); - let actions = vec![ + let mut actions = vec![ Action::commitInfo(enriched_commit_info), Action::protocol(protocol), Action::metaData(meta), ]; + if let Some(add_actions) = add_actions { + for add_action in add_actions { + actions.push(Action::add(add_action)); + } + }; let mut transaction = self.create_transaction(None); transaction.add_actions(actions.clone()); @@ -1812,7 +1818,7 @@ mod tests { serde_json::Value::String("test user".to_string()), ); // Action - dt.create(delta_md.clone(), protocol.clone(), Some(commit_info)) + dt.create(delta_md.clone(), protocol.clone(), Some(commit_info), None) .await .unwrap(); diff --git a/rust/src/delta_arrow.rs b/rust/src/delta_arrow.rs index d2f1bd9ce9..4df46fc48e 100644 --- a/rust/src/delta_arrow.rs +++ b/rust/src/delta_arrow.rs @@ -7,6 +7,7 @@ use arrow::datatypes::{ use arrow::error::ArrowError; use lazy_static::lazy_static; use regex::Regex; +use std::collections::HashMap; use std::convert::TryFrom; impl TryFrom<&schema::Schema> for ArrowSchema { @@ -162,6 +163,99 @@ impl TryFrom<&schema::SchemaDataType> for ArrowDataType { } } +impl TryFrom<&ArrowSchema> for schema::Schema { + type Error = ArrowError; + fn try_from(arrow_schema: &ArrowSchema) -> Result { + let new_fields: Result, _> = arrow_schema + .fields() + .iter() + .map(|field| field.try_into()) + .collect(); + Ok(schema::Schema::new(new_fields?)) + } +} + +impl TryFrom<&ArrowField> for schema::SchemaField { + type Error = ArrowError; + fn try_from(arrow_field: &ArrowField) -> Result { + Ok(schema::SchemaField::new( + arrow_field.name().clone(), + arrow_field.data_type().try_into()?, + arrow_field.is_nullable(), + arrow_field + .metadata() + .as_ref() + .map_or_else(HashMap::new, |m| { + m.iter() + .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone()))) + .collect() + }), + )) + } +} + +impl TryFrom<&ArrowDataType> for schema::SchemaDataType { + type Error = ArrowError; + fn try_from(arrow_datatype: &ArrowDataType) -> Result { + match arrow_datatype { + ArrowDataType::Utf8 => Ok(schema::SchemaDataType::primitive("string".to_string())), + ArrowDataType::Int64 => Ok(schema::SchemaDataType::primitive("long".to_string())), // undocumented type + ArrowDataType::Int32 => Ok(schema::SchemaDataType::primitive("integer".to_string())), + ArrowDataType::Int16 => Ok(schema::SchemaDataType::primitive("short".to_string())), + ArrowDataType::Int8 => Ok(schema::SchemaDataType::primitive("byte".to_string())), + ArrowDataType::Float32 => Ok(schema::SchemaDataType::primitive("float".to_string())), + ArrowDataType::Float64 => Ok(schema::SchemaDataType::primitive("double".to_string())), + ArrowDataType::Boolean => Ok(schema::SchemaDataType::primitive("boolean".to_string())), + ArrowDataType::Binary => Ok(schema::SchemaDataType::primitive("binary".to_string())), + ArrowDataType::Decimal(p, s) => Ok(schema::SchemaDataType::primitive(format!( + "decimal({},{})", + p, s + ))), + ArrowDataType::Date32 => Ok(schema::SchemaDataType::primitive("date".to_string())), + ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { + Ok(schema::SchemaDataType::primitive("timestamp".to_string())) + } + ArrowDataType::Struct(fields) => { + let converted_fields: Result, _> = + fields.iter().map(|field| field.try_into()).collect(); + Ok(schema::SchemaDataType::r#struct( + schema::SchemaTypeStruct::new(converted_fields?), + )) + } + ArrowDataType::List(field) => { + Ok(schema::SchemaDataType::array(schema::SchemaTypeArray::new( + Box::new((*field).data_type().try_into()?), + (*field).is_nullable(), + ))) + } + ArrowDataType::FixedSizeList(field, _) => { + Ok(schema::SchemaDataType::array(schema::SchemaTypeArray::new( + Box::new((*field).data_type().try_into()?), + (*field).is_nullable(), + ))) + } + ArrowDataType::Map(field, _) => { + if let ArrowDataType::Struct(struct_fields) = field.data_type() { + let key_type = struct_fields[0].data_type().try_into()?; + let value_type = struct_fields[1].data_type().try_into()?; + let value_type_nullable = struct_fields[1].is_nullable(); + Ok(schema::SchemaDataType::map(schema::SchemaTypeMap::new( + Box::new(key_type), + Box::new(value_type), + value_type_nullable, + ))) + } else { + panic!("DataType::Map should contain a struct field child"); + } + } + s => Err(ArrowError::SchemaError(format!( + "Invalid data type for Delta Lake: {}", + s + ))), + } + } +} + /// Returns an arrow schema representing the delta log for use in checkpoints /// /// # Arguments diff --git a/rust/tests/adls_gen2_table_test.rs b/rust/tests/adls_gen2_table_test.rs index 2f3b95efef..b1ad2f3d6b 100644 --- a/rust/tests/adls_gen2_table_test.rs +++ b/rust/tests/adls_gen2_table_test.rs @@ -88,7 +88,7 @@ mod adls_gen2_table { let (metadata, protocol) = table_info(); // Act 1 - dt.create(metadata.clone(), protocol.clone(), None) + dt.create(metadata.clone(), protocol.clone(), None, None) .await .unwrap(); diff --git a/rust/tests/concurrent_writes_test.rs b/rust/tests/concurrent_writes_test.rs index 14d378f929..be7851197c 100644 --- a/rust/tests/concurrent_writes_test.rs +++ b/rust/tests/concurrent_writes_test.rs @@ -83,7 +83,7 @@ async fn concurrent_writes_azure() { min_writer_version: 2, }; - dt.create(metadata.clone(), protocol.clone(), None) + dt.create(metadata.clone(), protocol.clone(), None, None) .await .unwrap(); diff --git a/rust/tests/fs_common/mod.rs b/rust/tests/fs_common/mod.rs index 63c1b2b017..80984eab04 100644 --- a/rust/tests/fs_common/mod.rs +++ b/rust/tests/fs_common/mod.rs @@ -53,7 +53,7 @@ pub async fn create_test_table( min_reader_version: 1, min_writer_version: 2, }; - table.create(md, protocol, None).await.unwrap(); + table.create(md, protocol, None, None).await.unwrap(); table }