Skip to content

Commit

Permalink
feat(JA): Code refactor and moving functions around for code sharing …
Browse files Browse the repository at this point in the history
…with upcoming JA CLI. (#451)

Signed-off-by: David Leong <[email protected]>
  • Loading branch information
leongdl authored Sep 25, 2024
1 parent 9cd8dec commit cd37413
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 96 deletions.
5 changes: 4 additions & 1 deletion src/deadline/client/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@
TelemetryClient,
record_success_fail_telemetry_event,
)
from ._submit_job_bundle import create_job_from_job_bundle, wait_for_create_job_to_complete
from ._submit_job_bundle import (
create_job_from_job_bundle,
wait_for_create_job_to_complete,
)
from ._get_storage_profile_for_queue import get_storage_profile_for_queue

logger = getLogger(__name__)
Expand Down
47 changes: 47 additions & 0 deletions src/deadline/client/api/_job_attachment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from deadline.client import api
from deadline.client.config import config_file
from deadline.job_attachments.models import AssetRootGroup, AssetRootManifest
from deadline.job_attachments.upload import S3AssetManager, SummaryStatistics


import textwrap
from configparser import ConfigParser
from typing import Callable, Dict, List, Optional, Tuple


def _hash_attachments(
asset_manager: S3AssetManager,
asset_groups: List[AssetRootGroup],
total_input_files: int,
total_input_bytes: int,
print_function_callback: Callable = lambda msg: None,
hashing_progress_callback: Optional[Callable] = None,
config: Optional[ConfigParser] = None,
) -> Tuple[SummaryStatistics, List[AssetRootManifest]]:
"""
Starts the job attachments hashing and handles the progress reporting
callback. Returns a list of the asset manifests of the hashed files.
"""

def _default_update_hash_progress(hashing_metadata: Dict[str, str]) -> bool:
return True

if not hashing_progress_callback:
hashing_progress_callback = _default_update_hash_progress

hashing_summary, manifests = asset_manager.hash_assets_and_create_manifest(
asset_groups=asset_groups,
total_input_files=total_input_files,
total_input_bytes=total_input_bytes,
hash_cache_dir=config_file.get_cache_directory(),
on_preparing_to_submit=hashing_progress_callback,
)
api.get_deadline_cloud_library_telemetry_client(config=config).record_hashing_summary(
hashing_summary
)
print_function_callback("Hashing Summary:")
print_function_callback(textwrap.indent(str(hashing_summary), " "))

