Skip to content

Commit

Permalink
Selectively overwrite data with python (#1101)
Browse files Browse the repository at this point in the history
# Description
Currently high-level python writer isn't support partial partition
overwrite.
This PR enable usage of partitions filtering for writing data

The functionlity is similar to:
https://docs.databricks.com/delta/selective-overwrite.html

The logic checks that data should contains only partitions that passing
filtering.

# Documentation
```python
    write_deltalake(
        delta_path,
        sample_data,
        mode="overwrite",
        partitions_filters=[("partition_a", ">", "1")],
    )
```

---------

Co-authored-by: Ilya Moshkov <[email protected]>
  • Loading branch information
ismoshkov and Ilya Moshkov authored Feb 25, 2023
1 parent 1b617e4 commit ac1ce57
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 6 deletions.
32 changes: 32 additions & 0 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def write_deltalake(
configuration: Optional[Mapping[str, Optional[str]]] = None,
overwrite_schema: bool = False,
storage_options: Optional[Dict[str, str]] = None,
partitions_filters: Optional[List[Tuple[str, str, Any]]] = None,
) -> None:
"""Write to a Delta Lake table (Experimental)
Expand Down Expand Up @@ -132,6 +133,7 @@ def write_deltalake(
:param configuration: A map containing configuration options for the metadata action.
:param overwrite_schema: If True, allows updating the schema of the table.
:param storage_options: options passed to the native delta filesystem. Unused if 'filesystem' is defined.
:param partitions_filters: the partition filters that will be used for partition overwrite.
"""
if _has_pandas and isinstance(data, pd.DataFrame):
if schema is not None:
Expand Down Expand Up @@ -179,6 +181,8 @@ def write_deltalake(

if partition_by:
assert partition_by == table.metadata().partition_columns
else:
partition_by = table.metadata().partition_columns

if table.protocol().min_writer_version > MAX_SUPPORTED_WRITER_VERSION:
raise DeltaTableProtocolError(
Expand Down Expand Up @@ -224,8 +228,35 @@ def visitor(written_file: Any) -> None:
invariants = table.schema().invariants
checker = _DeltaDataChecker(invariants)

def check_data_is_aligned_with_partition_filtering(
batch: pa.RecordBatch,
) -> None:
if table is None:
return
existed_partitions = table._table.get_active_partitions()
allowed_partitions = table._table.get_active_partitions(partitions_filters)
for column_index, column_name in enumerate(batch.schema.names):
if column_name in table.metadata().partition_columns:
for value in batch.column(column_index).unique():
partition = (
column_name,
json.dumps(value.as_py(), cls=DeltaJSONEncoder),
)
if (
partition not in allowed_partitions
and partition in existed_partitions
):
raise ValueError(
f"Data should be aligned with partitioning. "
f"Partition '{column_name}'='{value}' should be filtered out from data."
)

def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:
checker.check_batch(batch)

if mode == "overwrite" and partitions_filters:
check_data_is_aligned_with_partition_filtering(batch)

return batch

if isinstance(data, RecordBatchReader):
Expand Down Expand Up @@ -277,6 +308,7 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:
mode,
partition_by or [],
schema,
partitions_filters,
)


Expand Down
37 changes: 35 additions & 2 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,36 @@ impl RawDeltaTable {
.collect()
}

fn get_active_partitions(
&mut self,
partitions_filters: Option<Vec<(&str, &str, PartitionFilterValue)>>,
) -> PyResult<HashSet<(String, Option<String>)>> {
let converted_filters = convert_partition_filters(partitions_filters.unwrap_or_default())
.map_err(PyDeltaTableError::from_raw)?;

let add_actions = self
._table
.get_state()
.get_active_add_actions_by_partitions(&converted_filters)
.map_err(PyDeltaTableError::from_raw)?;
let active_partitions = add_actions
.flat_map(|add| {
add.partition_values
.iter()
.map(|i| (i.0.to_owned(), i.1.to_owned()))
.collect::<Vec<_>>()
})
.collect::<HashSet<_>>();
Ok(active_partitions)
}

fn create_write_transaction(
&mut self,
add_actions: Vec<PyAddAction>,
mode: &str,
partition_by: Vec<String>,
schema: PyArrowType<ArrowSchema>,
partitions_filters: Option<Vec<(&str, &str, PartitionFilterValue)>>,
) -> PyResult<()> {
let mode = save_mode_from_str(mode)?;
let schema: Schema = (&schema.0)
Expand All @@ -388,8 +412,17 @@ impl RawDeltaTable {

match mode {
SaveMode::Overwrite => {
// Remove all current files
for old_add in self._table.get_state().files().iter() {
let converted_filters =
convert_partition_filters(partitions_filters.unwrap_or_default())
.map_err(PyDeltaTableError::from_raw)?;

let add_actions = self
._table
.get_state()
.get_active_add_actions_by_partitions(&converted_filters)
.map_err(PyDeltaTableError::from_raw)?;

for old_add in add_actions {
let remove_action = Action::remove(action::Remove {
path: old_add.path.clone(),
deletion_timestamp: Some(current_timestamp()),
Expand Down
200 changes: 196 additions & 4 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
import pathlib
import random
from datetime import datetime
from typing import Dict, Iterable, List
from datetime import date, datetime
from typing import Any, Dict, Iterable, List
from unittest.mock import Mock

import pyarrow as pa
Expand All @@ -14,7 +14,7 @@
from pyarrow.dataset import ParquetFileFormat, ParquetReadOptions
from pyarrow.lib import RecordBatchReader

from deltalake import DeltaTable, write_deltalake
from deltalake import DeltaTable, PyDeltaTableError, write_deltalake
from deltalake.table import ProtocolVersions
from deltalake.writer import DeltaTableProtocolError, try_get_table_and_table_uri

Expand Down Expand Up @@ -193,7 +193,7 @@ def test_roundtrip_null_partition(tmp_path: pathlib.Path, sample_data: pa.Table)
)
write_deltalake(tmp_path, sample_data, partition_by=["utf8_with_nulls"])

delta_table = DeltaTable(str(tmp_path))
delta_table = DeltaTable(tmp_path)
assert delta_table.schema().to_pyarrow() == sample_data.schema

table = delta_table.to_pyarrow_table()
Expand Down Expand Up @@ -533,3 +533,195 @@ def test_try_get_table_and_table_uri(tmp_path: pathlib.Path):
# table_or_uri with invalid parameter type
with pytest.raises(ValueError):
try_get_table_and_table_uri(None, None)


@pytest.mark.parametrize(
"value_1,value_2,value_type,filter_string",
[
(1, 2, pa.int64(), "1"),
(False, True, pa.bool_(), "false"),
(date(2022, 1, 1), date(2022, 1, 2), pa.date32(), "2022-01-01"),
],
)
def test_partition_overwrite(
tmp_path: pathlib.Path,
value_1: Any,
value_2: Any,
value_type: pa.DataType,
filter_string: str,
):
sample_data = pa.table(
{
"p1": pa.array(["1", "1", "2", "2"], pa.string()),
"p2": pa.array([value_1, value_2, value_1, value_2], value_type),
"val": pa.array([1, 1, 1, 1], pa.int64()),
}
)
write_deltalake(tmp_path, sample_data, mode="overwrite", partition_by=["p1", "p2"])

delta_table = DeltaTable(tmp_path)
assert (
delta_table.to_pyarrow_table().sort_by(
[("p1", "ascending"), ("p2", "ascending")]
)
== sample_data
)

sample_data = pa.table(
{
"p1": pa.array(["1", "1"], pa.string()),
"p2": pa.array([value_2, value_1], value_type),
"val": pa.array([2, 2], pa.int64()),
}
)
expected_data = pa.table(
{
"p1": pa.array(["1", "1", "2", "2"], pa.string()),
"p2": pa.array([value_1, value_2, value_1, value_2], value_type),
"val": pa.array([2, 2, 1, 1], pa.int64()),
}
)
write_deltalake(
tmp_path,
sample_data,
mode="overwrite",
partitions_filters=[("p1", "=", "1")],
)

delta_table.update_incremental()
assert (
delta_table.to_pyarrow_table().sort_by(
[("p1", "ascending"), ("p2", "ascending")]
)
== expected_data
)

sample_data = pa.table(
{
"p1": pa.array(["1", "2"], pa.string()),
"p2": pa.array([value_2, value_2], value_type),
"val": pa.array([3, 3], pa.int64()),
}
)
expected_data = pa.table(
{
"p1": pa.array(["1", "1", "2", "2"], pa.string()),
"p2": pa.array([value_1, value_2, value_1, value_2], value_type),
"val": pa.array([2, 3, 1, 3], pa.int64()),
}
)

write_deltalake(
tmp_path,
sample_data,
mode="overwrite",
partitions_filters=[("p2", ">", filter_string)],
)
delta_table.update_incremental()
assert (
delta_table.to_pyarrow_table().sort_by(
[("p1", "ascending"), ("p2", "ascending")]
)
== expected_data
)


@pytest.fixture()
def sample_data_for_partitioning() -> pa.Table:
return pa.table(
{
"p1": pa.array(["1", "1", "2", "2"], pa.string()),
"p2": pa.array([1, 2, 1, 2], pa.int64()),
"val": pa.array([1, 1, 1, 1], pa.int64()),
}
)


def test_partition_overwrite_unfiltered_data_fails(
tmp_path: pathlib.Path, sample_data_for_partitioning: pa.Table
):
write_deltalake(
tmp_path,
sample_data_for_partitioning,
mode="overwrite",
partition_by=["p1", "p2"],
)
with pytest.raises(ValueError):
write_deltalake(
tmp_path,
sample_data_for_partitioning,
mode="overwrite",
partitions_filters=[("p2", "=", "1")],
)


def test_partition_overwrite_with_new_partition(
tmp_path: pathlib.Path, sample_data_for_partitioning: pa.Table
):
write_deltalake(
tmp_path,
sample_data_for_partitioning,
mode="overwrite",
partition_by=["p1", "p2"],
)

new_sample_data = pa.table(
{
"p1": pa.array(["2", "1"], pa.string()),
"p2": pa.array([3, 2], pa.int64()),
"val": pa.array([2, 2], pa.int64()),
}
)
expected_data = pa.table(
{
"p1": pa.array(["1", "1", "2", "2"], pa.string()),
"p2": pa.array([1, 2, 1, 3], pa.int64()),
"val": pa.array([1, 2, 1, 2], pa.int64()),
}
)
write_deltalake(
tmp_path,
new_sample_data,
mode="overwrite",
partitions_filters=[("p2", "=", "2")],
)
delta_table = DeltaTable(tmp_path)
assert (
delta_table.to_pyarrow_table().sort_by(
[("p1", "ascending"), ("p2", "ascending")]
)
== expected_data
)


def test_partition_overwrite_with_non_partitioned_data(
tmp_path: pathlib.Path, sample_data_for_partitioning: pa.Table
):
write_deltalake(tmp_path, sample_data_for_partitioning, mode="overwrite")

with pytest.raises(PyDeltaTableError):
write_deltalake(
tmp_path,
sample_data_for_partitioning,
mode="overwrite",
partitions_filters=[("p1", "=", "1")],
)


def test_partition_overwrite_with_wrong_partition(
tmp_path: pathlib.Path, sample_data_for_partitioning: pa.Table
):
write_deltalake(
tmp_path,
sample_data_for_partitioning,
mode="overwrite",
partition_by=["p1", "p2"],
)

with pytest.raises(PyDeltaTableError):
write_deltalake(
tmp_path,
sample_data_for_partitioning,
mode="overwrite",
partitions_filters=[("p999", "=", "1")],
)

0 comments on commit ac1ce57

Please sign in to comment.