Skip to content

Commit

Permalink
Increased memory for ucx clusters (#1366)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkvuong authored Apr 24, 2024
1 parent 6f7696f commit 491f792
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 10 deletions.
4 changes: 3 additions & 1 deletion src/databricks/labs/ucx/installer/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from databricks.sdk.service import compute
from databricks.sdk.service.sql import GetWorkspaceWarehouseConfigResponse


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -89,9 +90,10 @@ def _get_instance_pool_id(self) -> str | None:

def _definition(self, conf: dict, instance_profile: str | None, instance_pool_id: str | None) -> str:
latest_lts_dbr = self._ws.clusters.select_spark_version(latest=True, long_term_support=True)
node_type_id = self._ws.clusters.select_node_type(local_disk=True, min_memory_gb=16)
policy_definition = {
"spark_version": self._policy_config(latest_lts_dbr),
"node_type_id": self._policy_config(self._ws.clusters.select_node_type(local_disk=True)),
"node_type_id": self._policy_config(node_type_id),
}
for key, value in conf.items():
policy_definition[f"spark_conf.{key}"] = self._policy_config(value)
Expand Down
10 changes: 5 additions & 5 deletions src/databricks/labs/ucx/mixins/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def create(
kwargs["spark_conf"] = {"spark.databricks.cluster.profile": "singleNode", "spark.master": "local[*]"}
kwargs["custom_tags"] = {"ResourceClass": "SingleNode"}
if "instance_pool_id" not in kwargs:
kwargs["node_type_id"] = ws.clusters.select_node_type(local_disk=True)
kwargs["node_type_id"] = ws.clusters.select_node_type(local_disk=True, min_memory_gb=16)

return ws.clusters.create(
cluster_name=cluster_name,
Expand Down Expand Up @@ -738,7 +738,7 @@ def create(*, instance_pool_name=None, node_type_id=None, **kwargs):
if instance_pool_name is None:
instance_pool_name = f"sdk-{make_random(4)}"
if node_type_id is None:
node_type_id = ws.clusters.select_node_type(local_disk=True)
node_type_id = ws.clusters.select_node_type(local_disk=True, min_memory_gb=16)
return ws.instance_pools.create(instance_pool_name, node_type_id, **kwargs)

yield from factory("instance pool", create, lambda item: ws.instance_pools.delete(item.instance_pool_id))
Expand All @@ -761,7 +761,7 @@ def create(**kwargs):
description=make_random(4),
new_cluster=compute.ClusterSpec(
num_workers=1,
node_type_id=ws.clusters.select_node_type(local_disk=True),
node_type_id=ws.clusters.select_node_type(local_disk=True, min_memory_gb=16),
spark_version=ws.clusters.select_spark_version(latest=True),
spark_conf=task_spark_conf,
),
Expand All @@ -776,7 +776,7 @@ def create(**kwargs):
description=make_random(4),
new_cluster=compute.ClusterSpec(
num_workers=1,
node_type_id=ws.clusters.select_node_type(local_disk=True),
node_type_id=ws.clusters.select_node_type(local_disk=True, min_memory_gb=16),
spark_version=ws.clusters.select_spark_version(latest=True),
),
notebook_task=jobs.NotebookTask(notebook_path=make_notebook()),
Expand Down Expand Up @@ -817,7 +817,7 @@ def create(**kwargs) -> pipelines.CreatePipelineResponse:
if "clusters" not in kwargs:
kwargs["clusters"] = [
pipelines.PipelineCluster(
node_type_id=ws.clusters.select_node_type(local_disk=True),
node_type_id=ws.clusters.select_node_type(local_disk=True, min_memory_gb=16),
label="default",
num_workers=1,
custom_tags={
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_installation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_job_cluster_policy(ws, installation_ctx):

spark_version = ws.clusters.select_spark_version(latest=True, long_term_support=True)
assert policy_definition["spark_version"]["value"] == spark_version
assert policy_definition["node_type_id"]["value"] == ws.clusters.select_node_type(local_disk=True)
assert policy_definition["node_type_id"]["value"] == ws.clusters.select_node_type(local_disk=True, min_memory_gb=16)
if ws.config.is_azure:
assert (
policy_definition["azure_attributes.availability"]["value"]
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/installer/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def common():

w.cluster_policies.list.return_value = [policy]
w.clusters.select_spark_version = lambda **_: "14.2.x-scala2.12"
w.clusters.select_node_type = lambda local_disk: "Standard_F4s"
w.clusters.select_node_type = lambda **_: "Standard_F4s"
w.current_user.me = lambda: iam.User(user_name="[email protected]", groups=[iam.ComplexValue(display="admins")])
prompts = MockPrompts(
{
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_update_job_policy():
'123',
new_cluster=ClusterSpec(
num_workers=1,
node_type_id=ws.clusters.select_node_type(local_disk=True),
node_type_id=ws.clusters.select_node_type(local_disk=True, min_memory_gb=16),
spark_version=ws.clusters.select_spark_version(latest=True),
),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def download(path: str) -> io.StringIO | io.BytesIO:
workspace_client.clusters.list.return_value = mock_clusters()
workspace_client.cluster_policies.create.return_value = CreatePolicyResponse(policy_id="foo")
workspace_client.clusters.select_spark_version = lambda **_: "14.2.x-scala2.12"
workspace_client.clusters.select_node_type = lambda local_disk: "Standard_F4s"
workspace_client.clusters.select_node_type = lambda **_: "Standard_F4s"
workspace_client.workspace.download = download

return workspace_client
Expand Down

0 comments on commit 491f792

Please sign in to comment.