return hashing_summary, manifests
28 changes: 18 additions & 10 deletions src/deadline/client/api/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,25 @@
of the Deadline-configured IAM credentials.
"""
from __future__ import annotations
from functools import lru_cache

import logging
from configparser import ConfigParser
from contextlib import contextmanager
from enum import Enum
from functools import lru_cache
from typing import Optional

import boto3 # type: ignore[import]
import botocore
from botocore.client import BaseClient # type: ignore[import]
from botocore.credentials import CredentialProvider, RefreshableCredentials
from botocore.exceptions import ( # type: ignore[import]
ClientError,
ProfileNotFound,
)

from botocore.session import get_session as get_botocore_session
import botocore
from .. import version


from .. import version
from ..config import get_setting
from ..exceptions import DeadlineOperationError

Expand Down Expand Up @@ -104,6 +104,18 @@ def invalidate_boto3_session_cache() -> None:
_get_queue_user_boto3_session.cache_clear()


def get_default_client_config() -> botocore.config.Config:
"""
Gets the default botocore Config object to use with `boto3 sessions`.
This method adds user agent version and submitter context into botocore calls.
"""
user_agent_extra = f"app/deadline-client#{version}"
if session_context.get("submitter-name"):
user_agent_extra += f" submitter/{session_context['submitter-name']}"
session_config = botocore.config.Config(user_agent_extra=user_agent_extra)
return session_config


def get_boto3_client(service_name: str, config: Optional[ConfigParser] = None) -> BaseClient:
"""
Gets a client from the boto3 session returned by `get_boto3_session`.
Expand All @@ -115,12 +127,8 @@ def get_boto3_client(service_name: str, config: Optional[ConfigParser] = None) -
config (ConfigParser, optional): If provided, the AWS Deadline Cloud config to use.
"""

user_agent_extra = f"app/deadline-client#{version}"
if session_context.get("submitter-name"):
user_agent_extra += f" submitter/{session_context['submitter-name']}"
session_config = botocore.config.Config(user_agent_extra=user_agent_extra)
session = get_boto3_session(config=config)
return session.client(service_name, config=session_config)
return session.client(service_name, config=get_default_client_config())


def get_credentials_source(config: Optional[ConfigParser] = None) -> AwsCredentialsSource:
Expand Down
43 changes: 4 additions & 39 deletions src/deadline/client/api/_submit_job_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from configparser import ConfigParser
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from botocore.client import BaseClient # type: ignore[import]
from botocore.client import BaseClient

from deadline.client.api._job_attachment import _hash_attachments # type: ignore[import]

from .. import api
from ..exceptions import DeadlineOperationError, CreateJobWaiterCanceled
Expand All @@ -35,12 +37,11 @@
from ...job_attachments.exceptions import MisconfiguredInputsError
from ...job_attachments.models import (
JobAttachmentsFileSystem,
AssetRootGroup,
AssetRootManifest,
AssetUploadGroup,
JobAttachmentS3Settings,
)
from ...job_attachments.progress_tracker import SummaryStatistics, ProgressReportMetadata
from ...job_attachments.progress_tracker import ProgressReportMetadata
from ...job_attachments.upload import S3AssetManager
from ._session import session_context

Expand Down Expand Up @@ -403,42 +404,6 @@ def wait_for_create_job_to_complete(
)


def _hash_attachments(
asset_manager: S3AssetManager,
asset_groups: list[AssetRootGroup],
total_input_files: int,
total_input_bytes: int,
print_function_callback: Callable = lambda msg: None,
hashing_progress_callback: Optional[Callable] = None,
config: Optional[ConfigParser] = None,
) -> Tuple[SummaryStatistics, List[AssetRootManifest]]:
"""
Starts the job attachments hashing and handles the progress reporting
callback. Returns a list of the asset manifests of the hashed files.
"""

def _default_update_hash_progress(hashing_metadata: Dict[str, str]) -> bool:
return True

if not hashing_progress_callback:
hashing_progress_callback = _default_update_hash_progress

hashing_summary, manifests = asset_manager.hash_assets_and_create_manifest(
asset_groups=asset_groups,
total_input_files=total_input_files,
total_input_bytes=total_input_bytes,
hash_cache_dir=config_file.get_cache_directory(),
on_preparing_to_submit=hashing_progress_callback,
)
api.get_deadline_cloud_library_telemetry_client(config=config).record_hashing_summary(
hashing_summary
)
print_function_callback("Hashing Summary:")
print_function_callback(textwrap.indent(str(hashing_summary), " "))

return hashing_summary, manifests


@api.record_success_fail_telemetry_event(metric_name="cli_asset_upload") # type: ignore
def _upload_attachments(
asset_manager: S3AssetManager,
Expand Down
46 changes: 46 additions & 0 deletions src/deadline/client/cli/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,27 @@
"_handle_error",
"_apply_cli_options_to_config",
"_cli_object_repr",
"_ProgressBarCallbackManager",
]

import sys
from configparser import ConfigParser
from typing import Any, Callable, Optional, Set

import click
from contextlib import ExitStack
from deadline.job_attachments.progress_tracker import ProgressReportMetadata

from ..config import config_file
from ..exceptions import DeadlineOperationError
from ..job_bundle import deadline_yaml_dump
from ._groups._sigint_handler import SigIntHandler

_PROMPT_WHEN_COMPLETE = "PROMPT_WHEN_COMPLETE"

# Set up the signal handler for handling Ctrl + C interruptions.
sigint_handler = SigIntHandler()


def _prompt_at_completion(ctx: click.Context):
"""
Expand Down Expand Up @@ -182,3 +189,42 @@ def _cli_object_repr(obj: Any):
# strings to end with "\n".
obj = _fix_multiline_strings(obj)
return deadline_yaml_dump(obj)


class _ProgressBarCallbackManager:
"""
Manages creation, update, and deletion of a progress bar. On first call of the callback, the progress bar is created. The progress bar is closed
on the final call (100% completion)
"""

BAR_NOT_CREATED = 0
BAR_CREATED = 1
BAR_CLOSED = 2

def __init__(self, length: int, label: str):
self._length = length
self._label = label
self._bar_status = self.BAR_NOT_CREATED
self._exit_stack = ExitStack()

def callback(self, upload_metadata: ProgressReportMetadata) -> bool:
if self._bar_status == self.BAR_CLOSED:
# from multithreaded execution this can be called after completion somtimes.
return sigint_handler.continue_operation
elif self._bar_status == self.BAR_NOT_CREATED:
# Note: click doesn't export the return type of progressbar(), so we suppress mypy warnings for
# not annotating the type of hashing_progress.
self._upload_progress = click.progressbar(length=self._length, label=self._label) # type: ignore[var-annotated]
self._exit_stack.enter_context(self._upload_progress)
self._bar_status = self.BAR_CREATED

