From 27500c2f468873274564e64095f609e9d559229c Mon Sep 17 00:00:00 2001 From: Serge Smertin <259697+nfx@users.noreply.github.com> Date: Thu, 21 Mar 2024 20:02:47 +0100 Subject: [PATCH] Added initial version of `databricks labs ucx migrate-local-code` command (#1067) The `databricks labs ucx migrate-local-code` command has been added to facilitate migration of local code to a Databricks environment. This initial version of the command is highly experimental, with support for migrating Python and SQL files only. The `.gitignore` file has been updated to exclude output files and specific configuration files from being committed to the repository. This command aims to help users and administrators manage code migration and maintain consistency across workspaces, while also enhancing the compatibility of local code with the Unity Catalog, a part of Databricks' offerings for data and AI. --- .gitignore | 3 +- README.md | 19 + labs.yml | 3 + pyproject.toml | 7 +- src/databricks/labs/ucx/cli.py | 12 + src/databricks/labs/ucx/code/__init__.py | 0 src/databricks/labs/ucx/code/base.py | 91 +++++ src/databricks/labs/ucx/code/files.py | 68 ++++ src/databricks/labs/ucx/code/languages.py | 44 +++ src/databricks/labs/ucx/code/lsp.py | 345 ++++++++++++++++++ src/databricks/labs/ucx/code/notebooks.py | 32 ++ src/databricks/labs/ucx/code/pyspark.py | 59 +++ src/databricks/labs/ucx/code/queries.py | 60 +++ src/databricks/labs/ucx/code/redash.py | 24 ++ .../labs/ucx/hive_metastore/table_migrate.py | 38 +- tests/unit/code/__init__.py | 0 tests/unit/code/conftest.py | 18 + tests/unit/code/test_base.py | 40 ++ tests/unit/code/test_files.py | 40 ++ tests/unit/code/test_languages.py | 44 +++ tests/unit/code/test_notebooks.py | 52 +++ tests/unit/code/test_pyspark.py | 65 ++++ tests/unit/code/test_queries.py | 44 +++ 23 files changed, 1103 insertions(+), 5 deletions(-) create mode 100644 src/databricks/labs/ucx/code/__init__.py create mode 100644 src/databricks/labs/ucx/code/base.py create mode 100644 src/databricks/labs/ucx/code/files.py create mode 100644 src/databricks/labs/ucx/code/languages.py create mode 100644 src/databricks/labs/ucx/code/lsp.py create mode 100644 src/databricks/labs/ucx/code/notebooks.py create mode 100644 src/databricks/labs/ucx/code/pyspark.py create mode 100644 src/databricks/labs/ucx/code/queries.py create mode 100644 src/databricks/labs/ucx/code/redash.py create mode 100644 tests/unit/code/__init__.py create mode 100644 tests/unit/code/conftest.py create mode 100644 tests/unit/code/test_base.py create mode 100644 tests/unit/code/test_files.py create mode 100644 tests/unit/code/test_languages.py create mode 100644 tests/unit/code/test_notebooks.py create mode 100644 tests/unit/code/test_pyspark.py create mode 100644 tests/unit/code/test_queries.py diff --git a/.gitignore b/.gitignore index 0dab5311ba..bfb8e46e57 100644 --- a/.gitignore +++ b/.gitignore @@ -151,4 +151,5 @@ dev/cleanup.py .python-version .databricks-login.json -*.out \ No newline at end of file +*.out +foo \ No newline at end of file diff --git a/README.md b/README.md index a708bf7a6a..b030b402b5 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,8 @@ See [contributing instructions](CONTRIBUTING.md) to help improve this project. * [`create-catalogs-schemas` command](#create-catalogs-schemas-command) * [`move` command](#move-command) * [`alias` command](#alias-command) +* [Code migration commands](#code-migration-commands) + * [`migrate-local-code` command](#migrate-local-code-command) * [Cross-workspace installations](#cross-workspace-installations) * [`sync-workspace-info` command](#sync-workspace-info-command) * [`manual-workspace-info` command](#manual-workspace-info-command) @@ -625,6 +627,23 @@ It can also be used to debug issues related to table aliasing. [[back to top](#databricks-labs-ucx)] +# Code migration commands + +[[back to top](#databricks-labs-ucx)] + +## `migrate-local-code` command + +```text +databricks labs ucx migrate-local-code +``` + +**(Experimental)** Once [table migration](#table-migration-commands) is complete, you can run this command to +migrate all python and SQL files in the current working directory. This command is highly experimental and +at the moment only supports Python and SQL files and discards code comments and formatting during +the automated transformation process. + +[[back to top](#databricks-labs-ucx)] + # Cross-workspace installations When installing UCX across multiple workspaces, administrators need to keep UCX configurations in sync. diff --git a/labs.yml b/labs.yml index 017eccbfc6..b370f168e0 100644 --- a/labs.yml +++ b/labs.yml @@ -149,3 +149,6 @@ commands: - name: revert-cluster-remap description: Reverting the Re-mapping of the cluster from UC + + - name: migrate-local-code + description: (Experimental) Migrate files in the current directory to be more compatible with Unity Catalog. diff --git a/pyproject.toml b/pyproject.toml index 4258ccef5e..69760117fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,7 +154,12 @@ branch = true parallel = true [tool.coverage.report] -omit = ["src/databricks/labs/ucx/mixins/*", "*/working-copy/*", "*/fresh_wheel_file/*"] +omit = [ + "src/databricks/labs/ucx/mixins/*", + "src/databricks/labs/ucx/code/lsp.py", + "*/working-copy/*", + "*/fresh_wheel_file/*" +] exclude_lines = [ "no cov", "if __name__ == .__main__.:", diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index 13e74559ac..a15d7e4963 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -3,6 +3,7 @@ import shutil import webbrowser from collections.abc import Callable +from pathlib import Path from databricks.labs.blueprint.cli import App from databricks.labs.blueprint.entrypoint import get_logger @@ -19,6 +20,7 @@ from databricks.labs.ucx.azure.access import AzureResourcePermissions from databricks.labs.ucx.azure.credentials import ServicePrincipalMigration from databricks.labs.ucx.azure.locations import ExternalLocationsMigration +from databricks.labs.ucx.code.files import Files from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.hive_metastore import ExternalLocations, TablesCrawler from databricks.labs.ucx.hive_metastore.catalog_schema import CatalogSchema @@ -547,5 +549,15 @@ def revert_cluster_remap(w: WorkspaceClient, prompts: Prompts): cluster_details.revert_cluster_remap(cluster_list, cluster_ids) +@ucx.command +def migrate_local_code(w: WorkspaceClient, prompts: Prompts): + """Fix the code files based on their language.""" + files = Files.for_cli(w) + working_directory = Path.cwd() + if not prompts.confirm("Do you want to apply UC migration to all files in the current directory?"): + return + files.apply(working_directory) + + if __name__ == "__main__": ucx() diff --git a/src/databricks/labs/ucx/code/__init__.py b/src/databricks/labs/ucx/code/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/ucx/code/base.py b/src/databricks/labs/ucx/code/base.py new file mode 100644 index 0000000000..581dd64202 --- /dev/null +++ b/src/databricks/labs/ucx/code/base.py @@ -0,0 +1,91 @@ +from abc import abstractmethod +from collections.abc import Iterable +from dataclasses import dataclass + +# Code mapping between LSP, PyLint, and our own diagnostics: +# | LSP | PyLint | Our | +# |---------------------------|------------|----------------| +# | Severity.ERROR | Error | Failure() | +# | Severity.WARN | Warning | Advisory() | +# | DiagnosticTag.DEPRECATED | Warning | Deprecation() | +# | Severity.INFO | Info | Advice() | +# | Severity.HINT | Convention | Convention() | +# | DiagnosticTag.UNNECESSARY | Refactor | Convention() | + + +@dataclass +class Advice: + code: str + message: str + start_line: int + start_col: int + end_line: int + end_col: int + + def replace( + self, + code: str | None = None, + message: str | None = None, + start_line: int | None = None, + start_col: int | None = None, + end_line: int | None = None, + end_col: int | None = None, + ) -> 'Advice': + return self.__class__( + code=code if code is not None else self.code, + message=message if message is not None else self.message, + start_line=start_line if start_line is not None else self.start_line, + start_col=start_col if start_col is not None else self.start_col, + end_line=end_line if end_line is not None else self.end_line, + end_col=end_col if end_col is not None else self.end_col, + ) + + def as_advisory(self) -> 'Advisory': + return Advisory(**self.__dict__) + + def as_failure(self) -> 'Failure': + return Failure(**self.__dict__) + + def as_deprecation(self) -> 'Deprecation': + return Deprecation(**self.__dict__) + + def as_convention(self) -> 'Convention': + return Convention(**self.__dict__) + + +class Advisory(Advice): + """A warning that does not prevent the code from running.""" + + +class Failure(Advisory): + """An error that prevents the code from running.""" + + +class Deprecation(Advisory): + """An advisory that suggests to replace the code with a newer version.""" + + +class Convention(Advice): + """A suggestion for a better way to write the code.""" + + +class Linter: + @abstractmethod + def lint(self, code: str) -> Iterable[Advice]: ... + + +class Fixer: + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + def apply(self, code: str) -> str: ... + + +class SequentialLinter(Linter): + def __init__(self, linters: list[Linter]): + self._linters = linters + + def lint(self, code: str) -> Iterable[Advice]: + for linter in self._linters: + yield from linter.lint(code) diff --git a/src/databricks/labs/ucx/code/files.py b/src/databricks/labs/ucx/code/files.py new file mode 100644 index 0000000000..62ceab4922 --- /dev/null +++ b/src/databricks/labs/ucx/code/files.py @@ -0,0 +1,68 @@ +import logging +from pathlib import Path + +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.workspace import Language + +from databricks.labs.ucx.code.languages import Languages +from databricks.labs.ucx.hive_metastore.table_migrate import TablesMigrate + +logger = logging.getLogger(__name__) + + +class Files: + """The Files class is responsible for fixing code files based on their language.""" + + def __init__(self, languages: Languages): + self._languages = languages + self._extensions = {".py": Language.PYTHON, ".sql": Language.SQL} + + @classmethod + def for_cli(cls, ws: WorkspaceClient): + tables_migrate = TablesMigrate.for_cli(ws) + index = tables_migrate.index() + languages = Languages(index) + return cls(languages) + + def apply(self, path: Path) -> bool: + if path.is_dir(): + for folder in path.iterdir(): + self.apply(folder) + return True + return self._apply_file_fix(path) + + def _apply_file_fix(self, path): + """ + The fix method reads a file, lints it, applies fixes, and writes the fixed code back to the file. + """ + # Check if the file extension is in the list of supported extensions + if path.suffix not in self._extensions: + return False + # Get the language corresponding to the file extension + language = self._extensions[path.suffix] + # If the language is not supported, return + if not language: + return False + logger.info(f"Analysing {path}") + # Get the linter for the language + linter = self._languages.linter(language) + # Open the file and read the code + with path.open("r") as f: + code = f.read() + applied = False + # Lint the code and apply fixes + for advice in linter.lint(code): + logger.info(f"Found: {advice}") + fixer = self._languages.fixer(language, advice.code) + if not fixer: + continue + logger.info(f"Applying fix for {advice}") + code = fixer.apply(code) + applied = True + if not applied: + return False + # Write the fixed code back to the file + with path.open("w") as f: + logger.info(f"Overwriting {path}") + f.write(code) + return True diff --git a/src/databricks/labs/ucx/code/languages.py b/src/databricks/labs/ucx/code/languages.py new file mode 100644 index 0000000000..fa1ff810fe --- /dev/null +++ b/src/databricks/labs/ucx/code/languages.py @@ -0,0 +1,44 @@ +from databricks.sdk.service.workspace import Language + +from databricks.labs.ucx.code.base import Fixer, Linter, SequentialLinter +from databricks.labs.ucx.code.pyspark import SparkSql +from databricks.labs.ucx.code.queries import FromTable +from databricks.labs.ucx.hive_metastore.table_migrate import Index + + +class Languages: + def __init__(self, index: Index): + self._index = index + from_table = FromTable(index) + self._linters = { + Language.PYTHON: SequentialLinter([SparkSql(from_table)]), + Language.SQL: SequentialLinter([from_table]), + } + self._fixers: dict[Language, list[Fixer]] = { + Language.PYTHON: [SparkSql(from_table)], + Language.SQL: [from_table], + } + + def is_supported(self, language: Language) -> bool: + return language in self._linters and language in self._fixers + + def linter(self, language: Language) -> Linter: + if language not in self._linters: + raise ValueError(f"Unsupported language: {language}") + return self._linters[language] + + def fixer(self, language: Language, diagnostic_code: str) -> Fixer | None: + if language not in self._fixers: + return None + for fixer in self._fixers[language]: + if fixer.name() == diagnostic_code: + return fixer + return None + + def apply_fixes(self, language: Language, code: str) -> str: + linter = self.linter(language) + for advice in linter.lint(code): + fixer = self.fixer(language, advice.code) + if fixer: + code = fixer.apply(code) + return code diff --git a/src/databricks/labs/ucx/code/lsp.py b/src/databricks/labs/ucx/code/lsp.py new file mode 100644 index 0000000000..abebd6ad30 --- /dev/null +++ b/src/databricks/labs/ucx/code/lsp.py @@ -0,0 +1,345 @@ +import enum +import functools +import http.server +import json +import logging +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from urllib.parse import parse_qsl + +from databricks.labs.blueprint.logger import install_logger +from databricks.sdk.service.workspace import Language + +from databricks.labs.ucx.code.base import ( + Advice, + Advisory, + Convention, + Deprecation, + Failure, +) +from databricks.labs.ucx.code.languages import Languages +from databricks.labs.ucx.hive_metastore.table_migrate import Index, MigrationStatus + +logger = logging.getLogger(__name__) + + +@dataclass +class Position: + line: int + character: int + + def as_dict(self) -> dict: + return {"line": self.line, "character": self.character} + + @classmethod + def from_dict(cls, raw: dict) -> 'Position': + return cls(raw['line'], raw['character']) + + +@dataclass +class Range: + start: Position + end: Position + + @classmethod + def from_dict(cls, raw: dict) -> 'Range': + return cls(Position.from_dict(raw['start']), Position.from_dict(raw['end'])) + + @classmethod + def make(cls, start_line: int, start_character: int, end_line: int, end_character: int) -> 'Range': + return cls(start=Position(start_line - 1, start_character), end=Position(end_line - 1, end_character)) + + def as_dict(self) -> dict: + return {"start": self.start.as_dict(), "end": self.end.as_dict()} + + def fragment(self, code: str) -> str: + out = [] + splitlines = code.splitlines() + for line, part in enumerate(splitlines): + if line == self.start.line and line == self.end.line: + out.append(part[self.start.character : self.end.character]) + elif line == self.start.line: + out.append(part[self.start.character :]) + elif line == self.end.line: + out.append(part[: self.end.character]) + elif self.start.line < line < self.end.line: + out.append(part) + return "".join(out) + + +class Severity(enum.IntEnum): + ERROR = 1 + WARN = 2 + INFO = 3 + HINT = 4 + + +class DiagnosticTag(enum.IntEnum): + UNNECESSARY = 1 + DEPRECATED = 2 + + +@dataclass +class Diagnostic: + # the range at which the message applies. + range: Range + + # The diagnostic's code, which might appear in the user interface. + code: str + + # An optional property to describe the error code. + source: str + + # The diagnostic's message. + message: str + + # The diagnostic's severity. Can be omitted. If omitted it is up to the + # client to interpret diagnostics as error, warning, info or hint. + severity: Severity + + tags: list[DiagnosticTag] | None = None + + @classmethod + def from_advice(cls, advice: Advice) -> 'Diagnostic': + severity, tags = cls._severity_and_tags(advice) + return cls( + range=Range.make(advice.start_line, advice.start_col, advice.end_line, advice.end_col), + code=advice.code, + source="databricks.labs.ucx", + message=advice.message, + severity=severity, + tags=tags, + ) + + @classmethod + def _severity_and_tags(cls, advice): + if isinstance(advice, Convention): + return Severity.HINT, [DiagnosticTag.UNNECESSARY] + if isinstance(advice, Deprecation): + return Severity.WARN, [DiagnosticTag.DEPRECATED] + if isinstance(advice, Advisory): + return Severity.WARN, [] + if isinstance(advice, Failure): + return Severity.ERROR, [] + return Severity.INFO, [] + + def as_dict(self) -> dict: + return { + "range": self.range.as_dict(), + "code": self.code, + "source": self.source, + "message": self.message, + "severity": self.severity.value if self.severity else Severity.WARN, + "tags": [t.value for t in self.tags] if self.tags else [], + } + + +@dataclass +class TextDocumentIdentifier: + uri: str + + def as_dict(self) -> dict: + return { + "uri": self.uri, + } + + +@dataclass +class OptionalVersionedTextDocumentIdentifier: + uri: str + version: int | None = None + + def as_dict(self) -> dict: + return { + "uri": self.uri, + "version": self.version, + } + + +@dataclass +class TextEdit: + range: Range + new_text: str + + def as_dict(self) -> dict: + return { + "range": self.range.as_dict(), + "newText": self.new_text, + } + + +@dataclass +class TextDocumentEdit: + text_document: OptionalVersionedTextDocumentIdentifier + edits: Sequence[TextEdit] + + def as_dict(self) -> dict: + return { + "textDocument": self.text_document.as_dict(), + "edits": [e.as_dict() for e in self.edits], + } + + +@dataclass +class WorkspaceEdit: + # we also can have CreateFile | RenameFile | DeleteFile, but we won't do it for now. + document_changes: Sequence[TextDocumentEdit] + + def as_dict(self) -> dict: + return { + "documentChanges": [e.as_dict() for e in self.document_changes], + } + + +@dataclass +class CodeAction: + title: str + edit: WorkspaceEdit + is_preferred: bool = False + + def as_dict(self) -> dict: + return { + "title": self.title, + "edit": self.edit.as_dict(), + "isPreferred": self.is_preferred, + } + + +@dataclass +class AnalyseResponse: + diagnostics: list[Diagnostic] + + def as_dict(self): + return {"diagnostics": [d.as_dict() for d in self.diagnostics]} + + +@dataclass +class QuickFixResponse: + code_actions: list[CodeAction] + + def as_dict(self): + return {"code_actions": [ca.as_dict() for ca in self.code_actions]} + + +class LspServer: + def __init__(self, language_support: Languages): + self._languages = language_support + self._extensions = {".py": Language.PYTHON, ".sql": Language.SQL} + + def _read(self, file_uri: str): + file = Path(file_uri.removeprefix("file://")) + if file.suffix not in self._extensions: + raise KeyError(f"no language for {file.suffix}") + language = self._extensions[file.suffix] + with file.open('r', encoding='utf8') as f: + return f.read(), language + + def lint(self, file_uri: str): + code, language = self._read(file_uri) + analyser = self._languages.linter(language) + diagnostics = [Diagnostic.from_advice(_) for _ in analyser.lint(code)] + return AnalyseResponse(diagnostics) + + def quickfix(self, file_uri: str, code_range: Range, diagnostic_code: str): + code, language = self._read(file_uri) + fixer = self._languages.fixer(language, diagnostic_code) + if not fixer: + return QuickFixResponse(code_actions=[]) + fragment = code_range.fragment(code) + apply = fixer.apply(fragment) + return QuickFixResponse( + code_actions=[ + CodeAction( + title=f"Replace with: {apply}", + edit=WorkspaceEdit( + document_changes=[ + TextDocumentEdit( + text_document=OptionalVersionedTextDocumentIdentifier(file_uri), + edits=[TextEdit(code_range, apply)], + ), + ] + ), + is_preferred=True, + ) + ] + ) + + def serve(self): + server_address = ('localhost', 8000) + handler_class = functools.partial(_RequestHandler, self) + httpd = http.server.ThreadingHTTPServer(server_address, handler_class) + httpd.serve_forever() + + +class _RequestHandler(http.server.BaseHTTPRequestHandler): + def __init__(self, lsp_server: LspServer, *args): + self._lsp_server = lsp_server + super().__init__(*args) + + def log_message(self, fmt: str, *args: Any): # pylint: disable=arguments-differ + logger.debug(fmt % args) # pylint: disable=logging-not-lazy + + def do_POST(self): # pylint: disable=invalid-name + if not self.path.startswith('/quickfix'): + self.send_error(400, 'Wrong input') + return + self.send_response(200) + self.end_headers() + + content_length = int(self.headers['Content-Length']) + post_data = self.rfile.read(content_length) + raw = json.loads(post_data.decode('utf-8')) + logger.debug(f"Received:\n{raw}") + + rng = Range.from_dict(raw['range']) + response = self._lsp_server.quickfix(raw['file_uri'], rng, raw['code']) + + raw = json.dumps(response.as_dict()).encode('utf-8') + self.wfile.write(raw) + # self.wfile.flush() + + def do_GET(self): # pylint: disable=invalid-name + if not self.path.startswith('/lint'): + self.send_error(400, 'Wrong input') + return + parts = self.path.split('?') + if len(parts) != 2: + self.send_error(400, 'Missing Query') + return + _, query_string = parts + query = dict(parse_qsl(query_string)) + response = self._lsp_server.lint(query['file_uri']) + if not response: + self.send_error(404, 'no analyser for file type') + return + + self.send_response(200) + self.end_headers() + raw = json.dumps(response.as_dict()).encode('utf-8') + self.wfile.write(raw) + self.wfile.flush() + + +if __name__ == '__main__': + install_logger() + logging.root.setLevel('DEBUG') + languages = Languages( + Index( + [ + MigrationStatus( + src_schema='old', src_table='things', dst_catalog='brand', dst_schema='new', dst_table='stuff' + ), + MigrationStatus( + src_schema='other', + src_table='matters', + dst_catalog='some', + dst_schema='certain', + dst_table='issues', + ), + ] + ) + ) + lsp = LspServer(languages) + lsp.serve() diff --git a/src/databricks/labs/ucx/code/notebooks.py b/src/databricks/labs/ucx/code/notebooks.py new file mode 100644 index 0000000000..504f441fbf --- /dev/null +++ b/src/databricks/labs/ucx/code/notebooks.py @@ -0,0 +1,32 @@ +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.workspace import ExportFormat, ObjectInfo + +from databricks.labs.ucx.code.languages import Languages + + +class Notebooks: + def __init__(self, ws: WorkspaceClient, languages: Languages): + self._ws = ws + self._languages = languages + + def revert(self, object_info: ObjectInfo): + if not object_info.path: + return False + with self._ws.workspace.download(object_info.path + ".bak", format=ExportFormat.SOURCE) as f: + code = f.read().decode("utf-8") + self._ws.workspace.upload(object_info.path, code.encode("utf-8")) + return True + + def apply(self, object_info: ObjectInfo) -> bool: + if not object_info.language or not object_info.path: + return False + if not self._languages.is_supported(object_info.language): + return False + with self._ws.workspace.download(object_info.path, format=ExportFormat.SOURCE) as f: + original_code = f.read().decode("utf-8") + new_code = self._languages.apply_fixes(object_info.language, original_code) + if new_code == original_code: + return False + self._ws.workspace.upload(object_info.path + ".bak", original_code.encode("utf-8")) + self._ws.workspace.upload(object_info.path, new_code.encode("utf-8")) + return True diff --git a/src/databricks/labs/ucx/code/pyspark.py b/src/databricks/labs/ucx/code/pyspark.py new file mode 100644 index 0000000000..c2b1f39757 --- /dev/null +++ b/src/databricks/labs/ucx/code/pyspark.py @@ -0,0 +1,59 @@ +import ast +from collections.abc import Iterable + +from databricks.labs.ucx.code.base import Advice, Fixer, Linter +from databricks.labs.ucx.code.queries import FromTable + + +class SparkSql(Linter, Fixer): + def __init__(self, from_table: FromTable): + self._from_table = from_table + + def name(self) -> str: + # this is the same fixer, just in a different language context + return self._from_table.name() + + def lint(self, code: str) -> Iterable[Advice]: + tree = ast.parse(code) + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + if not isinstance(node.func, ast.Attribute): + continue + if node.func.attr != "sql": + continue + if len(node.args) != 1: + continue + first_arg = node.args[0] + if not isinstance(first_arg, ast.Constant): + # `astroid` library supports inference and parent node lookup, + # which makes traversing the AST a bit easier. + continue + query = first_arg.value + for advice in self._from_table.lint(query): + yield advice.replace( + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno, + end_col=node.end_col_offset, + ) + + def apply(self, code: str) -> str: + tree = ast.parse(code) + # we won't be doing it like this in production, but for the sake of the example + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + if not isinstance(node.func, ast.Attribute): + continue + if node.func.attr != "sql": + continue + if len(node.args) != 1: + continue + first_arg = node.args[0] + if not isinstance(first_arg, ast.Constant): + continue + query = first_arg.value + new_query = self._from_table.apply(query) + first_arg.value = new_query + return ast.unparse(tree) diff --git a/src/databricks/labs/ucx/code/queries.py b/src/databricks/labs/ucx/code/queries.py new file mode 100644 index 0000000000..ead39b69fe --- /dev/null +++ b/src/databricks/labs/ucx/code/queries.py @@ -0,0 +1,60 @@ +from collections.abc import Iterable + +import sqlglot +from sqlglot.expressions import Table + +from databricks.labs.ucx.code.base import Advice, Deprecation, Fixer, Linter +from databricks.labs.ucx.hive_metastore.table_migrate import Index + + +class FromTable(Linter, Fixer): + def __init__(self, index: Index): + self._index = index + + def name(self) -> str: + return 'table-migrate' + + def lint(self, code: str) -> Iterable[Advice]: + for statement in sqlglot.parse(code): + if not statement: + continue + for table in statement.find_all(Table): + catalog = self._catalog(table) + if catalog != 'hive_metastore': + continue + dst = self._index.get(table.db, table.name) + if not dst: + continue + yield Deprecation( + code='table-migrate', + message=f"Table {table.db}.{table.name} is migrated to {dst.destination()} in Unity Catalog", + # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 + start_line=0, + start_col=0, + end_line=0, + end_col=1024, + ) + + @staticmethod + def _catalog(table): + if table.catalog: + return table.catalog + return 'hive_metastore' + + def apply(self, code: str) -> str: + new_statements = [] + for statement in sqlglot.parse(code): + if not statement: + continue + for old_table in statement.find_all(Table): + catalog = self._catalog(old_table) + if catalog != 'hive_metastore': + continue + dst = self._index.get(old_table.db, old_table.name) + if not dst: + continue + new_table = Table(catalog=dst.dst_catalog, db=dst.dst_schema, this=dst.dst_table) + old_table.replace(new_table) + new_sql = statement.sql('databricks') + new_statements.append(new_sql) + return '; '.join(new_statements) diff --git a/src/databricks/labs/ucx/code/redash.py b/src/databricks/labs/ucx/code/redash.py new file mode 100644 index 0000000000..e89b03ea00 --- /dev/null +++ b/src/databricks/labs/ucx/code/redash.py @@ -0,0 +1,24 @@ +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.sql import Query + +from databricks.labs.ucx.code.base import Fixer + + +class Redash: + def __init__(self, fixer: Fixer, ws: WorkspaceClient): + self._fixer = fixer + self._ws = ws + + def fix(self, query: Query): + assert query.id is not None + assert query.query is not None + query.query = self._fixer.apply(query.query) + self._ws.queries.update( + query.id, + data_source_id=query.data_source_id, + description=query.description, + name=query.name, + options=query.options, + query=query.query, + run_as_role=query.run_as_role, + ) diff --git a/src/databricks/labs/ucx/hive_metastore/table_migrate.py b/src/databricks/labs/ucx/hive_metastore/table_migrate.py index 41331e08b3..57b19c9424 100644 --- a/src/databricks/labs/ucx/hive_metastore/table_migrate.py +++ b/src/databricks/labs/ucx/hive_metastore/table_migrate.py @@ -38,6 +38,9 @@ class MigrationStatus: dst_table: str | None = None update_ts: str | None = None + def destination(self): + return f"{self.dst_catalog}.{self.dst_schema}.{self.dst_table}".lower() + class TablesMigrate: def __init__( @@ -74,6 +77,9 @@ def for_cli(cls, ws: WorkspaceClient, product='ucx'): table_crawler, grants_crawler, ws, sql_backend, table_mapping, group_manager, migration_status_refresher ) + def index(self): + return self._migration_status_refresher.index() + def migrate_tables(self, *, what: What | None = None, acl_strategy: AclMigrationWhat | None = None): self._init_seen_tables() tables_to_migrate = self._tm.get_tables_to_migrate(self._tc) @@ -273,6 +279,27 @@ def _match_grants(table: Table, grants: Iterable[Grant], migrated_groups: list[M return matched_grants +class Index: + def __init__(self, tables: list[MigrationStatus]): + self._tables = tables + + def is_upgraded(self, schema: str, table: str) -> bool: + src_schema = schema.lower() + src_table = table.lower() + for migration_status in self._tables: + if migration_status.src_schema == src_schema and migration_status.src_table == src_table: + return True + return False + + def get(self, schema: str, table: str) -> MigrationStatus | None: + src_schema = schema.lower() + src_table = table.lower() + for migration_status in self._tables: + if migration_status.src_schema == src_schema and migration_status.src_table == src_table: + return migration_status + return None + + class MigrationStatusRefresher(CrawlerBase[MigrationStatus]): def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema, table_crawler: TablesCrawler): super().__init__(sbe, "hive_metastore", schema, "migration_status", MigrationStatus) @@ -282,6 +309,9 @@ def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema, table_crawler: def snapshot(self) -> Iterable[MigrationStatus]: return self._snapshot(self._try_fetch, self._crawl) + def index(self) -> Index: + return Index(list(self.snapshot())) + def get_seen_tables(self) -> dict[str, str]: seen_tables: dict[str, str] = {} for schema in self._iter_schemas(): @@ -310,12 +340,14 @@ def _crawl(self) -> Iterable[MigrationStatus]: reverse_seen = {v: k for k, v in self.get_seen_tables().items()} timestamp = datetime.datetime.now(datetime.timezone.utc).timestamp() for table in all_tables: + src_schema = table.database.lower() + src_table = table.name.lower() table_migration_status = MigrationStatus( - src_schema=table.database, - src_table=table.name, + src_schema=src_schema, + src_table=src_table, update_ts=str(timestamp), ) - if table.key in reverse_seen and self.is_upgraded(table.database, table.name): + if table.key in reverse_seen and self.is_upgraded(src_schema, src_table): target_table = reverse_seen[table.key] if len(target_table.split(".")) == 3: table_migration_status.dst_catalog = target_table.split(".")[0] diff --git a/tests/unit/code/__init__.py b/tests/unit/code/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/code/conftest.py b/tests/unit/code/conftest.py new file mode 100644 index 0000000000..f0f9fcf163 --- /dev/null +++ b/tests/unit/code/conftest.py @@ -0,0 +1,18 @@ +import pytest + +from databricks.labs.ucx.hive_metastore.table_migrate import Index, MigrationStatus + + +@pytest.fixture +def empty_index(): + return Index([]) + + +@pytest.fixture +def migration_index(): + return Index( + [ + MigrationStatus('old', 'things', dst_catalog='brand', dst_schema='new', dst_table='stuff'), + MigrationStatus('other', 'matters', dst_catalog='some', dst_schema='certain', dst_table='issues'), + ] + ) diff --git a/tests/unit/code/test_base.py b/tests/unit/code/test_base.py new file mode 100644 index 0000000000..a601b8c1f0 --- /dev/null +++ b/tests/unit/code/test_base.py @@ -0,0 +1,40 @@ +from databricks.labs.ucx.code.base import ( + Advice, + Advisory, + Convention, + Deprecation, + Failure, +) + + +def test_message_initialization(): + message = Advice('code1', 'This is a message', 1, 1, 2, 2) + assert message.code == 'code1' + assert message.message == 'This is a message' + assert message.start_line == 1 + assert message.start_col == 1 + assert message.end_line == 2 + assert message.end_col == 2 + + +def test_warning_initialization(): + warning = Advisory('code2', 'This is a warning', 1, 1, 2, 2) + + copy_of = warning.replace(code='code3') + assert copy_of.code == 'code3' + assert isinstance(copy_of, Advisory) + + +def test_error_initialization(): + error = Failure('code3', 'This is an error', 1, 1, 2, 2) + assert isinstance(error, Advice) + + +def test_deprecation_initialization(): + deprecation = Deprecation('code4', 'This is a deprecation', 1, 1, 2, 2) + assert isinstance(deprecation, Advice) + + +def test_convention_initialization(): + convention = Convention('code5', 'This is a convention', 1, 1, 2, 2) + assert isinstance(convention, Advice) diff --git a/tests/unit/code/test_files.py b/tests/unit/code/test_files.py new file mode 100644 index 0000000000..ee87cb50fe --- /dev/null +++ b/tests/unit/code/test_files.py @@ -0,0 +1,40 @@ +from pathlib import Path +from unittest.mock import Mock, create_autospec + +from databricks.sdk.service.workspace import Language + +from databricks.labs.ucx.code.files import Files +from databricks.labs.ucx.code.languages import Languages + + +def test_files_fix_ignores_unsupported_extensions(): + languages = create_autospec(Languages) + files = Files(languages) + path = Path('unsupported.ext') + assert not files.apply(path) + + +def test_files_fix_reads_supported_extensions(): + languages = create_autospec(Languages) + files = Files(languages) + path = Path(__file__) + assert not files.apply(path) + + +def test_files_supported_language_no_diagnostics(): + languages = create_autospec(Languages) + languages.linter(Language.PYTHON).lint.return_value = [] + files = Files(languages) + path = Path(__file__) + files.apply(path) + languages.fixer.assert_not_called() + + +def test_files_supported_language_no_fixer(): + languages = create_autospec(Languages) + languages.linter(Language.PYTHON).lint.return_value = [Mock(code='some-code')] + languages.fixer.return_value = None + files = Files(languages) + path = Path(__file__) + files.apply(path) + languages.fixer.assert_called_once_with(Language.PYTHON, 'some-code') diff --git a/tests/unit/code/test_languages.py b/tests/unit/code/test_languages.py new file mode 100644 index 0000000000..2ddfd0801f --- /dev/null +++ b/tests/unit/code/test_languages.py @@ -0,0 +1,44 @@ +import pytest +from databricks.sdk.service.workspace import Language + +from databricks.labs.ucx.code.base import Fixer, Linter +from databricks.labs.ucx.code.languages import Languages +from databricks.labs.ucx.hive_metastore.table_migrate import Index + +index = Index([]) + + +def test_linter_returns_correct_analyser_for_python(): + languages = Languages(index) + linter = languages.linter(Language.PYTHON) + assert isinstance(linter, Linter) + + +def test_linter_returns_correct_analyser_for_sql(): + languages = Languages(index) + linter = languages.linter(Language.SQL) + assert isinstance(linter, Linter) + + +def test_linter_raises_error_for_unsupported_language(): + languages = Languages(index) + with pytest.raises(ValueError): + languages.linter(Language.R) + + +def test_fixer_returns_correct_fixer_for_python(): + languages = Languages(index) + fixer = languages.fixer(Language.PYTHON, "diagnostic_code") + assert isinstance(fixer, Fixer) or fixer is None + + +def test_fixer_returns_correct_fixer_for_sql(): + languages = Languages(index) + fixer = languages.fixer(Language.SQL, "diagnostic_code") + assert isinstance(fixer, Fixer) or fixer is None + + +def test_fixer_returns_none_for_unsupported_language(): + languages = Languages(index) + fixer = languages.fixer(Language.SCALA, "diagnostic_code") + assert fixer is None diff --git a/tests/unit/code/test_notebooks.py b/tests/unit/code/test_notebooks.py new file mode 100644 index 0000000000..8eddc40c10 --- /dev/null +++ b/tests/unit/code/test_notebooks.py @@ -0,0 +1,52 @@ +from unittest.mock import create_autospec + +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.workspace import ExportFormat, Language, ObjectInfo + +from databricks.labs.ucx.code.languages import Languages +from databricks.labs.ucx.code.notebooks import Notebooks + + +def test_notebooks_revert_restores_original_code(): + ws = create_autospec(WorkspaceClient) + ws.workspace.download.return_value.__enter__.return_value.read.return_value = b'original_code' + languages = create_autospec(Languages) + notebooks = Notebooks(ws, languages) + object_info = ObjectInfo(path='path', language=Language.PYTHON) + notebooks.revert(object_info) + ws.workspace.download.assert_called_with('path.bak', format=ExportFormat.SOURCE) + ws.workspace.upload.assert_called_with('path', b'original_code') + + +def test_apply_returns_false_when_language_not_supported(): + ws = create_autospec(WorkspaceClient) + languages = create_autospec(Languages) + languages.is_supported.return_value = False + notebooks = Notebooks(ws, languages) + object_info = ObjectInfo(path='path', language=Language.R) + result = notebooks.apply(object_info) + assert not result + + +def test_apply_returns_false_when_no_fixes_applied(): + ws = create_autospec(WorkspaceClient) + ws.workspace.download.return_value.__enter__.return_value.read.return_value = b'original_code' + languages = create_autospec(Languages) + languages.is_supported.return_value = True + languages.apply_fixes.return_value = 'original_code' + notebooks = Notebooks(ws, languages) + object_info = ObjectInfo(path='path', language=Language.PYTHON) + assert not notebooks.apply(object_info) + + +def test_apply_returns_true_and_changes_code_when_fixes_applied(): + ws = create_autospec(WorkspaceClient) + ws.workspace.download.return_value.__enter__.return_value.read.return_value = b'original_code' + languages = create_autospec(Languages) + languages.is_supported.return_value = True + languages.apply_fixes.return_value = 'new_code' + notebooks = Notebooks(ws, languages) + object_info = ObjectInfo(path='path', language=Language.PYTHON) + assert notebooks.apply(object_info) + ws.workspace.upload.assert_any_call('path.bak', 'original_code'.encode("utf-8")) + ws.workspace.upload.assert_any_call('path', 'new_code'.encode("utf-8")) diff --git a/tests/unit/code/test_pyspark.py b/tests/unit/code/test_pyspark.py new file mode 100644 index 0000000000..bb470367bf --- /dev/null +++ b/tests/unit/code/test_pyspark.py @@ -0,0 +1,65 @@ +from databricks.labs.ucx.code.base import Deprecation +from databricks.labs.ucx.code.pyspark import SparkSql +from databricks.labs.ucx.code.queries import FromTable + + +def test_spark_not_sql(empty_index): + ftf = FromTable(empty_index) + sqf = SparkSql(ftf) + + assert not list(sqf.lint("print(1)")) + + +def test_spark_sql_no_match(empty_index): + ftf = FromTable(empty_index) + sqf = SparkSql(ftf) + + old_code = """ +spark.read.csv("s3://bucket/path") +for i in range(10): + result = spark.sql("SELECT * FROM old.things").collect() + print(len(result)) +""" + + assert not list(sqf.lint(old_code)) + + +def test_spark_sql_match(migration_index): + ftf = FromTable(migration_index) + sqf = SparkSql(ftf) + + old_code = """ +spark.read.csv("s3://bucket/path") +for i in range(10): + result = spark.sql("SELECT * FROM old.things").collect() + print(len(result)) +""" + assert [ + Deprecation( + code='table-migrate', + message='Table old.things is migrated to brand.new.stuff in Unity Catalog', + start_line=4, + start_col=13, + end_line=4, + end_col=50, + ) + ] == list(sqf.lint(old_code)) + + +def test_spark_sql_fix(migration_index): + ftf = FromTable(migration_index) + sqf = SparkSql(ftf) + + old_code = """spark.read.csv("s3://bucket/path") +for i in range(10): + result = spark.sql("SELECT * FROM old.things").collect() + print(len(result)) +""" + fixed_code = sqf.apply(old_code) + assert ( + fixed_code + == """spark.read.csv('s3://bucket/path') +for i in range(10): + result = spark.sql('SELECT * FROM brand.new.stuff').collect() + print(len(result))""" + ) diff --git a/tests/unit/code/test_queries.py b/tests/unit/code/test_queries.py new file mode 100644 index 0000000000..462113497e --- /dev/null +++ b/tests/unit/code/test_queries.py @@ -0,0 +1,44 @@ +from databricks.labs.ucx.code.base import Deprecation +from databricks.labs.ucx.code.queries import FromTable + + +def test_not_migrated_tables_trigger_nothing(empty_index): + ftf = FromTable(empty_index) + + old_query = "SELECT * FROM old.things LEFT JOIN hive_metastore.other.matters USING (x) WHERE state > 1 LIMIT 10" + + assert not list(ftf.lint(old_query)) + + +def test_migrated_tables_trigger_messages(migration_index): + ftf = FromTable(migration_index) + + old_query = "SELECT * FROM old.things LEFT JOIN hive_metastore.other.matters USING (x) WHERE state > 1 LIMIT 10" + + assert [ + Deprecation( + code='table-migrate', + message='Table old.things is migrated to brand.new.stuff in Unity Catalog', + start_line=0, + start_col=0, + end_line=0, + end_col=1024, + ), + Deprecation( + code='table-migrate', + message='Table other.matters is migrated to some.certain.issues in Unity Catalog', + start_line=0, + start_col=0, + end_line=0, + end_col=1024, + ), + ] == list(ftf.lint(old_query)) + + +def test_fully_migrated_queries_match(migration_index): + ftf = FromTable(migration_index) + + old_query = "SELECT * FROM old.things LEFT JOIN hive_metastore.other.matters USING (x) WHERE state > 1 LIMIT 10" + new_query = "SELECT * FROM brand.new.stuff LEFT JOIN some.certain.issues USING (x) WHERE state > 1 LIMIT 10" + + assert ftf.apply(old_query) == new_query