-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial interface for Data comparison framework (#1695)
Add initial interface for Data comparison framework
- Loading branch information
1 parent
0045aa1
commit ab03d5c
Showing
12 changed files
with
727 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class TableIdentifier: | ||
catalog: str | ||
schema: str | ||
table: str | ||
|
||
@property | ||
def catalog_escaped(self): | ||
return f"`{self.catalog}`" | ||
|
||
@property | ||
def schema_escaped(self): | ||
return f"`{self.schema}`" | ||
|
||
@property | ||
def table_escaped(self): | ||
return f"`{self.table}`" | ||
|
||
@property | ||
def fqn_escaped(self): | ||
return f"{self.catalog_escaped}.{self.schema_escaped}.{self.table_escaped}" | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ColumnMetadata: | ||
name: str | ||
data_type: str | ||
|
||
|
||
@dataclass | ||
class TableMetadata: | ||
identifier: TableIdentifier | ||
columns: list[ColumnMetadata] | ||
|
||
def get_column_metadata(self, column_name: str) -> ColumnMetadata | None: | ||
for column in self.columns: | ||
if column.name == column_name: | ||
return column | ||
return None | ||
|
||
|
||
@dataclass | ||
class DataProfilingResult: | ||
row_count: int | ||
table_metadata: TableMetadata | ||
|
||
|
||
@dataclass | ||
class SchemaComparisonEntry: | ||
source_column: str | None | ||
source_datatype: str | None | ||
target_column: str | None | ||
target_datatype: str | None | ||
is_matching: bool | ||
notes: str | None | ||
|
||
|
||
@dataclass | ||
class SchemaComparisonResult: | ||
is_matching: bool | ||
data: list[SchemaComparisonEntry] | ||
|
||
|
||
@dataclass | ||
class DataComparisonResult: | ||
source_row_count: int | ||
target_row_count: int | ||
num_missing_records_in_target: int | ||
num_missing_records_in_source: int | ||
|
||
|
||
class TableMetadataRetriever(ABC): | ||
@abstractmethod | ||
def get_metadata(self, entity: TableIdentifier) -> TableMetadata: | ||
""" | ||
Get metadata for a given table | ||
""" | ||
|
||
|
||
class DataProfiler(ABC): | ||
@abstractmethod | ||
def profile_data(self, entity: TableIdentifier) -> DataProfilingResult: | ||
""" | ||
Profile data for a given table | ||
""" | ||
|
||
|
||
class SchemaComparator(ABC): | ||
@abstractmethod | ||
def compare_schema(self, source: TableIdentifier, target: TableIdentifier) -> SchemaComparisonResult: | ||
""" | ||
Compare schema for two tables | ||
""" | ||
|
||
|
||
class DataComparator(ABC): | ||
@abstractmethod | ||
def compare_data(self, source: TableIdentifier, target: TableIdentifier) -> DataComparisonResult: | ||
""" | ||
Compare data for two tables | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from collections.abc import Iterator | ||
|
||
from databricks.labs.lsql.backends import SqlBackend | ||
from databricks.labs.lsql.core import Row | ||
|
||
from .base import ( | ||
DataComparator, | ||
DataComparisonResult, | ||
TableIdentifier, | ||
DataProfiler, | ||
DataProfilingResult, | ||
) | ||
|
||
|
||
class StandardDataComparator(DataComparator): | ||
DATA_COMPARISON_QUERY_TEMPLATE = """ | ||
WITH compare_results AS ( | ||
SELECT | ||
CASE | ||
WHEN source.hash_value IS NULL AND target.hash_value IS NULL THEN TRUE | ||
WHEN source.hash_value IS NULL OR target.hash_value IS NULL THEN FALSE | ||
WHEN source.hash_value = target.hash_value THEN TRUE | ||
ELSE FALSE | ||
END AS is_match, | ||
CASE | ||
WHEN target.hash_value IS NULL THEN 1 | ||
ELSE 0 | ||
END AS num_missing_records_in_target, | ||
CASE | ||
WHEN source.hash_value IS NULL THEN 1 | ||
ELSE 0 | ||
END AS num_missing_records_in_source | ||
FROM ( | ||
SELECT {source_hash_expr} AS hash_value | ||
FROM {source_table_fqn} | ||
) AS source | ||
FULL OUTER JOIN ( | ||
SELECT {target_hash_expr} AS hash_value | ||
FROM {target_table_fqn} | ||
) AS target | ||
ON source.hash_value = target.hash_value | ||
) | ||
SELECT | ||
COUNT(*) AS total_mismatches, | ||
COALESCE(SUM(num_missing_records_in_target), 0) AS num_missing_records_in_target, | ||
COALESCE(SUM(num_missing_records_in_source), 0) AS num_missing_records_in_source | ||
FROM compare_results | ||
WHERE is_match IS FALSE; | ||
""" | ||
|
||
def __init__(self, sql_backend: SqlBackend, data_profiler: DataProfiler): | ||
self._sql_backend = sql_backend | ||
self._data_profiler = data_profiler | ||
|
||
def compare_data(self, source: TableIdentifier, target: TableIdentifier) -> DataComparisonResult: | ||
""" | ||
This method compares the data of two tables. It takes two TableIdentifier objects as input, which represent | ||
the source and target tables for which the data are to be compared. | ||
Note: This method does not handle exceptions raised during the execution of the SQL query or | ||
the retrieval of the table metadata. These exceptions are expected to be handled by the caller in a manner | ||
appropriate for their context. | ||
""" | ||
source_data_profile = self._data_profiler.profile_data(source) | ||
target_data_profile = self._data_profiler.profile_data(target) | ||
comparison_query = StandardDataComparator.build_data_comparison_query( | ||
source_data_profile, | ||
target_data_profile, | ||
) | ||
query_result: Iterator[Row] = self._sql_backend.fetch(comparison_query) | ||
count_row = next(query_result) | ||
num_missing_records_in_target = int(count_row["num_missing_records_in_target"]) | ||
num_missing_records_in_source = int(count_row["num_missing_records_in_source"]) | ||
return DataComparisonResult( | ||
source_row_count=source_data_profile.row_count, | ||
target_row_count=target_data_profile.row_count, | ||
num_missing_records_in_target=num_missing_records_in_target, | ||
num_missing_records_in_source=num_missing_records_in_source, | ||
) | ||
|
||
@classmethod | ||
def build_data_comparison_query( | ||
cls, | ||
source_data_profile: DataProfilingResult, | ||
target_data_profile: DataProfilingResult, | ||
) -> str: | ||
source_table = source_data_profile.table_metadata.identifier | ||
target_table = target_data_profile.table_metadata.identifier | ||
source_hash_inputs = _build_data_comparison_hash_inputs(source_data_profile) | ||
target_hash_inputs = _build_data_comparison_hash_inputs(target_data_profile) | ||
comparison_query = StandardDataComparator.DATA_COMPARISON_QUERY_TEMPLATE.format( | ||
source_hash_expr=f"SHA2(CONCAT_WS('|', {', '.join(source_hash_inputs)}), 256)", | ||
target_hash_expr=f"SHA2(CONCAT_WS('|', {', '.join(target_hash_inputs)}), 256)", | ||
source_table_fqn=source_table.fqn_escaped, | ||
target_table_fqn=target_table.fqn_escaped, | ||
) | ||
|
||
return comparison_query | ||
|
||
|
||
def _build_data_comparison_hash_inputs(data_profile: DataProfilingResult) -> list[str]: | ||
source_metadata = data_profile.table_metadata | ||
inputs = [] | ||
for column in source_metadata.columns: | ||
data_type = column.data_type.lower() | ||
transformed_column = column.name | ||
|
||
if data_type.startswith("array"): | ||
transformed_column = f"TO_JSON(SORT_ARRAY({column.name}))" | ||
elif data_type.startswith("map") or data_type.startswith("struct"): | ||
transformed_column = f"TO_JSON({column.name})" | ||
|
||
inputs.append(f"COALESCE(TRIM({transformed_column}), '')") | ||
return inputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from collections.abc import Iterator | ||
|
||
from databricks.labs.lsql.backends import SqlBackend | ||
from databricks.labs.lsql.core import Row | ||
|
||
from .base import DataProfiler, DataProfilingResult, TableIdentifier, TableMetadataRetriever | ||
|
||
|
||
class StandardDataProfiler(DataProfiler): | ||
def __init__(self, sql_backend: SqlBackend, metadata_retriever: TableMetadataRetriever): | ||
self._sql_backend = sql_backend | ||
self._metadata_retriever = metadata_retriever | ||
|
||
def profile_data(self, entity: TableIdentifier) -> DataProfilingResult: | ||
""" | ||
This method profiles the data in the given table. It takes a TableIdentifier object as input, which represents | ||
the table to be profiled. The method performs two main operations: | ||
1. It retrieves the row count of the table. | ||
2. It retrieves the metadata of the table using a TableMetadataRetriever instance. | ||
Note: This method does not handle exceptions raised during the execution of the SQL query or the retrieval | ||
of the table metadata. These exceptions are expected to be handled by the caller | ||
in a manner appropriate for their context. | ||
""" | ||
row_count = self._get_table_row_count(entity) | ||
return DataProfilingResult( | ||
row_count, | ||
self._metadata_retriever.get_metadata(entity), | ||
) | ||
|
||
def _get_table_row_count(self, entity: TableIdentifier) -> int: | ||
query_result: Iterator[Row] = self._sql_backend.fetch(f"SELECT COUNT(*) as row_count FROM {entity.fqn_escaped}") | ||
count_row = next(query_result) | ||
return int(count_row[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from collections.abc import Iterator | ||
|
||
from databricks.labs.lsql.backends import SqlBackend | ||
from databricks.labs.lsql.core import Row | ||
|
||
from .base import TableIdentifier, TableMetadata, ColumnMetadata, TableMetadataRetriever | ||
|
||
|
||
class DatabricksTableMetadataRetriever(TableMetadataRetriever): | ||
def __init__(self, sql_backend: SqlBackend): | ||
self._sql_backend = sql_backend | ||
|
||
def get_metadata(self, entity: TableIdentifier) -> TableMetadata: | ||
""" | ||
This method retrieves the metadata for a given table. It takes a TableIdentifier object as input, | ||
which represents the table for which the metadata is to be retrieved. | ||
Note: This method does not handle exceptions raised during the execution of the SQL query. These exceptions are | ||
expected to be handled by the caller in a manner appropriate for their context. | ||
""" | ||
schema_query = DatabricksTableMetadataRetriever.build_metadata_query(entity) | ||
query_result: Iterator[Row] = self._sql_backend.fetch(schema_query) | ||
# The code uses a set comprehension to automatically deduplicate the column metadata entries, | ||
# Partition information are typically prefixed with a # symbol, | ||
# so any column name starting with # is excluded from the final set of column metadata. | ||
# The column metadata objects are sorted by column name to ensure a consistent order. | ||
columns = { | ||
ColumnMetadata(str(row["col_name"]), str(row["data_type"])) | ||
for row in query_result | ||
if not str(row["col_name"]).startswith("#") | ||
} | ||
return TableMetadata(entity, sorted(columns, key=lambda x: x.name)) | ||
|
||
@classmethod | ||
def build_metadata_query(cls, entity: TableIdentifier) -> str: | ||
if entity.catalog == "hive_metastore": | ||
return f"DESCRIBE TABLE {entity.fqn_escaped}" | ||
|
||
query = f""" | ||
SELECT | ||
LOWER(column_name) AS col_name, | ||
full_data_type AS data_type | ||
FROM | ||
{entity.catalog_escaped}.information_schema.columns | ||
WHERE | ||
LOWER(table_catalog)='{entity.catalog}' AND | ||
LOWER(table_schema)='{entity.schema}' AND | ||
LOWER(table_name) ='{entity.table}' | ||
ORDER BY col_name""" | ||
|
||
return query |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from .base import ( | ||
SchemaComparator, | ||
SchemaComparisonEntry, | ||
SchemaComparisonResult, | ||
TableMetadataRetriever, | ||
ColumnMetadata, | ||
TableIdentifier, | ||
) | ||
|
||
|
||
class StandardSchemaComparator(SchemaComparator): | ||
def __init__(self, metadata_retriever: TableMetadataRetriever): | ||
self._metadata_retriever = metadata_retriever | ||
|
||
def compare_schema(self, source: TableIdentifier, target: TableIdentifier) -> SchemaComparisonResult: | ||
""" | ||
This method compares the schema of two tables. It takes two TableIdentifier objects as input, which represent | ||
the source and target tables for which the schemas are to be compared. | ||
Note: This method does not handle exceptions raised during the execution of the SQL query or the retrieval | ||
of the table metadata. These exceptions are expected to be handled by the caller in a manner appropriate for | ||
their context. | ||
""" | ||
comparison_result = self._eval_schema_diffs(source, target) | ||
is_matching = all(entry.is_matching for entry in comparison_result) | ||
return SchemaComparisonResult(is_matching, comparison_result) | ||
|
||
def _eval_schema_diffs(self, source: TableIdentifier, target: TableIdentifier) -> list[SchemaComparisonEntry]: | ||
source_metadata = self._metadata_retriever.get_metadata(source) | ||
target_metadata = self._metadata_retriever.get_metadata(target) | ||
# Combine the sets of column names for both the source and target tables | ||
# to create a set of all unique column names from both tables. | ||
source_column_names = {column.name for column in source_metadata.columns} | ||
target_column_names = {column.name for column in target_metadata.columns} | ||
all_column_names = source_column_names.union(target_column_names) | ||
comparison_result = [] | ||
# Compare the column metadata from each table with a logic similar to full outer join. | ||
for field_name in sorted(all_column_names): | ||
source_col = source_metadata.get_column_metadata(field_name) | ||
target_col = target_metadata.get_column_metadata(field_name) | ||
comparison_result.append(_build_comparison_result_entry(source_col, target_col)) | ||
return comparison_result | ||
|
||
|
||
def _build_comparison_result_entry( | ||
source_col: ColumnMetadata | None, | ||
target_col: ColumnMetadata | None, | ||
) -> SchemaComparisonEntry: | ||
if source_col and target_col: | ||
is_matching = source_col == target_col | ||
notes = None | ||
else: | ||
is_matching = False | ||
notes = "Column is missing in " + ("target" if source_col else "source") | ||
|
||
return SchemaComparisonEntry( | ||
source_column=source_col.name if source_col else None, | ||
source_datatype=source_col.data_type if source_col else None, | ||
target_column=target_col.name if target_col else None, | ||
target_datatype=target_col.data_type if target_col else None, | ||
is_matching=is_matching, | ||
notes=notes, | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import pytest | ||
from databricks.labs.lsql.backends import MockBackend | ||
|
||
|
||
@pytest.fixture() | ||
def metadata_row_factory(): | ||
yield MockBackend.rows( | ||
"col_name", | ||
"data_type", | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def row_count_row_factory(): | ||
yield MockBackend.rows( | ||
"row_count", | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def data_comp_row_factory(): | ||
yield MockBackend.rows( | ||
"total_mismatches", | ||
"num_missing_records_in_target", | ||
"num_missing_records_in_source", | ||
) |
Oops, something went wrong.