total_progress = int(upload_metadata.progress)
new_progress = total_progress - self._upload_progress.pos
if new_progress > 0:
self._upload_progress.update(new_progress)

if total_progress == self._length or not sigint_handler.continue_operation:
self._bar_status = self.BAR_CLOSED
self._exit_stack.close()

return sigint_handler.continue_operation
43 changes: 1 addition & 42 deletions src/deadline/client/cli/_groups/bundle_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Any, Optional

import click
from contextlib import ExitStack
from botocore.exceptions import ClientError

from deadline.client import api
Expand All @@ -22,11 +21,10 @@
MisconfiguredInputsError,
)
from deadline.job_attachments.models import AssetUploadGroup, JobAttachmentsFileSystem
from deadline.job_attachments.progress_tracker import ProgressReportMetadata
from deadline.job_attachments._utils import _human_readable_file_size

from ...exceptions import DeadlineOperationError, CreateJobWaiterCanceled
from .._common import _apply_cli_options_to_config, _handle_error
from .._common import _apply_cli_options_to_config, _handle_error, _ProgressBarCallbackManager
from ._sigint_handler import SigIntHandler

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -356,42 +354,3 @@ def _print_response(
click.echo(f"Job ID: {job_id}")
else:
click.echo("Job submission canceled.")


class _ProgressBarCallbackManager:
"""
Manages creation, update, and deletion of a progress bar. On first call of the callback, the progress bar is created. The progress bar is closed
on the final call (100% completion)
"""

BAR_NOT_CREATED = 0
BAR_CREATED = 1
BAR_CLOSED = 2

def __init__(self, length: int, label: str):
self._length = length
self._label = label
self._bar_status = self.BAR_NOT_CREATED
self._exit_stack = ExitStack()

def callback(self, upload_metadata: ProgressReportMetadata) -> bool:
if self._bar_status == self.BAR_CLOSED:
# from multithreaded execution this can be called after completion somtimes.
return sigint_handler.continue_operation
elif self._bar_status == self.BAR_NOT_CREATED:
# Note: click doesn't export the return type of progressbar(), so we suppress mypy warnings for
# not annotating the type of hashing_progress.
self._upload_progress = click.progressbar(length=self._length, label=self._label) # type: ignore[var-annotated]
self._exit_stack.enter_context(self._upload_progress)
self._bar_status = self.BAR_CREATED

total_progress = int(upload_metadata.progress)
new_progress = total_progress - self._upload_progress.pos
if new_progress > 0:
self._upload_progress.update(new_progress)

if total_progress == self._length or not sigint_handler.continue_operation:
self._bar_status = self.BAR_CLOSED
self._exit_stack.close()

return sigint_handler.continue_operation
39 changes: 39 additions & 0 deletions src/deadline/client/cli/_groups/click_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

import click
import json
import typing as t


class ClickLogger:
"""
Wrapper around click that is JSON aware. Users can instantiate this as a
replacement for using `click.echo`. A helper JSON function is also provided
to output JSON.
"""

def __init__(self, is_json: bool):
self._is_json = is_json

def echo(
self,
message: t.Optional[t.Any] = None,
file: t.Optional[t.IO[t.Any]] = None,
nl: bool = True,
err: bool = False,
color: t.Optional[bool] = None,
):
if not self._is_json:
click.echo(message, file, nl, err, color)

def json(
self,
message: t.Optional[dict] = None,
file: t.Optional[t.IO[t.Any]] = None,
nl: bool = True,
err: bool = False,
color: t.Optional[bool] = None,
indent=None,
):
if self._is_json:
click.echo(json.dumps(obj=message, indent=indent), file, nl, err, color)
4 changes: 3 additions & 1 deletion src/deadline/job_attachments/_glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def _glob_paths(
) -> List[str]:
"""
Glob routine that supports Unix style pathname pattern expansion for includes and excludes.
This function will recursively list all files of path, including all files globbed by include and removing all files marked by exclude.
path: Root path to glob.
include: Optional, pattern syntax for files to include.
exlucde: Optional, pattern syntax for files to exclude.
exclude: Optional, pattern syntax for files to exclude.
return: List of files found based on supplied glob patterns.
"""
include_files: List[Optional[str]] = []
for input_glob in include:
Expand Down
Loading

0 comments on commit cd37413

Please sign in to comment.