diff --git a/src/databricks/labs/ucx/framework/dashboards.py b/src/databricks/labs/ucx/framework/dashboards.py index c08ebeb2ae..4879b522e5 100644 --- a/src/databricks/labs/ucx/framework/dashboards.py +++ b/src/databricks/labs/ucx/framework/dashboards.py @@ -8,6 +8,7 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.core import DatabricksError +from databricks.sdk.service import workspace from databricks.sdk.service.sql import ( AccessControl, ObjectTypePlural, @@ -16,7 +17,8 @@ WidgetOptions, WidgetPosition, ) -from databricks.sdk.service.workspace import ImportFormat + +from databricks.labs.ucx.framework.install_state import InstallState logger = logging.getLogger(__name__) @@ -30,16 +32,8 @@ class SimpleQuery: widget: dict[str, str] @property - def query_key(self): - return f"{self.dashboard_ref}_{self.name}:query_id" - - @property - def viz_key(self): - return f"{self.dashboard_ref}_{self.name}:viz_id" - - @property - def widget_key(self): - return f"{self.dashboard_ref}_{self.name}:widget_id" + def key(self): + return f"{self.dashboard_ref}_{self.name}" @property def viz_type(self) -> str: @@ -79,6 +73,7 @@ class DashboardFromFiles: def __init__( self, ws: WorkspaceClient, + state: InstallState, local_folder: Path, remote_folder: str, name_prefix: str, @@ -91,19 +86,14 @@ def __init__( self._name_prefix = name_prefix self._query_text_callback = query_text_callback self._warehouse_id = warehouse_id - self._state = {} + self._state = state self._pos = 0 - @property - def _query_state(self): - return f"{self._remote_folder}/state.json" - def dashboard_link(self, dashboard_ref: str): - dashboard_id = self._state[f"{dashboard_ref}:dashboard_id"] + dashboard_id = self._state.dashboards[dashboard_ref] return f"{self._ws.config.host}/sql/dashboards/{dashboard_id}" def create_dashboards(self) -> dict: - dashboards = {} queries_per_dashboard = {} # Iterate over dashboards for each step, represented as first-level folders step_folders = [f for f in self._local_folder.glob("*") if f.is_dir()] @@ -127,9 +117,8 @@ def create_dashboards(self) -> dict: self._install_viz(query) self._install_widget(query, dashboard_ref) queries_per_dashboard[dashboard_ref] = desired_queries - dashboards[dashboard_ref] = self._state[f"{dashboard_ref}:dashboard_id"] self._store_query_state(queries_per_dashboard) - return dashboards + return self._state.dashboards def validate(self): step_folders = [f for f in self._local_folder.glob("*") if f.is_dir()] @@ -148,13 +137,13 @@ def validate(self): raise AssertionError(msg) from err def _install_widget(self, query: SimpleQuery, dashboard_ref: str): - dashboard_id = self._state[f"{dashboard_ref}:dashboard_id"] + dashboard_id = self._state.dashboards[dashboard_ref] widget_options = self._get_widget_options(query) # widgets are cleaned up every dashboard redeploy widget = self._ws.dashboard_widgets.create( - dashboard_id, widget_options, 1, visualization_id=self._state[query.viz_key] + dashboard_id, widget_options, 1, visualization_id=self._state.viz[query.key] ) - self._state[query.widget_key] = widget.id + self._state.widgets[query.key] = widget.id def _get_widget_options(self, query: SimpleQuery): self._pos += 1 @@ -170,11 +159,12 @@ def _get_widget_options(self, query: SimpleQuery): ) return widget_options - def _installed_query_state(self): + def _state_pre_v06(self): try: - self._state = json.load(self._ws.workspace.download(self._query_state)) + query_state = f"{self._remote_folder}/state.json" + state = json.load(self._ws.workspace.download(query_state)) to_remove = [] - for k, v in self._state.items(): + for k, v in state.items(): if k.endswith("dashboard_id"): continue if not k.endswith("query_id"): @@ -184,50 +174,80 @@ def _installed_query_state(self): except DatabricksError: to_remove.append(k) for key in to_remove: - del self._state[key] + del state[key] + return state except DatabricksError as err: if err.error_code != "RESOURCE_DOES_NOT_EXIST": raise err self._ws.workspace.mkdirs(self._remote_folder) + return {} except JSONDecodeError: - logger.warning(f"JSON state file corrupt: {self._query_state}") - self._state = {} # noop - object_info = self._ws.workspace.get_status(self._remote_folder) + return {} + + def _remote_folder_object(self) -> workspace.ObjectInfo: + try: + return self._ws.workspace.get_status(self._remote_folder) + except DatabricksError as err: + if err.error_code != "RESOURCE_DOES_NOT_EXIST": + raise err + self._ws.workspace.mkdirs(self._remote_folder) + return self._remote_folder_object() + + def _installed_query_state(self): + if not self._state.dashboards: + for k, v in self._state_pre_v06().items(): + prefix, suffix = k.split(":", 2) + match suffix: + case "dashboard_id": + self._state.dashboards[prefix] = v + case "query_id": + self._state.queries[prefix] = v + case "viz_id": + self._state.viz[prefix] = v + case "widget_id": + self._state.widgets[prefix] = v + object_info = self._remote_folder_object() parent = f"folders/{object_info.object_id}" return parent - def _store_query_state(self, queries: dict[str, list[SimpleQuery]]): - desired_keys = [] - for ref, qrs in queries.items(): - desired_keys.append(f"{ref}:dashboard_id") - for query in qrs: - desired_keys.append(query.query_key) - desired_keys.append(query.viz_key) - desired_keys.append(query.widget_key) - destructors = { - "query_id": self._ws.queries.delete, - "viz_id": self._ws.query_visualizations.delete, - "widget_id": self._ws.dashboard_widgets.delete, - } - new_state = {} - for k, v in self._state.items(): - if k in desired_keys: - new_state[k] = v - continue - name = k if ":" not in k else k.split(":")[-1] - if name not in destructors: - continue + def _store_query_state(self, queries_per_dashboard: dict[str, list[SimpleQuery]]): + query_refs = set() + dashboard_refs = queries_per_dashboard.keys() + for queries in queries_per_dashboard.values(): + for query in queries: + query_refs.add(query.key) + + def silent_destroy(fn, object_id): try: - destructors[name](v) + fn(object_id) except DatabricksError as err: - logger.info(f"Failed to delete {name}-{v} --- {err.error_code}") - state_dump = json.dumps(new_state, indent=2).encode("utf8") - self._ws.workspace.upload(self._query_state, state_dump, format=ImportFormat.AUTO, overwrite=True) + logger.info(f"Failed to delete {object_id} --- {err.error_code}") + + for ref, object_id in self._state.dashboards.items(): + if ref in dashboard_refs: + continue + silent_destroy(self._ws.dashboards.delete, object_id) + + for ref, object_id in self._state.queries.items(): + if ref in query_refs: + continue + silent_destroy(self._ws.queries.delete, object_id) + + for ref, object_id in self._state.viz.items(): + if ref in query_refs: + continue + silent_destroy(self._ws.query_visualizations.delete, object_id) + + for ref, object_id in self._state.widgets.items(): + if ref in query_refs: + continue + silent_destroy(self._ws.dashboard_widgets.delete, object_id) + + self._state.save() def _install_dashboard(self, dashboard_name: str, parent_folder_id: str, dashboard_ref: str): - dashboard_id = f"{dashboard_ref}:dashboard_id" - if dashboard_id in self._state: - for widget in self._ws.dashboards.get(self._state[dashboard_id]).widgets: + if dashboard_ref in self._state.dashboards: + for widget in self._ws.dashboards.get(self._state.dashboards[dashboard_ref]).widgets: self._ws.dashboard_widgets.delete(widget.id) return dash = self._ws.dashboards.create(dashboard_name, run_as_role=RunAsRole.VIEWER, parent=parent_folder_id) @@ -236,7 +256,7 @@ def _install_dashboard(self, dashboard_name: str, parent_folder_id: str, dashboa dash.id, access_control_list=[AccessControl(group_name="users", permission_level=PermissionLevel.CAN_VIEW)], ) - self._state[dashboard_id] = dash.id + self._state.dashboards[dashboard_ref] = dash.id def _desired_queries(self, local_folder: Path, dashboard_ref: str) -> list[SimpleQuery]: desired_queries = [] @@ -257,10 +277,10 @@ def _desired_queries(self, local_folder: Path, dashboard_ref: str) -> list[Simpl def _install_viz(self, query: SimpleQuery): viz_args = self._get_viz_options(query) - if query.viz_key in self._state: - return self._ws.query_visualizations.update(self._state[query.viz_key], **viz_args) - viz = self._ws.query_visualizations.create(self._state[query.query_key], **viz_args) - self._state[query.viz_key] = viz.id + if query.key in self._state.viz: + return self._ws.query_visualizations.update(self._state.viz[query.key], **viz_args) + viz = self._ws.query_visualizations.create(self._state.queries[query.key], **viz_args) + self._state.viz[query.key] = viz.id def _get_viz_options(self, query: SimpleQuery): viz_types = {"table": self._table_viz_args, "counter": self._counter_viz_args} @@ -276,8 +296,8 @@ def _install_query(self, query: SimpleQuery, dashboard_name: str, data_source_id "name": f"{dashboard_name} - {query.name}", "query": query.query, } - if query.query_key in self._state: - return self._ws.queries.update(self._state[query.query_key], **query_meta) + if query.key in self._state.queries: + return self._ws.queries.update(self._state.queries[query.key], **query_meta) deployed_query = self._ws.queries.create(parent=parent, run_as_role=RunAsRole.VIEWER, **query_meta) self._ws.dbsql_permissions.set( @@ -285,7 +305,7 @@ def _install_query(self, query: SimpleQuery, dashboard_name: str, data_source_id deployed_query.id, access_control_list=[AccessControl(group_name="users", permission_level=PermissionLevel.CAN_RUN)], ) - self._state[query.query_key] = deployed_query.id + self._state.queries[query.key] = deployed_query.id @staticmethod def _table_viz_args( diff --git a/src/databricks/labs/ucx/framework/install_state.py b/src/databricks/labs/ucx/framework/install_state.py new file mode 100644 index 0000000000..6ee25665ca --- /dev/null +++ b/src/databricks/labs/ucx/framework/install_state.py @@ -0,0 +1,45 @@ +import json +import logging +from json import JSONDecodeError + +from databricks.sdk import WorkspaceClient +from databricks.sdk.core import DatabricksError +from databricks.sdk.service.workspace import ImportFormat + +logger = logging.getLogger(__name__) + + +class InstallState: + def __init__(self, ws: WorkspaceClient, install_folder: str, version: int = 1): + self._ws = ws + self._state_file = f"{install_folder}/state.json" + self._version = version + self._state = {} + + def __getattr__(self, item): + if not self._state: + self._state = self._load() + if item not in self._state["resources"]: + self._state["resources"][item] = {} + return self._state["resources"][item] + + def _load(self): + default_state = {"$version": self._version, "resources": {}} + try: + raw = json.load(self._ws.workspace.download(self._state_file)) + version = raw.get("$version", None) + if version != self._version: + msg = f"expected state $version={self._version}, got={version}" + raise ValueError(msg) + return raw + except DatabricksError as err: + if err.error_code == "RESOURCE_DOES_NOT_EXIST": + return default_state + raise err + except JSONDecodeError: + logger.warning(f"JSON state file corrupt: {self._state_file}") + return default_state + + def save(self): + state_dump = json.dumps(self._state, indent=2).encode("utf8") + self._ws.workspace.upload(self._state_file, state_dump, format=ImportFormat.AUTO, overwrite=True) diff --git a/src/databricks/labs/ucx/install.py b/src/databricks/labs/ucx/install.py index 284ba00233..4c14f26927 100644 --- a/src/databricks/labs/ucx/install.py +++ b/src/databricks/labs/ucx/install.py @@ -25,6 +25,7 @@ from databricks.labs.ucx.__about__ import __version__ from databricks.labs.ucx.config import GroupsConfig, WorkspaceConfig from databricks.labs.ucx.framework.dashboards import DashboardFromFiles +from databricks.labs.ucx.framework.install_state import InstallState from databricks.labs.ucx.framework.tasks import _TASKS, Task from databricks.labs.ucx.hive_metastore.hms_lineage import HiveMetastoreLineageEnabler from databricks.labs.ucx.runtime import main @@ -99,6 +100,7 @@ def __init__(self, ws: WorkspaceClient, *, prefix: str = "ucx", promtps: bool = self._this_file = Path(__file__) self._override_clusters = None self._dashboards = {} + self._state = InstallState(ws, self._install_folder) def run(self): logger.info(f"Installing UCX v{self._version}") @@ -168,7 +170,7 @@ def run_for_config( return workspace_installer def run_workflow(self, step: str): - job_id = self._deployed_steps[step] + job_id = self._state.jobs[step] logger.debug(f"starting {step} job: {self._ws.config.host}#job/{job_id}") job_run_waiter = self._ws.jobs.run_now(job_id) try: @@ -195,6 +197,7 @@ def _create_dashboards(self): local_query_files = self._find_project_root() / "src/databricks/labs/ucx/queries" dash = DashboardFromFiles( self._ws, + state=self._state, local_folder=local_query_files, remote_folder=f"{self._install_folder}/queries", name_prefix=self._name("UCX "), @@ -262,7 +265,7 @@ def _current_config(self): return self._config def _name(self, name: str) -> str: - return f"[{self._prefix.upper()}][{self._short_name}] {name}" + return f"[{self._prefix.upper()}] {name}" def _configure_inventory_database(self): counter = 0 @@ -380,9 +383,11 @@ def _write_config(self): self._ws.workspace.upload(self.config_file, config_bytes, format=ImportFormat.AUTO) def _create_jobs(self): + if not self._state.jobs: + for step, job_id in self._deployed_steps_pre_v06().items(): + self._state.jobs[step] = job_id logger.debug(f"Creating jobs from tasks in {main.__name__}") remote_wheel = self._upload_wheel() - self._deployed_steps = self.deployed_steps() desired_steps = {t.workflow for t in _TASKS.values()} wheel_runner = None @@ -392,22 +397,37 @@ def _create_jobs(self): settings = self._job_settings(step_name, remote_wheel) if self._override_clusters: settings = self._apply_cluster_overrides(settings, self._override_clusters, wheel_runner) - if step_name in self._deployed_steps: - job_id = self._deployed_steps[step_name] + if step_name in self._state.jobs: + job_id = self._state.jobs[step_name] logger.info(f"Updating configuration for step={step_name} job_id={job_id}") self._ws.jobs.reset(job_id, jobs.JobSettings(**settings)) else: logger.info(f"Creating new job configuration for step={step_name}") - self._deployed_steps[step_name] = self._ws.jobs.create(**settings).job_id + job_id = self._ws.jobs.create(**settings).job_id + self._state.jobs[step_name] = job_id - for step_name, job_id in self._deployed_steps.items(): + for step_name, job_id in self._state.jobs.items(): if step_name not in desired_steps: logger.info(f"Removing job_id={job_id}, as it is no longer needed") self._ws.jobs.delete(job_id) + self._state.save() self._create_readme() self._create_debug(remote_wheel) + def _deployed_steps_pre_v06(self): + deployed_steps = {} + logger.debug(f"Fetching all jobs to determine already deployed steps for app={self._app}") + for j in self._ws.jobs.list(): + tags = j.settings.tags + if tags is None: + continue + if tags.get(TAG_APP, None) != self._app: + continue + step = tags.get(TAG_STEP, "_") + deployed_steps[step] = j.job_id + return deployed_steps + @staticmethod def _sorted_tasks() -> list[Task]: return sorted(_TASKS.values(), key=lambda x: x.task_id) @@ -428,10 +448,10 @@ def _create_readme(self): "All jobs are defined with necessary cluster configurations and DBR versions.\n", ] for step_name in self._step_list(): - if step_name not in self._deployed_steps: + if step_name not in self._state.jobs: logger.warning(f"Skipping step '{step_name}' since it was not deployed.") continue - job_id = self._deployed_steps[step_name] + job_id = self._state.jobs[step_name] dashboard_link = "" dashboards_per_step = [d for d in self._dashboards.keys() if d.startswith(step_name)] for dash in dashboards_per_step: @@ -459,7 +479,7 @@ def _create_readme(self): url = self.notebook_link(path) logger.info(f"Created README notebook with job overview: {url}") msg = "Open job overview in README notebook in your home directory ?" - if self._prompts and self._question(msg, default="yes") == "yes": + if self._prompts and self._question(msg, default="no") == "yes": webbrowser.open(url) def _replace_inventory_variable(self, text: str) -> str: @@ -469,7 +489,7 @@ def _create_debug(self, remote_wheel: str): readme_link = self.notebook_link(f"{self._install_folder}/README.py") job_links = ", ".join( f"[{self._name(step_name)}]({self._ws.config.host}#job/{job_id})" - for step_name, job_id in self._deployed_steps.items() + for step_name, job_id in self._state.jobs.items() ) path = f"{self._install_folder}/DEBUG.py" logger.debug(f"Created debug notebook: {self.notebook_link(path)}") @@ -548,23 +568,12 @@ def _job_settings(self, step_name: str, dbfs_path: str): version = self._version if not self._ws.config.is_gcp else self._version.replace("+", "-") return { "name": self._name(step_name), - "tags": {TAG_APP: self._app, TAG_STEP: step_name, "version": f"v{version}"}, + "tags": {TAG_APP: self._app, "version": f"v{version}"}, "job_clusters": self._job_clusters({t.job_cluster for t in tasks}), "email_notifications": email_notifications, "tasks": [self._job_task(task, dbfs_path) for task in tasks], } - @staticmethod - def _apply_cluster_overrides(settings: dict[str, any], overrides: dict[str, str]) -> dict: - settings["job_clusters"] = [_ for _ in settings["job_clusters"] if _.job_cluster_key not in overrides] - for job_task in settings["tasks"]: - if job_task.job_cluster_key is None: - continue - if job_task.job_cluster_key in overrides: - job_task.existing_cluster_id = overrides[job_task.job_cluster_key] - job_task.job_cluster_key = None - return settings - def _upload_wheel_runner(self, remote_wheel: str): # TODO: we have to be doing this workaround until ES-897453 is solved in the platform path = f"{self._install_folder}/wheels/wheel-test-runner-{self._version}.py" @@ -773,18 +782,6 @@ def _cluster_node_type(self, spec: compute.ClusterSpec) -> compute.ClusterSpec: ) return replace(spec, gcp_attributes=compute.GcpAttributes(availability=compute.GcpAvailability.ON_DEMAND_GCP)) - def deployed_steps(self): - deployed_steps = {} - logger.debug(f"Fetching all jobs to determine already deployed steps for app={self._app}") - for j in self._ws.jobs.list(): - tags = j.settings.tags - if tags is None: - continue - if tags.get(TAG_APP, None) != self._app: - continue - deployed_steps[tags.get(TAG_STEP, "_")] = j.job_id - return deployed_steps - def _instance_profiles(self): return {"No Instance Profile": None} | { profile.instance_profile_arn: profile.instance_profile_arn for profile in self._ws.instance_profiles.list() @@ -821,7 +818,7 @@ def _get_ext_hms_conf_from_policy(cluster_policy): def latest_job_status(self) -> list[dict]: latest_status = [] - for step, job_id in self.deployed_steps().items(): + for step, job_id in self._state.jobs.items(): job_runs = list(self._ws.jobs.list_runs(job_id=job_id, limit=1)) latest_status.append( { diff --git a/tests/unit/assessment/test_dashboard.py b/tests/unit/assessment/test_dashboard.py index 1ebfe141ba..d3489b3ad5 100644 --- a/tests/unit/assessment/test_dashboard.py +++ b/tests/unit/assessment/test_dashboard.py @@ -12,6 +12,7 @@ from databricks.labs.ucx.config import GroupsConfig, WorkspaceConfig from databricks.labs.ucx.framework.dashboards import DashboardFromFiles +from databricks.labs.ucx.framework.install_state import InstallState from databricks.labs.ucx.install import WorkspaceInstaller @@ -33,6 +34,7 @@ def test_dashboard(mocker): local_query_files = installer._find_project_root() / "src/databricks/labs/ucx/queries" dash = DashboardFromFiles( ws, + InstallState(ws, "/users/not_a_real_user"), local_folder=local_query_files, remote_folder="/users/not_a_real_user/queries", name_prefix="Assessment", diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py index c5607b0a50..faa81dbeab 100644 --- a/tests/unit/test_install.py +++ b/tests/unit/test_install.py @@ -28,10 +28,12 @@ from databricks.labs.ucx.config import GroupsConfig, WorkspaceConfig from databricks.labs.ucx.framework.dashboards import DashboardFromFiles +from databricks.labs.ucx.framework.install_state import InstallState from databricks.labs.ucx.install import WorkspaceInstaller -def mock_ws(mocker): +@pytest.fixture +def ws(mocker): ws = mocker.patch("databricks.sdk.WorkspaceClient.__init__") ws.current_user.me = lambda: iam.User(user_name="me@example.com", groups=[iam.ComplexValue(display="admins")]) @@ -45,21 +47,21 @@ def mock_ws(mocker): ws.data_sources.list = lambda: [DataSource(id="bcd", warehouse_id="abc")] ws.warehouses.list = lambda **_: [EndpointInfo(id="abc", warehouse_type=EndpointInfoWarehouseType.PRO)] ws.dashboards.create.return_value = Dashboard(id="abc") + ws.jobs.create.return_value = jobs.CreateResponse(job_id="abc") ws.queries.create.return_value = Query(id="abc") ws.query_visualizations.create.return_value = Visualization(id="abc") ws.dashboard_widgets.create.return_value = Widget(id="abc") return ws -def test_replace_clusters_for_integration_tests(mocker): - ws = mock_ws(mocker) +def test_replace_clusters_for_integration_tests(ws): return_value = WorkspaceInstaller.run_for_config( ws, WorkspaceConfig(inventory_database="a", groups=GroupsConfig(auto=True)), override_clusters={"main": "abc"} ) assert return_value -def test_run_workflow_creates_proper_failure(mocker): +def test_run_workflow_creates_proper_failure(ws, mocker): def run_now(job_id): assert "bar" == job_id @@ -71,7 +73,6 @@ def result(): waiter.run_id = "qux" return waiter - ws = mock_ws(mocker) ws.jobs.run_now = run_now ws.jobs.get_run.return_value = jobs.Run( state=jobs.RunState(state_message="Stuff happens."), @@ -85,7 +86,7 @@ def result(): ) ws.jobs.get_run_output.return_value = jobs.RunOutput(error="does not compute", error_trace="# goes to stderr") installer = WorkspaceInstaller(ws) - installer._deployed_steps = {"foo": "bar"} + installer._state.jobs = {"foo": "bar"} with pytest.raises(OperationFailed) as failure: installer.run_workflow("foo") @@ -93,16 +94,14 @@ def result(): assert "Stuff happens: stuff: does not compute" == str(failure.value) -def test_install_database_happy(mocker, tmp_path): - ws = mocker.Mock() +def test_install_database_happy(ws, mocker, tmp_path): install = WorkspaceInstaller(ws) mocker.patch("builtins.input", return_value="ucx") res = install._configure_inventory_database() assert "ucx" == res -def test_install_database_unhappy(mocker, tmp_path): - ws = mocker.Mock() +def test_install_database_unhappy(ws, mocker, tmp_path): install = WorkspaceInstaller(ws) mocker.patch("builtins.input", return_value="main.ucx") @@ -110,21 +109,18 @@ def test_install_database_unhappy(mocker, tmp_path): install._configure_inventory_database() -def test_build_wheel(mocker, tmp_path): - ws = mocker.Mock() +def test_build_wheel(ws, tmp_path): install = WorkspaceInstaller(ws) whl = install._build_wheel(str(tmp_path)) assert os.path.exists(whl) -def test_save_config(mocker): +def test_save_config(ws, mocker): def not_found(_): raise DatabricksError(error_code="RESOURCE_DOES_NOT_EXIST") mocker.patch("builtins.input", return_value="42") - ws = mocker.Mock() - ws.current_user.me = lambda: iam.User(user_name="me@example.com", groups=[iam.ComplexValue(display="admins")]) - ws.config.host = "https://foo" + ws.workspace.get_status = not_found ws.warehouses.list = lambda **_: [ EndpointInfo(id="abc", warehouse_type=EndpointInfoWarehouseType.PRO, state=State.RUNNING) @@ -153,14 +149,12 @@ def not_found(_): ) -def test_save_config_with_error(mocker): +def test_save_config_with_error(ws, mocker): def not_found(_): raise DatabricksError(error_code="RAISED_FOR_TESTING") mocker.patch("builtins.input", return_value="42") - ws = mocker.Mock() - ws.current_user.me = lambda: iam.User(user_name="me@example.com", groups=[iam.ComplexValue(display="admins")]) - ws.config.host = "https://foo" + ws.workspace.get_status = not_found ws.cluster_policies.list = lambda: [] @@ -170,7 +164,7 @@ def not_found(_): assert str(e_info.value.error_code) == "RAISED_FOR_TESTING" -def test_save_config_auto_groups(mocker): +def test_save_config_auto_groups(ws, mocker): def not_found(_): raise DatabricksError(error_code="RESOURCE_DOES_NOT_EXIST") @@ -181,9 +175,7 @@ def mock_question(text: str, *, default: str | None = None) -> str: return "42" mocker.patch("builtins.input", return_value="42") - ws = mocker.Mock() - ws.current_user.me = lambda: iam.User(user_name="me@example.com", groups=[iam.ComplexValue(display="admins")]) - ws.config.host = "https://foo" + ws.workspace.get_status = not_found ws.warehouses.list = lambda **_: [ EndpointInfo(id="abc", warehouse_type=EndpointInfoWarehouseType.PRO, state=State.RUNNING) @@ -212,7 +204,7 @@ def mock_question(text: str, *, default: str | None = None) -> str: ) -def test_save_config_strip_group_names(mocker): +def test_save_config_strip_group_names(ws, mocker): def not_found(_): raise DatabricksError(error_code="RESOURCE_DOES_NOT_EXIST") @@ -223,9 +215,7 @@ def mock_question(text: str, *, default: str | None = None) -> str: return "42" mocker.patch("builtins.input", return_value="42") - ws = mocker.Mock() - ws.current_user.me = lambda: iam.User(user_name="me@example.com", groups=[iam.ComplexValue(display="admins")]) - ws.config.host = "https://foo" + ws.workspace.get_status = not_found ws.warehouses.list = lambda **_: [ EndpointInfo(id="abc", warehouse_type=EndpointInfoWarehouseType.PRO, state=State.RUNNING) @@ -257,7 +247,7 @@ def mock_question(text: str, *, default: str | None = None) -> str: ) -def test_save_config_with_glue(mocker): +def test_save_config_with_glue(ws, mocker): policy_def = b""" { "aws_attributes.instance_profile_arn": { @@ -289,9 +279,7 @@ def mock_choice_from_dict(text: str, choices: dict[str, Any]) -> Any: return "abc" mocker.patch("builtins.input", return_value="42") - ws = mocker.Mock() - ws.current_user.me = lambda: iam.User(user_name="me@example.com", groups=[iam.ComplexValue(display="admins")]) - ws.config.host = "https://foo" + ws.workspace.get_status = not_found ws.warehouses.list = lambda **_: [ EndpointInfo(id="abc", warehouse_type=EndpointInfoWarehouseType.PRO, state=State.RUNNING) @@ -324,17 +312,13 @@ def mock_choice_from_dict(text: str, choices: dict[str, Any]) -> Any: ) -def test_main_with_existing_conf_does_not_recreate_config(mocker): +def test_main_with_existing_conf_does_not_recreate_config(ws, mocker): mocker.patch("builtins.input", return_value="yes") mock_file = MagicMock() mocker.patch("builtins.open", return_value=mock_file) mocker.patch("base64.b64encode") webbrowser_open = mocker.patch("webbrowser.open") - ws = mocker.patch("databricks.sdk.WorkspaceClient.__init__") - ws.current_user.me = lambda: iam.User(user_name="me@example.com", groups=[iam.ComplexValue(display="admins")]) - ws.config.host = "https://foo" - ws.config.is_aws = True config_bytes = yaml.dump(WorkspaceConfig(inventory_database="a", groups=GroupsConfig(auto=True)).as_dict()).encode( "utf8" ) @@ -354,39 +338,34 @@ def test_main_with_existing_conf_does_not_recreate_config(mocker): # ws.workspace.mkdirs.assert_called_with("/Users/me@example.com/.ucx") -def test_query_metadata(mocker): - ws = mocker.Mock() +def test_query_metadata(ws, mocker): install = WorkspaceInstaller(ws) local_query_files = install._find_project_root() / "src/databricks/labs/ucx/queries" - DashboardFromFiles(ws, local_query_files, "any", "any").validate() + DashboardFromFiles(ws, InstallState(ws, "any"), local_query_files, "any", "any").validate() -def test_choices_out_of_range(mocker): - ws = mocker.Mock() +def test_choices_out_of_range(ws, mocker): install = WorkspaceInstaller(ws) mocker.patch("builtins.input", return_value="42") with pytest.raises(ValueError): install._choice("foo", ["a", "b"]) -def test_choices_not_a_number(mocker): - ws = mocker.Mock() +def test_choices_not_a_number(ws, mocker): install = WorkspaceInstaller(ws) mocker.patch("builtins.input", return_value="two") with pytest.raises(ValueError): install._choice("foo", ["a", "b"]) -def test_choices_happy(mocker): - ws = mocker.Mock() +def test_choices_happy(ws, mocker): install = WorkspaceInstaller(ws) mocker.patch("builtins.input", return_value="1") res = install._choice("foo", ["a", "b"]) assert "b" == res -def test_step_list(mocker): - ws = mocker.Mock() +def test_step_list(ws, mocker): from databricks.labs.ucx.framework.tasks import Task tasks = [ @@ -402,13 +381,10 @@ def test_step_list(mocker): assert steps[0] == "wl_1" and steps[1] == "wl_2" -def test_create_readme(mocker): +def test_create_readme(ws, mocker): mocker.patch("builtins.input", return_value="yes") webbrowser_open = mocker.patch("webbrowser.open") - ws = mocker.Mock() - ws.current_user.me = lambda: iam.User(user_name="me@example.com", groups=[iam.ComplexValue(display="admins")]) - ws.config.host = "https://foo" config_bytes = yaml.dump(WorkspaceConfig(inventory_database="a", groups=GroupsConfig(auto=True)).as_dict()).encode( "utf8" ) @@ -432,13 +408,8 @@ def test_create_readme(mocker): _, args, kwargs = ws.mock_calls[0] assert args[0] == "/Users/me@example.com/.ucx/README.py" - import re - - p = re.compile(".*wl_1.*n3.*n1.*wl_2.*n2.*") - assert p.match(str(args[1])) - -def test_replace_pydoc(mocker): +def test_replace_pydoc(): from databricks.labs.ucx.framework.tasks import _remove_extra_indentation doc = _remove_extra_indentation( @@ -454,8 +425,7 @@ def test_replace_pydoc(mocker): ) -def test_global_init_script_already_exists_enabled(mocker): - ws = mocker.Mock() +def test_global_init_script_already_exists_enabled(ws, mocker): ginit_scripts = [ GlobalInitScriptDetails( created_at=1695045723722, @@ -491,8 +461,7 @@ def test_global_init_script_already_exists_enabled(mocker): install._install_spark_config_for_hms_lineage() -def test_global_init_script_already_exists_disabled(mocker): - ws = mocker.Mock() +def test_global_init_script_already_exists_disabled(ws, mocker): ginit_scripts = [ GlobalInitScriptDetails( created_at=1695045723722, @@ -528,8 +497,7 @@ def test_global_init_script_already_exists_disabled(mocker): install._install_spark_config_for_hms_lineage() -def test_global_init_script_exists_disabled_not_enabled(mocker): - ws = mocker.Mock() +def test_global_init_script_exists_disabled_not_enabled(ws, mocker): ginit_scripts = [ GlobalInitScriptDetails( created_at=1695045723722, @@ -566,9 +534,8 @@ def test_global_init_script_exists_disabled_not_enabled(mocker): @patch("builtins.open", new_callable=MagicMock) -@patch("base64.b64encode") @patch("builtins.input", new_callable=MagicMock) -def test_global_init_script_create_new(mock_open, mocker, mock_input): +def test_global_init_script_create_new(mock_open, mock_input, ws): expected_content = """if [[ $DB_IS_DRIVER = "TRUE" ]]; then driver_conf=${DB_HOME}/driver/conf/spark-branch.conf if [ ! -e $driver_conf ] ; then @@ -584,7 +551,7 @@ def test_global_init_script_create_new(mock_open, mocker, mock_input): mock_file = MagicMock() mock_file.read.return_value = expected_content mock_open.return_value = mock_file - ws = mocker.Mock() - install = WorkspaceInstaller(ws) mock_input.return_value = "yes" + + install = WorkspaceInstaller(ws) install._install_spark_config_for_hms_lineage()