Skip to content

Commit

Permalink
integration test and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 7, 2024
1 parent 1705aa8 commit 2606b24
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 26 deletions.
4 changes: 3 additions & 1 deletion anyscale_provider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__version__ = "1.0.0"

def get_provider_info():
from typing import Any, Dict, Optional

def get_provider_info() -> Dict[str,Any]:
return {
"package-name": "astro-provider-anyscale", # Required
"name": "Anyscale", # Required
Expand Down
4 changes: 2 additions & 2 deletions anyscale_provider/hooks/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def submit_job(self, config: JobConfig) -> str:

def deploy_service(self, config: ServiceConfig,
in_place: bool = False,
canary_percent: Optional[int] = None,
max_surge_percent: Optional[int] = None) -> str:
canary_percent: Optional[float] = None,
max_surge_percent: Optional[float] = None) -> str:
self.log.info("Deploying a service with configuration: {}".format(config))
service_id: str = self.sdk.service.deploy(config=config,
in_place=in_place,
Expand Down
3 changes: 1 addition & 2 deletions anyscale_provider/operators/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(self,
self.max_retries = max_retries

self.job_id: Optional[str] = None
self.created_at: Optional[float] = None

self.fields: Dict[str, Any] = {
"name": name,
Expand Down Expand Up @@ -123,7 +122,7 @@ def execute(self, context: Context) -> Optional[str]:
# Submit the job to Anyscale
job_config = JobConfig(**self.fields)
self.job_id = self.hook.submit_job(job_config)
self.created_at = time.time()
self.created_at: float = time.time()
self.log.info(f"Submitted Anyscale job with ID: {self.job_id}")

current_status = self.get_current_status(self.job_id)
Expand Down
13 changes: 7 additions & 6 deletions anyscale_provider/triggers/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class AnyscaleJobTrigger(BaseTrigger):
"""

def __init__(self, conn_id: str, job_id: str, job_start_time: float, poll_interval: int = 60, timeout: int = 3600):
super().__init__()
super().__init__() # type: ignore[no-untyped-call]
self.conn_id = conn_id
self.job_id = job_id
self.job_start_time = job_start_time
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(self,
canary_percent: Optional[float],
poll_interval: int = 60,
timeout: int = 600):
super().__init__()
super().__init__() # type: ignore[no-untyped-call]
self.conn_id = conn_id
self.service_name = service_name
self.expected_state = expected_state
Expand Down Expand Up @@ -200,11 +200,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]:

def get_current_status(self, service_name: str) -> str:
service_status = self.hook.get_service_status(service_name)

if self.canary_percent is None or 0.0 < self.canary_percent < 100.0:
return str(service_status.canary_version.state)
else:

if self.canary_percent is None or self.canary_percent == 100.0:
return str(service_status.state)
else:
if service_status.canary_version and service_status.canary_version.state:
return str(service_status.canary_version.state)

def check_current_status(self, service_name: str) -> bool:
job_status = self.get_current_status(service_name)
Expand Down
33 changes: 18 additions & 15 deletions tests/dags/test_dag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
from contextlib import contextmanager
import pytest
from pathlib import Path
from airflow.models import DagBag
from airflow.models import DagBag, Connection
from airflow.utils.db import create_default_connections
from airflow.utils.session import provide_session
from airflow.utils.session import create_session
from airflow.models import Connection
from airflow.utils.session import provide_session, create_session

import utils as test_utils

EXAMPLE_DAGS_DIR = Path(__file__).parent.parent.parent / "anyscale_provider/example_dags"


def get_dags(dag_folder=None):
# Generate a tuple of dag_id, <DAG objects> in the DagBag
dag_bag = DagBag(dag_folder=dag_folder, include_examples=False) if dag_folder else DagBag(include_examples=False)
Expand All @@ -31,16 +28,22 @@ def strip_path_prefix(path):

return dags_info

@pytest.mark.integration
@pytest.mark.parametrize("dag_id,dag, fileloc", get_dags(EXAMPLE_DAGS_DIR), ids=[x[2] for x in get_dags()])
def test_dag_runs(dag_id, dag, fileloc):

@pytest.fixture(scope="module")
def setup_airflow_db():
# Initialize the database
os.system('airflow db init')
with create_session() as session:
conn = Connection(conn_id="anyscale_conn",
conn_type="anyscale",
password=os.getenv("ANYSCALE_CLI_TOKEN"))
# Add anyscale connection
conn = Connection(
conn_id="anyscale_conn",
conn_type="anyscale",
extra=f'{{"ANYSCALE_CLI_TOKEN": "{os.environ["ANYSCALE_CLI_TOKEN"]}"}}'
)
session.add(conn)
session.commit() # Ensure the connection is committed to the database
session.commit()

# Run the example dags
test_utils.run_dag(dag)
@pytest.mark.integration
@pytest.mark.parametrize("dag_id,dag, fileloc", get_dags(EXAMPLE_DAGS_DIR), ids=[x[2] for x in get_dags(EXAMPLE_DAGS_DIR)])
def test_dag_runs(setup_airflow_db, dag_id, dag, fileloc):
# Run the example dags
test_utils.run_dag(dag)

0 comments on commit 2606b24

Please sign in to comment.