Skip to content

Commit

Permalink
Add initial interface for Data comparison framework (#1695)
Browse files Browse the repository at this point in the history
Add initial interface for Data comparison framework
  • Loading branch information
bishwajit-db authored May 20, 2024
1 parent 0045aa1 commit ab03d5c
Show file tree
Hide file tree
Showing 12 changed files with 727 additions and 0 deletions.
Empty file.
105 changes: 105 additions & 0 deletions src/databricks/labs/ucx/recon/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass


@dataclass
class TableIdentifier:
catalog: str
schema: str
table: str

@property
def catalog_escaped(self):
return f"`{self.catalog}`"

@property
def schema_escaped(self):
return f"`{self.schema}`"

@property
def table_escaped(self):
return f"`{self.table}`"

@property
def fqn_escaped(self):
return f"{self.catalog_escaped}.{self.schema_escaped}.{self.table_escaped}"


@dataclass(frozen=True)
class ColumnMetadata:
name: str
data_type: str


@dataclass
class TableMetadata:
identifier: TableIdentifier
columns: list[ColumnMetadata]

def get_column_metadata(self, column_name: str) -> ColumnMetadata | None:
for column in self.columns:
if column.name == column_name:
return column
return None


@dataclass
class DataProfilingResult:
row_count: int
table_metadata: TableMetadata


@dataclass
class SchemaComparisonEntry:
source_column: str | None
source_datatype: str | None
target_column: str | None
target_datatype: str | None
is_matching: bool
notes: str | None


@dataclass
class SchemaComparisonResult:
is_matching: bool
data: list[SchemaComparisonEntry]


@dataclass
class DataComparisonResult:
source_row_count: int
target_row_count: int
num_missing_records_in_target: int
num_missing_records_in_source: int


class TableMetadataRetriever(ABC):
@abstractmethod
def get_metadata(self, entity: TableIdentifier) -> TableMetadata:
"""
Get metadata for a given table
"""


class DataProfiler(ABC):
@abstractmethod
def profile_data(self, entity: TableIdentifier) -> DataProfilingResult:
"""
Profile data for a given table
"""


class SchemaComparator(ABC):
@abstractmethod
def compare_schema(self, source: TableIdentifier, target: TableIdentifier) -> SchemaComparisonResult:
"""
Compare schema for two tables
"""


class DataComparator(ABC):
@abstractmethod
def compare_data(self, source: TableIdentifier, target: TableIdentifier) -> DataComparisonResult:
"""
Compare data for two tables
"""
114 changes: 114 additions & 0 deletions src/databricks/labs/ucx/recon/data_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from collections.abc import Iterator

from databricks.labs.lsql.backends import SqlBackend
from databricks.labs.lsql.core import Row

from .base import (
DataComparator,
DataComparisonResult,
TableIdentifier,
DataProfiler,
DataProfilingResult,
)


class StandardDataComparator(DataComparator):
DATA_COMPARISON_QUERY_TEMPLATE = """
WITH compare_results AS (
SELECT
CASE
WHEN source.hash_value IS NULL AND target.hash_value IS NULL THEN TRUE
WHEN source.hash_value IS NULL OR target.hash_value IS NULL THEN FALSE
WHEN source.hash_value = target.hash_value THEN TRUE
ELSE FALSE
END AS is_match,
CASE
WHEN target.hash_value IS NULL THEN 1
ELSE 0
END AS num_missing_records_in_target,
CASE
WHEN source.hash_value IS NULL THEN 1
ELSE 0
END AS num_missing_records_in_source
FROM (
SELECT {source_hash_expr} AS hash_value
FROM {source_table_fqn}
) AS source
FULL OUTER JOIN (
SELECT {target_hash_expr} AS hash_value
FROM {target_table_fqn}
) AS target
ON source.hash_value = target.hash_value
)
SELECT
COUNT(*) AS total_mismatches,
COALESCE(SUM(num_missing_records_in_target), 0) AS num_missing_records_in_target,
COALESCE(SUM(num_missing_records_in_source), 0) AS num_missing_records_in_source
FROM compare_results
WHERE is_match IS FALSE;
"""

def __init__(self, sql_backend: SqlBackend, data_profiler: DataProfiler):
self._sql_backend = sql_backend
self._data_profiler = data_profiler

def compare_data(self, source: TableIdentifier, target: TableIdentifier) -> DataComparisonResult:
"""
This method compares the data of two tables. It takes two TableIdentifier objects as input, which represent
the source and target tables for which the data are to be compared.
Note: This method does not handle exceptions raised during the execution of the SQL query or
the retrieval of the table metadata. These exceptions are expected to be handled by the caller in a manner
appropriate for their context.
"""
source_data_profile = self._data_profiler.profile_data(source)
target_data_profile = self._data_profiler.profile_data(target)
comparison_query = StandardDataComparator.build_data_comparison_query(
source_data_profile,
target_data_profile,
)
query_result: Iterator[Row] = self._sql_backend.fetch(comparison_query)
count_row = next(query_result)
num_missing_records_in_target = int(count_row["num_missing_records_in_target"])
num_missing_records_in_source = int(count_row["num_missing_records_in_source"])
return DataComparisonResult(
source_row_count=source_data_profile.row_count,
target_row_count=target_data_profile.row_count,
num_missing_records_in_target=num_missing_records_in_target,
num_missing_records_in_source=num_missing_records_in_source,
)

@classmethod
def build_data_comparison_query(
cls,
source_data_profile: DataProfilingResult,
target_data_profile: DataProfilingResult,
) -> str:
source_table = source_data_profile.table_metadata.identifier
target_table = target_data_profile.table_metadata.identifier
source_hash_inputs = _build_data_comparison_hash_inputs(source_data_profile)
target_hash_inputs = _build_data_comparison_hash_inputs(target_data_profile)
comparison_query = StandardDataComparator.DATA_COMPARISON_QUERY_TEMPLATE.format(
source_hash_expr=f"SHA2(CONCAT_WS('|', {', '.join(source_hash_inputs)}), 256)",
target_hash_expr=f"SHA2(CONCAT_WS('|', {', '.join(target_hash_inputs)}), 256)",
source_table_fqn=source_table.fqn_escaped,
target_table_fqn=target_table.fqn_escaped,
)

return comparison_query


def _build_data_comparison_hash_inputs(data_profile: DataProfilingResult) -> list[str]:
source_metadata = data_profile.table_metadata
inputs = []
for column in source_metadata.columns:
data_type = column.data_type.lower()
transformed_column = column.name

if data_type.startswith("array"):
transformed_column = f"TO_JSON(SORT_ARRAY({column.name}))"
elif data_type.startswith("map") or data_type.startswith("struct"):
transformed_column = f"TO_JSON({column.name})"

inputs.append(f"COALESCE(TRIM({transformed_column}), '')")
return inputs
35 changes: 35 additions & 0 deletions src/databricks/labs/ucx/recon/data_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from collections.abc import Iterator

from databricks.labs.lsql.backends import SqlBackend
from databricks.labs.lsql.core import Row

from .base import DataProfiler, DataProfilingResult, TableIdentifier, TableMetadataRetriever


class StandardDataProfiler(DataProfiler):
def __init__(self, sql_backend: SqlBackend, metadata_retriever: TableMetadataRetriever):
self._sql_backend = sql_backend
self._metadata_retriever = metadata_retriever

def profile_data(self, entity: TableIdentifier) -> DataProfilingResult:
"""
This method profiles the data in the given table. It takes a TableIdentifier object as input, which represents
the table to be profiled. The method performs two main operations:
1. It retrieves the row count of the table.
2. It retrieves the metadata of the table using a TableMetadataRetriever instance.
Note: This method does not handle exceptions raised during the execution of the SQL query or the retrieval
of the table metadata. These exceptions are expected to be handled by the caller
in a manner appropriate for their context.
"""
row_count = self._get_table_row_count(entity)
return DataProfilingResult(
row_count,
self._metadata_retriever.get_metadata(entity),
)

def _get_table_row_count(self, entity: TableIdentifier) -> int:
query_result: Iterator[Row] = self._sql_backend.fetch(f"SELECT COUNT(*) as row_count FROM {entity.fqn_escaped}")
count_row = next(query_result)
return int(count_row[0])
51 changes: 51 additions & 0 deletions src/databricks/labs/ucx/recon/metadata_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from collections.abc import Iterator

from databricks.labs.lsql.backends import SqlBackend
from databricks.labs.lsql.core import Row

from .base import TableIdentifier, TableMetadata, ColumnMetadata, TableMetadataRetriever


class DatabricksTableMetadataRetriever(TableMetadataRetriever):
def __init__(self, sql_backend: SqlBackend):
self._sql_backend = sql_backend

def get_metadata(self, entity: TableIdentifier) -> TableMetadata:
"""
This method retrieves the metadata for a given table. It takes a TableIdentifier object as input,
which represents the table for which the metadata is to be retrieved.
Note: This method does not handle exceptions raised during the execution of the SQL query. These exceptions are
expected to be handled by the caller in a manner appropriate for their context.
"""
schema_query = DatabricksTableMetadataRetriever.build_metadata_query(entity)
query_result: Iterator[Row] = self._sql_backend.fetch(schema_query)
# The code uses a set comprehension to automatically deduplicate the column metadata entries,
# Partition information are typically prefixed with a # symbol,
# so any column name starting with # is excluded from the final set of column metadata.
# The column metadata objects are sorted by column name to ensure a consistent order.
columns = {
ColumnMetadata(str(row["col_name"]), str(row["data_type"]))
for row in query_result
if not str(row["col_name"]).startswith("#")
}
return TableMetadata(entity, sorted(columns, key=lambda x: x.name))

@classmethod
def build_metadata_query(cls, entity: TableIdentifier) -> str:
if entity.catalog == "hive_metastore":
return f"DESCRIBE TABLE {entity.fqn_escaped}"

query = f"""
SELECT
LOWER(column_name) AS col_name,
full_data_type AS data_type
FROM
{entity.catalog_escaped}.information_schema.columns
WHERE
LOWER(table_catalog)='{entity.catalog}' AND
LOWER(table_schema)='{entity.schema}' AND
LOWER(table_name) ='{entity.table}'
ORDER BY col_name"""

return query
63 changes: 63 additions & 0 deletions src/databricks/labs/ucx/recon/schema_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from .base import (
SchemaComparator,
SchemaComparisonEntry,
SchemaComparisonResult,
TableMetadataRetriever,
ColumnMetadata,
TableIdentifier,
)


class StandardSchemaComparator(SchemaComparator):
def __init__(self, metadata_retriever: TableMetadataRetriever):
self._metadata_retriever = metadata_retriever

def compare_schema(self, source: TableIdentifier, target: TableIdentifier) -> SchemaComparisonResult:
"""
This method compares the schema of two tables. It takes two TableIdentifier objects as input, which represent
the source and target tables for which the schemas are to be compared.
Note: This method does not handle exceptions raised during the execution of the SQL query or the retrieval
of the table metadata. These exceptions are expected to be handled by the caller in a manner appropriate for
their context.
"""
comparison_result = self._eval_schema_diffs(source, target)
is_matching = all(entry.is_matching for entry in comparison_result)
return SchemaComparisonResult(is_matching, comparison_result)

def _eval_schema_diffs(self, source: TableIdentifier, target: TableIdentifier) -> list[SchemaComparisonEntry]:
source_metadata = self._metadata_retriever.get_metadata(source)
target_metadata = self._metadata_retriever.get_metadata(target)
# Combine the sets of column names for both the source and target tables
# to create a set of all unique column names from both tables.
source_column_names = {column.name for column in source_metadata.columns}
target_column_names = {column.name for column in target_metadata.columns}
all_column_names = source_column_names.union(target_column_names)
comparison_result = []
# Compare the column metadata from each table with a logic similar to full outer join.
for field_name in sorted(all_column_names):
source_col = source_metadata.get_column_metadata(field_name)
target_col = target_metadata.get_column_metadata(field_name)
comparison_result.append(_build_comparison_result_entry(source_col, target_col))
return comparison_result


def _build_comparison_result_entry(
source_col: ColumnMetadata | None,
target_col: ColumnMetadata | None,
) -> SchemaComparisonEntry:
if source_col and target_col:
is_matching = source_col == target_col
notes = None
else:
is_matching = False
notes = "Column is missing in " + ("target" if source_col else "source")

return SchemaComparisonEntry(
source_column=source_col.name if source_col else None,
source_datatype=source_col.data_type if source_col else None,
target_column=target_col.name if target_col else None,
target_datatype=target_col.data_type if target_col else None,
is_matching=is_matching,
notes=notes,
)
Empty file added tests/unit/recon/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions tests/unit/recon/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
from databricks.labs.lsql.backends import MockBackend


@pytest.fixture()
def metadata_row_factory():
yield MockBackend.rows(
"col_name",
"data_type",
)


@pytest.fixture()
def row_count_row_factory():
yield MockBackend.rows(
"row_count",
)


@pytest.fixture()
def data_comp_row_factory():
yield MockBackend.rows(
"total_mismatches",
"num_missing_records_in_target",
"num_missing_records_in_source",
)
Loading

0 comments on commit ab03d5c

Please sign in to comment.