diff --git a/.gitignore b/.gitignore index b46da6e..73ec146 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ *.py[cod] *$py.class +.idea +.editorconfig + build/ dist/ dsub_libs/ diff --git a/README.md b/README.md index d89f924..8fb4770 100644 --- a/README.md +++ b/README.md @@ -724,7 +724,7 @@ The image below illustrates this: By default, `dsub` will use the [default Compute Engine service account](https://cloud.google.com/compute/docs/access/service-accounts#default_service_account) as the authorized service account on the VM instance. You can choose to specify -the email address of another service acount using `--service-account`. +the email address of another service account using `--service-account`. By default, `dsub` will grant the following access scopes to the service account: diff --git a/dsub/_dsub_version.py b/dsub/_dsub_version.py index 3feda53..9e7dad2 100644 --- a/dsub/_dsub_version.py +++ b/dsub/_dsub_version.py @@ -26,4 +26,4 @@ 0.1.3.dev0 -> 0.1.3 -> 0.1.4.dev0 -> ... """ -DSUB_VERSION = '0.4.10' +DSUB_VERSION = '0.4.11.dev0' diff --git a/dsub/providers/google_batch.py b/dsub/providers/google_batch.py index ca4133e..5d79e7a 100644 --- a/dsub/providers/google_batch.py +++ b/dsub/providers/google_batch.py @@ -22,17 +22,17 @@ import os import sys import textwrap -from typing import Dict, List, Set +from typing import Dict, List, Set, MutableSequence from . import base from . import google_base from . import google_batch_operations from . import google_utils +from .google_batch_operations import build_compute_resource, build_accelerators from ..lib import job_model from ..lib import param_util from ..lib import providers_util - # pylint: disable=g-import-not-at-top try: from google.cloud import batch_v1 @@ -278,12 +278,12 @@ def get_field(self, field: str, default: str = None): if self._job_descriptor: value = self._job_descriptor.job_metadata.get(field) elif field in [ - 'job-id', - 'job-name', - 'task-id', - 'task-attempt', - 'user-id', - 'dsub-version', + 'job-id', + 'job-name', + 'task-id', + 'task-attempt', + 'user-id', + 'dsub-version', ]: value = google_batch_operations.get_label(self._op, field) elif field == 'task-status': @@ -298,31 +298,31 @@ def get_field(self, field: str, default: str = None): elif field in ['envs', 'labels']: if self._job_descriptor: items = providers_util.get_job_and_task_param( - self._job_descriptor.job_params, - self._job_descriptor.task_descriptors[0].task_params, - field, + self._job_descriptor.job_params, + self._job_descriptor.task_descriptors[0].task_params, + field, ) value = {item.name: item.value for item in items} elif field in [ - 'inputs', - 'outputs', - 'input-recursives', - 'output-recursives', + 'inputs', + 'outputs', + 'input-recursives', + 'output-recursives', ]: if self._job_descriptor: value = {} items = providers_util.get_job_and_task_param( - self._job_descriptor.job_params, - self._job_descriptor.task_descriptors[0].task_params, - field, + self._job_descriptor.job_params, + self._job_descriptor.task_descriptors[0].task_params, + field, ) value.update({item.name: item.value for item in items}) elif field == 'mounts': if self._job_descriptor: items = providers_util.get_job_and_task_param( - self._job_descriptor.job_params, - self._job_descriptor.task_descriptors[0].task_params, - field, + self._job_descriptor.job_params, + self._job_descriptor.task_descriptors[0].task_params, + field, ) value = {item.name: item.value for item in items} elif field == 'provider': @@ -386,9 +386,9 @@ def _operation_status(self): return 'FAILURE' raise ValueError( - 'Status for operation {} could not be determined'.format( - self._op['name'] - ) + 'Status for operation {} could not be determined'.format( + self._op['name'] + ) ) def _operation_status_message(self): @@ -425,7 +425,7 @@ class GoogleBatchJobProvider(google_utils.GoogleJobProviderBase): """dsub provider implementation managing Jobs on Google Cloud.""" def __init__( - self, dry_run: bool, project: str, location: str, credentials=None + self, dry_run: bool, project: str, location: str, credentials=None ): self._dry_run = dry_run self._location = location @@ -448,10 +448,10 @@ def _get_logging_env(self, logging_uri, user_project, include_filter_script): logging_prefix = logging_uri[: -len('.log')] env = { - 'LOGGING_PATH': '{}.log'.format(logging_prefix), - 'STDOUT_PATH': '{}-stdout.log'.format(logging_prefix), - 'STDERR_PATH': '{}-stderr.log'.format(logging_prefix), - 'USER_PROJECT': user_project, + 'LOGGING_PATH': '{}.log'.format(logging_prefix), + 'STDOUT_PATH': '{}-stdout.log'.format(logging_prefix), + 'STDERR_PATH': '{}-stderr.log'.format(logging_prefix), + 'USER_PROJECT': user_project, } if include_filter_script: env[_LOG_FILTER_VAR] = repr(_LOG_FILTER_PYTHON) @@ -459,10 +459,10 @@ def _get_logging_env(self, logging_uri, user_project, include_filter_script): return env def _create_batch_request( - self, - task_view: job_model.JobDescriptor, - job_id, - all_envs: List[batch_v1.types.Environment], + self, + task_view: job_model.JobDescriptor, + job_id, + all_envs: List[batch_v1.types.Environment], ): job_metadata = task_view.job_metadata job_params = task_view.job_params @@ -473,18 +473,18 @@ def _create_batch_request( # Set up VM-specific variables datadisk_volume = google_batch_operations.build_volume( - disk=google_utils.DATA_DISK_NAME, path=_VOLUME_MOUNT_POINT + disk=google_utils.DATA_DISK_NAME, path=_VOLUME_MOUNT_POINT ) # Set up the task labels # pylint: disable=g-complex-comprehension labels = { - label.name: label.value if label.value else '' - for label in google_base.build_pipeline_labels( - job_metadata, task_metadata - ) - | job_params['labels'] - | task_params['labels'] + label.name: label.value if label.value else '' + for label in google_base.build_pipeline_labels( + job_metadata, task_metadata + ) + | job_params['labels'] + | task_params['labels'] } # pylint: enable=g-complex-comprehension @@ -495,34 +495,34 @@ def _create_batch_request( user_action = 3 continuous_logging_cmd = _CONTINUOUS_LOGGING_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - gsutil_cp_fn=google_utils.GSUTIL_CP_FN, - log_filter_var=_LOG_FILTER_VAR, + log_msg_fn=google_utils.LOG_MSG_FN, + gsutil_cp_fn=google_utils.GSUTIL_CP_FN, + log_filter_var=_LOG_FILTER_VAR, + log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, + python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, + logging_dir=_LOGGING_DIR, + log_file_path=_LOG_FILE_PATH, + log_cp=_LOG_CP.format( log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, - python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, - logging_dir=_LOGGING_DIR, log_file_path=_LOG_FILE_PATH, - log_cp=_LOG_CP.format( - log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, - log_file_path=_LOG_FILE_PATH, - user_action=user_action, - ), - log_interval=job_resources.log_interval or '60s', + user_action=user_action, + ), + log_interval=job_resources.log_interval or '60s', ) logging_cmd = _FINAL_LOGGING_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - gsutil_cp_fn=google_utils.GSUTIL_CP_FN, - log_filter_var=_LOG_FILTER_VAR, + log_msg_fn=google_utils.LOG_MSG_FN, + gsutil_cp_fn=google_utils.GSUTIL_CP_FN, + log_filter_var=_LOG_FILTER_VAR, + log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, + python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, + logging_dir=_LOGGING_DIR, + log_file_path=_LOG_FILE_PATH, + log_cp=_LOG_CP.format( log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, - python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, - logging_dir=_LOGGING_DIR, log_file_path=_LOG_FILE_PATH, - log_cp=_LOG_CP.format( - log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, - log_file_path=_LOG_FILE_PATH, - user_action=user_action, - ), + user_action=user_action, + ), ) # Set up command and environments for the prepare, localization, user, @@ -531,131 +531,139 @@ def _create_batch_request( user_project = task_view.job_metadata['user-project'] or '' prepare_command = google_utils.PREPARE_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - mk_runtime_dirs=google_utils.make_runtime_dirs_command( - _SCRIPT_DIR, _TMP_DIR, _WORKING_DIR - ), - script_var=google_utils.SCRIPT_VARNAME, - python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, - script_path=script_path, - mk_io_dirs=google_utils.MK_IO_DIRS, + log_msg_fn=google_utils.LOG_MSG_FN, + mk_runtime_dirs=google_utils.make_runtime_dirs_command( + _SCRIPT_DIR, _TMP_DIR, _WORKING_DIR + ), + script_var=google_utils.SCRIPT_VARNAME, + python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, + script_path=script_path, + mk_io_dirs=google_utils.MK_IO_DIRS, ) # pylint: disable=line-too-long continuous_logging_env = google_batch_operations.build_environment( - self._get_logging_env( - task_resources.logging_path.uri, user_project, True - ) + self._get_logging_env( + task_resources.logging_path.uri, user_project, True + ) ) final_logging_env = google_batch_operations.build_environment( - self._get_logging_env( - task_resources.logging_path.uri, user_project, False - ) + self._get_logging_env( + task_resources.logging_path.uri, user_project, False + ) ) # Build the list of runnables (aka actions) runnables = [] runnables.append( - # logging - google_batch_operations.build_runnable( - run_in_background=True, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=continuous_logging_env, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=['-c', continuous_logging_cmd], - ) + # logging + google_batch_operations.build_runnable( + run_in_background=True, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=continuous_logging_env, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=['-c', continuous_logging_cmd], + ) ) runnables.append( - # prepare - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=None, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=['-c', prepare_command], - ) + # prepare + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=None, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=['-c', prepare_command], + ) ) runnables.append( - # localization - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=None, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=[ - '-c', - google_utils.LOCALIZATION_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, - cp_fn=google_utils.GSUTIL_CP_FN, - cp_loop=google_utils.LOCALIZATION_LOOP, - ), - ], - ) + # localization + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=None, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=[ + '-c', + google_utils.LOCALIZATION_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, + cp_fn=google_utils.GSUTIL_CP_FN, + cp_loop=google_utils.LOCALIZATION_LOOP, + ), + ], + ) ) + # user-command volumes + user_command_volumes = [f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'] + if job_resources.accelerator_type is not None: + user_command_volumes.extend([ + "/var/lib/nvidia/lib64:/usr/local/nvidia/lib64", + "/var/lib/nvidia/bin:/usr/local/nvidia/bin" + ]) + runnables.append( - # user-command - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=job_resources.image, - environment=None, - entrypoint='/usr/bin/env', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=[ - 'bash', - '-c', - google_utils.USER_CMD.format( - tmp_dir=_TMP_DIR, - working_dir=_WORKING_DIR, - user_script=script_path, - ), - ], - ) + # user-command + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=job_resources.image, + environment=None, + entrypoint='/usr/bin/env', + volumes=user_command_volumes, + commands=[ + 'bash', + '-c', + google_utils.USER_CMD.format( + tmp_dir=_TMP_DIR, + working_dir=_WORKING_DIR, + user_script=script_path, + ), + ], + ) ) runnables.append( - # delocalization - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=None, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}:ro'], - commands=[ - '-c', - google_utils.LOCALIZATION_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, - cp_fn=google_utils.GSUTIL_CP_FN, - cp_loop=google_utils.DELOCALIZATION_LOOP, - ), - ], - ) + # delocalization + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=None, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}:ro'], + commands=[ + '-c', + google_utils.LOCALIZATION_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, + cp_fn=google_utils.GSUTIL_CP_FN, + cp_loop=google_utils.DELOCALIZATION_LOOP, + ), + ], + ) ) runnables.append( - # final_logging - google_batch_operations.build_runnable( - run_in_background=False, - always_run=True, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=final_logging_env, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=['-c', logging_cmd], - ), + # final_logging + google_batch_operations.build_runnable( + run_in_background=False, + always_run=True, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=final_logging_env, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=['-c', logging_cmd], + ), ) # Prepare the VM (resources) configuration. The InstancePolicy describes an @@ -663,39 +671,69 @@ def _create_batch_request( # describes when, where, and how compute resources should be allocated # for the Job. disk = google_batch_operations.build_persistent_disk( - size_gb=job_resources.disk_size, - disk_type=job_resources.disk_type or job_model.DEFAULT_DISK_TYPE, + size_gb=job_resources.disk_size, + disk_type=job_resources.disk_type or job_model.DEFAULT_DISK_TYPE, ) attached_disk = google_batch_operations.build_attached_disk( - disk=disk, device_name=google_utils.DATA_DISK_NAME + disk=disk, device_name=google_utils.DATA_DISK_NAME ) + instance_policy = google_batch_operations.build_instance_policy( - attached_disk + disks=attached_disk, + machine_type=job_resources.machine_type, + accelerators=build_accelerators( + accelerator_type=job_resources.accelerator_type, + accelerator_count=job_resources.accelerator_count, + ) if job_resources.accelerator_type is not None else None ) + ipt = google_batch_operations.build_instance_policy_or_template( - instance_policy + instance_policy=instance_policy, + install_gpu_drivers=True if job_resources.accelerator_type is not None else False + ) - allocation_policy = google_batch_operations.build_allocation_policy([ipt]) + + service_account = google_batch_operations.build_service_account( + service_account_email=job_resources.service_account) + + network_policy = google_batch_operations.build_network_policy( + network=job_resources.network, + subnetwork=job_resources.subnetwork, + no_external_ip_address=job_resources.use_private_address, + ) + + allocation_policy = google_batch_operations.build_allocation_policy( + ipts=[ipt], + service_account=service_account, + network_policy=network_policy, + ) + logs_policy = google_batch_operations.build_logs_policy( - batch_v1.LogsPolicy.Destination.PATH, _BATCH_LOG_FILE_PATH + batch_v1.LogsPolicy.Destination.PATH, _BATCH_LOG_FILE_PATH + ) + + compute_resource = build_compute_resource( + cpu_milli=job_resources.min_cores * 1000, + memory_mib=job_resources.min_ram * 1024, + boot_disk_mib=job_resources.boot_disk_size * 1024 ) # Bring together the task definition(s) and build the Job request. task_spec = google_batch_operations.build_task_spec( - runnables=runnables, volumes=[datadisk_volume] + runnables=runnables, volumes=[datadisk_volume], compute_resource=compute_resource ) task_group = google_batch_operations.build_task_group( - task_spec, all_envs, task_count=len(all_envs), task_count_per_node=1 + task_spec, all_envs, task_count=len(all_envs), task_count_per_node=1 ) job = google_batch_operations.build_job( - [task_group], allocation_policy, labels, logs_policy + [task_group], allocation_policy, labels, logs_policy ) job_request = batch_v1.CreateJobRequest( - parent=f'projects/{self._project}/locations/{self._location}', - job=job, - job_id=job_id, + parent=f'projects/{self._project}/locations/{self._location}', + job=job, + job_id=job_id, ) # pylint: enable=line-too-long return job_request @@ -708,7 +746,7 @@ def _submit_batch_job(self, request) -> str: return op.get_field('task-id') def _create_env_for_task( - self, task_view: job_model.JobDescriptor + self, task_view: job_model.JobDescriptor ) -> Dict[str, str]: job_params = task_view.job_params task_params = task_view.task_descriptors[0].task_params @@ -723,39 +761,39 @@ def _create_env_for_task( mounts = job_params['mounts'] prepare_env = self._get_prepare_env( - script, task_view, inputs, outputs, mounts, _DATA_MOUNT_POINT + script, task_view, inputs, outputs, mounts, _DATA_MOUNT_POINT ) localization_env = self._get_localization_env( - inputs, user_project, _DATA_MOUNT_POINT + inputs, user_project, _DATA_MOUNT_POINT ) user_environment = self._build_user_environment( - envs, inputs, outputs, mounts, _DATA_MOUNT_POINT + envs, inputs, outputs, mounts, _DATA_MOUNT_POINT ) delocalization_env = self._get_delocalization_env( - outputs, user_project, _DATA_MOUNT_POINT + outputs, user_project, _DATA_MOUNT_POINT ) # This merges all the envs into one dict. Need to use this syntax because # of python3.6. In python3.9 we'd prefer to use | operator. all_env = { - **prepare_env, - **localization_env, - **user_environment, - **delocalization_env, + **prepare_env, + **localization_env, + **user_environment, + **delocalization_env, } return all_env def submit_job( - self, - job_descriptor: job_model.JobDescriptor, - skip_if_output_present: bool, + self, + job_descriptor: job_model.JobDescriptor, + skip_if_output_present: bool, ) -> Dict[str, any]: # Validate task data and resources. param_util.validate_submit_args_or_fail( - job_descriptor, - provider_name=_PROVIDER_NAME, - input_providers=_SUPPORTED_INPUT_PROVIDERS, - output_providers=_SUPPORTED_OUTPUT_PROVIDERS, - logging_providers=_SUPPORTED_LOGGING_PROVIDERS, + job_descriptor, + provider_name=_PROVIDER_NAME, + input_providers=_SUPPORTED_INPUT_PROVIDERS, + output_providers=_SUPPORTED_OUTPUT_PROVIDERS, + logging_providers=_SUPPORTED_LOGGING_PROVIDERS, ) # Prepare and submit jobs. @@ -781,89 +819,89 @@ def submit_job( # If this is a dry-run, emit all the pipeline request objects if self._dry_run: print( - json.dumps(requests, indent=2, sort_keys=True, separators=(',', ': ')) + json.dumps(requests, indent=2, sort_keys=True, separators=(',', ': ')) ) return { - 'job-id': job_id, - 'user-id': job_descriptor.job_metadata.get('user-id'), - 'task-id': [task_id for task_id in launched_tasks if task_id], + 'job-id': job_id, + 'user-id': job_descriptor.job_metadata.get('user-id'), + 'task-id': [task_id for task_id in launched_tasks if task_id], } def delete_jobs( - self, - user_ids, - job_ids, - task_ids, - labels, - create_time_min=None, - create_time_max=None, + self, + user_ids, + job_ids, + task_ids, + labels, + create_time_min=None, + create_time_max=None, ): """Kills the operations associated with the specified job or job.task. - Args: - user_ids: List of user ids who "own" the job(s) to cancel. - job_ids: List of job_ids to cancel. - task_ids: List of task-ids to cancel. - labels: List of LabelParam, each must match the job(s) to be canceled. - create_time_min: a timezone-aware datetime value for the earliest create - time of a task, inclusive. - create_time_max: a timezone-aware datetime value for the most recent - create time of a task, inclusive. - - Returns: - A list of tasks canceled and a list of error messages. - """ + Args: + user_ids: List of user ids who "own" the job(s) to cancel. + job_ids: List of job_ids to cancel. + task_ids: List of task-ids to cancel. + labels: List of LabelParam, each must match the job(s) to be canceled. + create_time_min: a timezone-aware datetime value for the earliest create + time of a task, inclusive. + create_time_max: a timezone-aware datetime value for the most recent + create time of a task, inclusive. + + Returns: + A list of tasks canceled and a list of error messages. + """ # Look up the job(s) tasks = list( - self.lookup_job_tasks( - {'RUNNING'}, - user_ids=user_ids, - job_ids=job_ids, - task_ids=task_ids, - labels=labels, - create_time_min=create_time_min, - create_time_max=create_time_max, - ) + self.lookup_job_tasks( + {'RUNNING'}, + user_ids=user_ids, + job_ids=job_ids, + task_ids=task_ids, + labels=labels, + create_time_min=create_time_min, + create_time_max=create_time_max, + ) ) print('Found %d tasks to delete.' % len(tasks)) return google_base.cancel( - self._batch_handler_def(), self._operations_cancel_api_def(), tasks + self._batch_handler_def(), self._operations_cancel_api_def(), tasks ) def lookup_job_tasks( - self, - statuses: Set[str], - user_ids=None, - job_ids=None, - job_names=None, - task_ids=None, - task_attempts=None, - labels=None, - create_time_min=None, - create_time_max=None, - max_tasks=0, - page_size=0, + self, + statuses: Set[str], + user_ids=None, + job_ids=None, + job_names=None, + task_ids=None, + task_attempts=None, + labels=None, + create_time_min=None, + create_time_max=None, + max_tasks=0, + page_size=0, ): client = batch_v1.BatchServiceClient() # TODO: Batch API has no 'done' filter like lifesciences API. # Need to figure out how to filter for jobs that are completed. empty_statuses = set() ops_filter = self._build_query_filter( - empty_statuses, - user_ids, - job_ids, - job_names, - task_ids, - task_attempts, - labels, - create_time_min, - create_time_max, + empty_statuses, + user_ids, + job_ids, + job_names, + task_ids, + task_attempts, + labels, + create_time_min, + create_time_max, ) # Initialize request argument(s) request = batch_v1.ListJobsRequest( - parent=f'projects/{self._project}/locations/{self._location}', - filter=ops_filter, + parent=f'projects/{self._project}/locations/{self._location}', + filter=ops_filter, ) # Make the request diff --git a/dsub/providers/google_batch_operations.py b/dsub/providers/google_batch_operations.py index 01f92c0..ab072d8 100644 --- a/dsub/providers/google_batch_operations.py +++ b/dsub/providers/google_batch_operations.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utility routines for constructing a Google Batch API request.""" -from typing import List, Optional, Dict +from typing import List, Optional, Dict, MutableSequence +from google.cloud.batch_v1 import ServiceAccount, AllocationPolicy # pylint: disable=g-import-not-at-top try: @@ -20,6 +21,8 @@ except ImportError: # TODO: Remove conditional import when batch library is available from . import batch_dummy as batch_v1 + + # pylint: enable=g-import-not-at-top @@ -45,8 +48,8 @@ def get_environment(op: batch_v1.types.Job) -> Dict[str, str]: def is_done(op: batch_v1.types.Job) -> bool: """Return whether the operation has been marked done.""" return op.status.state in [ - batch_v1.types.job.JobStatus.State.SUCCEEDED, - batch_v1.types.job.JobStatus.State.FAILED, + batch_v1.types.job.JobStatus.State.SUCCEEDED, + batch_v1.types.job.JobStatus.State.FAILED, ] @@ -78,18 +81,18 @@ def _pad_timestamps(ts: str) -> str: def get_update_time(op: batch_v1.types.Job) -> Optional[str]: """Return the update time string of the operation.""" - update_time = op.update_time + update_time = op.update_time.ToDatetime() if op.update_time else None if update_time: - return _pad_timestamps(op.update_time.rfc3339()) + return update_time.isoformat('T') + 'Z' # Representing the datetime object in rfc3339 format else: return None def get_create_time(op: batch_v1.types.Job) -> Optional[str]: """Return the create time string of the operation.""" - create_time = op.create_time + create_time = op.create_time.ToDatetime() if op.create_time else None if create_time: - return _pad_timestamps(op.create_time.rfc3339()) + return create_time.isoformat('T') + 'Z' else: return None @@ -99,10 +102,10 @@ def get_status_events(op: batch_v1.types.Job): def build_job( - task_groups: List[batch_v1.types.TaskGroup], - allocation_policy: batch_v1.types.AllocationPolicy, - labels: Dict[str, str], - logs_policy: batch_v1.types.LogsPolicy, + task_groups: List[batch_v1.types.TaskGroup], + allocation_policy: batch_v1.types.AllocationPolicy, + labels: Dict[str, str], + logs_policy: batch_v1.types.LogsPolicy, ) -> batch_v1.types.Job: job = batch_v1.Job() job.task_groups = task_groups @@ -112,13 +115,44 @@ def build_job( return job +def build_compute_resource(cpu_milli: int, memory_mib: int, boot_disk_mib: int) -> batch_v1.types.ComputeResource: + """Build a ComputeResource object for a Batch request. + + Args: + cpu_milli (int): Number of milliCPU units + memory_mib (int): Amount of memory in Mebibytes (MiB) + boot_disk_mib (int): The boot disk size in Mebibytes (MiB) + + Returns: + A ComputeResource object. + """ + compute_resource = batch_v1.ComputeResource( + cpu_milli=cpu_milli, + memory_mib=memory_mib, + boot_disk_mib=boot_disk_mib + ) + return compute_resource + + def build_task_spec( - runnables: List[batch_v1.types.task.Runnable], - volumes: List[batch_v1.types.Volume], + runnables: List[batch_v1.types.task.Runnable], + volumes: List[batch_v1.types.Volume], + compute_resource: batch_v1.types.ComputeResource, ) -> batch_v1.types.TaskSpec: + """Build a TaskSpec object for a Batch request. + + Args: + runnables (List[Runnable]): List of Runnable objects + volumes (List[Volume]): List of Volume objects + compute_resource (ComputeResource): The compute resources to use + + Returns: + A TaskSpec object. + """ task_spec = batch_v1.TaskSpec() task_spec.runnables = runnables task_spec.volumes = volumes + task_spec.compute_resource = compute_resource return task_spec @@ -129,10 +163,10 @@ def build_environment(env_vars: Dict[str, str]): def build_task_group( - task_spec: batch_v1.types.TaskSpec, - task_environments: List[batch_v1.types.Environment], - task_count: int, - task_count_per_node: int, + task_spec: batch_v1.types.TaskSpec, + task_environments: List[batch_v1.types.Environment], + task_count: int, + task_count_per_node: int, ) -> batch_v1.types.TaskGroup: """Build a TaskGroup object for a Batch request. @@ -154,7 +188,7 @@ def build_task_group( def build_container( - image_uri: str, entrypoint: str, volumes: List[str], commands: List[str] + image_uri: str, entrypoint: str, volumes: List[str], commands: List[str] ) -> batch_v1.types.task.Runnable.Container: container = batch_v1.types.task.Runnable.Container() container.image_uri = image_uri @@ -165,13 +199,13 @@ def build_container( def build_runnable( - run_in_background: bool, - always_run: bool, - environment: batch_v1.types.Environment, - image_uri: str, - entrypoint: str, - volumes: List[str], - commands: List[str], + run_in_background: bool, + always_run: bool, + environment: batch_v1.types.Environment, + image_uri: str, + entrypoint: str, + volumes: List[str], + commands: List[str], ) -> batch_v1.types.task.Runnable: """Build a Runnable object for a Batch request. @@ -213,24 +247,52 @@ def build_volume(disk: str, path: str) -> batch_v1.types.Volume: return volume +def build_network_policy(network: str, subnetwork: str, + no_external_ip_address: bool) -> batch_v1.types.job.AllocationPolicy.NetworkPolicy: + network_polycy = AllocationPolicy.NetworkPolicy( + network_interfaces=[ + AllocationPolicy.NetworkInterface( + network=network, + subnetwork=subnetwork, + no_external_ip_address=no_external_ip_address, + ) + ] + ) + return network_polycy + + +def build_service_account(service_account_email: str) -> batch_v1.ServiceAccount: + service_account = ServiceAccount( + email=service_account_email + ) + return service_account + + def build_allocation_policy( - ipts: List[batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate], + ipts: List[batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate], + service_account: batch_v1.ServiceAccount, + network_policy: batch_v1.types.job.AllocationPolicy.NetworkPolicy ) -> batch_v1.types.AllocationPolicy: allocation_policy = batch_v1.AllocationPolicy() allocation_policy.instances = ipts + allocation_policy.service_account = service_account + allocation_policy.network = network_policy + return allocation_policy def build_instance_policy_or_template( - instance_policy: batch_v1.types.AllocationPolicy.InstancePolicy, + instance_policy: batch_v1.types.AllocationPolicy.InstancePolicy, + install_gpu_drivers: bool ) -> batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate: ipt = batch_v1.AllocationPolicy.InstancePolicyOrTemplate() ipt.policy = instance_policy + ipt.install_gpu_drivers = install_gpu_drivers return ipt def build_logs_policy( - destination: batch_v1.types.LogsPolicy.Destination, logs_path: str + destination: batch_v1.types.LogsPolicy.Destination, logs_path: str ) -> batch_v1.types.LogsPolicy: logs_policy = batch_v1.LogsPolicy() logs_policy.destination = destination @@ -240,15 +302,20 @@ def build_logs_policy( def build_instance_policy( - disks: List[batch_v1.types.AllocationPolicy.AttachedDisk], + disks: List[batch_v1.types.AllocationPolicy.AttachedDisk], + machine_type: str, + accelerators: MutableSequence[batch_v1.types.AllocationPolicy.Accelerator] ) -> batch_v1.types.AllocationPolicy.InstancePolicy: instance_policy = batch_v1.AllocationPolicy.InstancePolicy() instance_policy.disks = [disks] + instance_policy.machine_type = machine_type + instance_policy.accelerators = accelerators + return instance_policy def build_attached_disk( - disk: batch_v1.types.AllocationPolicy.Disk, device_name: str + disk: batch_v1.types.AllocationPolicy.Disk, device_name: str ) -> batch_v1.types.AllocationPolicy.AttachedDisk: attached_disk = batch_v1.AllocationPolicy.AttachedDisk() attached_disk.new_disk = disk @@ -257,9 +324,22 @@ def build_attached_disk( def build_persistent_disk( - size_gb: int, disk_type: str + size_gb: int, disk_type: str ) -> batch_v1.types.AllocationPolicy.Disk: disk = batch_v1.AllocationPolicy.Disk() disk.type = disk_type disk.size_gb = size_gb return disk + + +def build_accelerators( + accelerator_type, + accelerator_count +) -> MutableSequence[batch_v1.types.AllocationPolicy.Accelerator]: + accelerators = [] + accelerator = batch_v1.AllocationPolicy.Accelerator() + accelerator.count = accelerator_count + accelerator.type = accelerator_type + accelerators.append(accelerator) + + return accelerators