Skip to content

Commit

Permalink
Speedup assessment workflow by making DBFS root table size calculatio…
Browse files Browse the repository at this point in the history
…n parallel (#2745)

We were not doing that before and now we do.
  • Loading branch information
nfx authored Sep 25, 2024
1 parent 1951f37 commit 1dd1a12
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/databricks/labs/ucx/contexts/workflow_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def pipelines_crawler(self):

@cached_property
def table_size_crawler(self):
return TableSizeCrawler(self.sql_backend, self.inventory_database)
return TableSizeCrawler(self.sql_backend, self.inventory_database, self.config.include_databases)

@cached_property
def policies_crawler(self):
Expand Down
39 changes: 21 additions & 18 deletions src/databricks/labs/ucx/hive_metastore/table_size.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import logging
from collections.abc import Iterable
from dataclasses import dataclass
from functools import partial

from databricks.labs.blueprint.parallel import Threads
from databricks.labs.lsql.backends import SqlBackend

from databricks.labs.ucx.framework.crawlers import CrawlerBase
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.hive_metastore import TablesCrawler
from databricks.labs.ucx.hive_metastore.tables import Table

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,43 +43,43 @@ def _crawl(self) -> Iterable[TableSize]:
"""Crawls and lists tables using table crawler
Identifies DBFS root tables and calculates the size for these.
"""
tasks = []
for table in self._tables_crawler.snapshot():
if not table.kind == "TABLE":
continue
if not table.is_dbfs_root:
continue
size_in_bytes = self._safe_get_table_size(table.key)
if size_in_bytes is None:
continue # table does not exist anymore or is corrupted

yield TableSize(
catalog=table.catalog, database=table.database, name=table.name, size_in_bytes=size_in_bytes
)
tasks.append(partial(self._safe_get_table_size, table))
return Threads.strict('DBFS root table sizes', tasks)

def _try_fetch(self) -> Iterable[TableSize]:
"""Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error"""
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield TableSize(*row)

def _safe_get_table_size(self, table_full_name: str) -> int | None:
logger.debug(f"Evaluating {table_full_name} table size.")
def _safe_get_table_size(self, table: Table) -> TableSize | None:
logger.debug(f"Evaluating {table.key} table size.")
try:
# refresh table statistics to avoid stale stats in HMS
self._backend.execute(f"ANALYZE table {escape_sql_identifier(table_full_name)} compute STATISTICS NOSCAN")
# pylint: disable-next=protected-access
return self._spark._jsparkSession.table(table_full_name).queryExecution().analyzed().stats().sizeInBytes()
self._backend.execute(f"ANALYZE table {table.safe_sql_key} compute STATISTICS NOSCAN")
jvm_df = self._spark._jsparkSession.table(table.safe_sql_key) # pylint: disable=protected-access
size_in_bytes = jvm_df.queryExecution().analyzed().stats().sizeInBytes()
return TableSize(
catalog=table.catalog,
database=table.database,
name=table.name,
size_in_bytes=size_in_bytes,
)
except Exception as e: # pylint: disable=broad-exception-caught
if "[TABLE_OR_VIEW_NOT_FOUND]" in str(e) or "[DELTA_TABLE_NOT_FOUND]" in str(e):
logger.warning(f"Failed to evaluate {table_full_name} table size. Table not found.")
logger.warning(f"Failed to evaluate {table.key} table size. Table not found.")
return None
if "[DELTA_INVALID_FORMAT]" in str(e):
logger.warning(
f"Unable to read Delta table {table_full_name}, please check table structure and try again."
)
logger.warning(f"Unable to read Delta table {table.key}, please check table structure and try again.")
return None
if "[DELTA_MISSING_TRANSACTION_LOG]" in str(e):
logger.warning(f"Delta table {table_full_name} is corrupted: missing transaction log.")
logger.warning(f"Delta table {table.key} is corrupt: missing transaction log.")
return None
logger.error(f"Failed to evaluate {table_full_name} table size: ", exc_info=True)
logger.error(f"Failed to evaluate {table.key} table size: ", exc_info=True)

return None
2 changes: 1 addition & 1 deletion src/databricks/labs/ucx/hive_metastore/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ class MigrationCount:
what_count: dict[What, int]


class TablesCrawler(CrawlerBase):
class TablesCrawler(CrawlerBase[Table]):
def __init__(self, backend: SqlBackend, schema, include_databases: list[str] | None = None):
"""
Initializes a TablesCrawler instance.
Expand Down
3 changes: 2 additions & 1 deletion src/databricks/labs/ucx/source_code/known.json
Original file line number Diff line number Diff line change
Expand Up @@ -1830,6 +1830,7 @@
"databricks-labs-ucx": {
"databricks.labs.ucx": []
},
"databricks-pydabs": {},
"databricks-sdk": {
"databricks.sdk": []
},
Expand Down Expand Up @@ -29921,4 +29922,4 @@
"zipp.compat.py310": [],
"zipp.glob": []
}
}
}
2 changes: 1 addition & 1 deletion tests/unit/hive_metastore/test_table_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_table_size_when_table_corrupted(mocker, caplog):
results = tsc.snapshot()

assert len(results) == 0
assert "Delta table hive_metastore.db1.table1 is corrupted: missing transaction log" in caplog.text
assert "Delta table hive_metastore.db1.table1 is corrupt: missing transaction log" in caplog.text


def test_table_size_when_delta_invalid_format_error(mocker, caplog):
Expand Down

0 comments on commit 1dd1a12

Please sign in to comment.