Skip to content

Commit

Permalink
type check updates
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 7, 2024
1 parent b93b9e2 commit 1705aa8
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 40 deletions.
18 changes: 9 additions & 9 deletions anyscale_provider/hooks/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import anyscale
from anyscale import Anyscale
from anyscale.job.models import JobConfig
from anyscale.job.models import JobStatus, JobState
from anyscale.service.models import ServiceConfig, ServiceStatus, ServiceVersionState, ServiceState
from anyscale.job.models import JobConfig, JobStatus
from anyscale.service.models import ServiceConfig, ServiceStatus, ServiceState

from airflow.hooks.base import BaseHook # Adjusted import based on Airflow's newer version
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -135,18 +134,18 @@ def get_ui_field_behaviour(cls) -> Dict[str, Any]:

def submit_job(self, config: JobConfig) -> str:
self.log.info("Creating a job with configuration: {}".format(config))
job_id = self.sdk.job.submit(config=config)
job_id: str = self.sdk.job.submit(config=config)
return job_id

def deploy_service(self, config: ServiceConfig,
in_place: bool = False,
canary_percent: Optional[int] = None,
max_surge_percent: Optional[int] = None) -> str:
self.log.info("Deploying a service with configuration: {}".format(config))
service_id = self.sdk.service.deploy(config=config,
in_place=in_place,
canary_percent=canary_percent,
max_surge_percent=max_surge_percent)
service_id: str = self.sdk.service.deploy(config=config,
in_place=in_place,
canary_percent=canary_percent,
max_surge_percent=max_surge_percent)
return service_id

def get_job_status(self, job_id: str) -> JobStatus:
Expand Down Expand Up @@ -177,4 +176,5 @@ def terminate_service(self, service_id: str, time_delay: int) -> bool:
return True

def get_logs(self, job_id: str) -> str:
return self.sdk.job.get_logs(job_id=job_id)
logs: str = self.sdk.job.get_logs(job_id=job_id)
return logs
58 changes: 29 additions & 29 deletions anyscale_provider/operators/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def __init__(self,
conn_id: str,
name: str,
image_uri: str,
compute_config: Union[ComputeConfig, dict, str],
compute_config: Union[ComputeConfig, Dict[str, Any], str],
working_dir: str,
entrypoint: str,
excludes: Optional[List[str]] = None,
requirements: Optional[Union[str, List[str]]] = None,
env_vars: Optional[Dict[str, str]] = None,
py_modules: Optional[List[str]] = None,
max_retries: int = 1,
*args, **kwargs: Any) -> None:
*args: Any, **kwargs: Any) -> None:
super(SubmitAnyscaleJob, self).__init__(*args, **kwargs)
self.conn_id = conn_id
self.name = name
Expand All @@ -78,8 +78,8 @@ def __init__(self,
self.entrypoint = entrypoint
self.max_retries = max_retries

self.job_id: str = None
self.created_at: float = None
self.job_id: Optional[str] = None
self.created_at: Optional[float] = None

self.fields: Dict[str, Any] = {
"name": name,
Expand Down Expand Up @@ -152,9 +152,9 @@ def defer_job_polling(self, job_id: str) -> None:
method_name="execute_complete")

def get_current_status(self, job_id: str) -> str:
return self.hook.get_job_status(job_id=job_id).state
return str(self.hook.get_job_status(job_id=job_id).state)

def execute_complete(self, context: Context, event: TriggerEvent) -> None:
def execute_complete(self, context: Context, event: Any) -> None:
current_job_id = event["job_id"]

if event["status"] == JobState.FAILED:
Expand Down Expand Up @@ -193,34 +193,34 @@ class RolloutAnyscaleService(BaseOperator):
:param logging_config: Optional. Logging configuration for the service. Defaults to None.
:param ray_gcs_external_storage_config: Optional. Ray GCS external storage configuration. Defaults to None.
:param in_place: Optional. Flag for in-place updates. Defaults to False.
:param canary_percent: Optional. Percentage of canary deployment. Defaults to None.
:param max_surge_percent: Optional. Maximum percentage of surge during deployment. Defaults to None.
:param canary_percent: Optional[float]. Percentage of canary deployment. Defaults to None.
:param max_surge_percent: Optional[float]. Maximum percentage of surge during deployment. Defaults to None.
:raises ValueError: If service name or applications list is not provided.
:raises AirflowException: If the SDK is not available or the service deployment fails.
"""

def __init__(self,
conn_id: str,
name: str,
image_uri: str,
compute_config: Union[ComputeConfig, dict, str],
applications: List[Dict[str, Any]],
working_dir: str,
containerfile: Optional[str] = None,
excludes: Optional[List[str]] = None,
requirements: Optional[Union[str, List[str]]] = None,
env_vars: Optional[Dict[str, str]] = None,
py_modules: Optional[List[str]] = None,
query_auth_token_enabled: bool = False,
http_options: Optional[Dict[str, Any]] = None,
grpc_options: Optional[Dict[str, Any]] = None,
logging_config: Optional[Dict[str, Any]] = None,
ray_gcs_external_storage_config: Optional[Union[RayGCSExternalStorageConfig, dict]] = None,
in_place: bool = False,
canary_percent: Optional[float] = None,
max_surge_percent: Optional[int] = None,
**kwargs: Any) -> None:
conn_id: str,
name: str,
image_uri: str,
compute_config: Union[ComputeConfig, Dict[str, Any], str],
applications: List[Dict[str, Any]],
working_dir: str,
containerfile: Optional[str] = None,
excludes: Optional[List[str]] = None,
requirements: Optional[Union[str, List[str]]] = None,
env_vars: Optional[Dict[str, str]] = None,
py_modules: Optional[List[str]] = None,
query_auth_token_enabled: bool = False,
http_options: Optional[Dict[str, Any]] = None,
grpc_options: Optional[Dict[str, Any]] = None,
logging_config: Optional[Dict[str, Any]] = None,
ray_gcs_external_storage_config: Optional[Union[RayGCSExternalStorageConfig, Dict[str, Any]]] = None,
in_place: bool = False,
canary_percent: Optional[float] = None,
max_surge_percent: Optional[float] = None,
**kwargs: Any) -> None:
super().__init__(**kwargs)
self.conn_id = conn_id

Expand Down Expand Up @@ -285,7 +285,7 @@ def execute(self, context: Context) -> Optional[str]:
self.log.info(f"Service rollout id: {service_id}")
return service_id

def execute_complete(self, context: Context, event: TriggerEvent) -> None:
def execute_complete(self, context: Context, event: Any) -> None:
self.log.info(f"Execution completed...")
service_id = event["service_name"]

Expand Down
4 changes: 2 additions & 2 deletions anyscale_provider/triggers/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from functools import partial
from datetime import datetime, timedelta
from typing import Any, Dict, AsyncIterator, Tuple
from typing import Any, Dict, AsyncIterator, Tuple, Optional

from anyscale.job.models import JobState
from anyscale.service.models import ServiceState
Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(self,
conn_id: str,
service_name: str,
expected_state: str,
canary_percent: float,
canary_percent: Optional[float],
poll_interval: int = 60,
timeout: int = 600):
super().__init__()
Expand Down

0 comments on commit 1705aa8

Please sign in to comment.