Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jun 7, 2024
1 parent 19e6406 commit 9a85335
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
path: |
~/.cache/pip
.nox
key: unit-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.airflow-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('cosmos/__init__.py') }}
key: unit-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.airflow-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('anyscale_provider/__init__.py') }}

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down Expand Up @@ -120,7 +120,7 @@ jobs:
- name: Test Cosmos against Airflow ${{ matrix.airflow-version }} and Python ${{ matrix.python-version }}
run: |
hatch run tests.py${{ matrix.python-version }}-${{ matrix.airflow-version }}:integration_test
hatch run tests.py${{ matrix.python-version }}-${{ matrix.airflow-version }}:test-integration
env:
AIRFLOW_HOME: /home/runner/work/astro-provider-anyscale/anyscale_provider/
AIRFLOW_CONN_AIRFLOW_DB: postgres://postgres:[email protected]:5432/postgres
Expand Down
17 changes: 9 additions & 8 deletions anyscale_provider/triggers/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from functools import partial
from datetime import datetime, timedelta
from typing import Any, Dict, AsyncIterator, Literal
from typing import Any, Dict, AsyncIterator, Tuple

from anyscale.job.models import JobState
from anyscale.service.models import ServiceState
Expand Down Expand Up @@ -34,7 +34,7 @@ class AnyscaleJobTrigger(BaseTrigger):
:raises AirflowException: If no job_id is provided or an error occurs during polling.
"""

def __init__(self, conn_id: str, job_id: str, job_start_time: float, poll_interval: int = 60, timeout: int = 3600) -> None:
def __init__(self, conn_id: str, job_id: str, job_start_time: float, poll_interval: int = 60, timeout: int = 3600):
super().__init__()
self.conn_id = conn_id
self.job_id = job_id
Expand All @@ -48,7 +48,7 @@ def hook(self) -> AnyscaleHook:
"""Return an instance of the AnyscaleHook."""
return AnyscaleHook(conn_id=self.conn_id)

def serialize(self) -> tuple[str, dict[str, Any]]:
def serialize(self) -> Tuple[str, Dict[str, Any]]:
return ("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger", {
"conn_id": self.conn_id,
"job_id": self.job_id,
Expand Down Expand Up @@ -99,7 +99,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
})

def get_current_status(self, job_id: str) -> str:
return self.hook.get_job_status(job_id=job_id).state
job_status = self.hook.get_job_status(job_id=job_id).state
return str(job_status)

def is_terminal_status(self, job_id: str) -> bool:
job_status = self.get_current_status(job_id)
Expand Down Expand Up @@ -134,7 +135,7 @@ def __init__(self,
expected_state: str,
canary_percent: float,
poll_interval: int = 60,
timeout: int = 600) -> None:
timeout: int = 600):
super().__init__()
self.conn_id = conn_id
self.service_name = service_name
Expand All @@ -149,7 +150,7 @@ def hook(self) -> AnyscaleHook:
"""Return an instance of the AnyscaleHook."""
return AnyscaleHook(conn_id=self.conn_id)

def serialize(self) -> tuple[str, dict[str, Any]]:
def serialize(self) -> Tuple[str, Dict[str, Any]]:
return ("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger", {
"conn_id": self.conn_id,
"service_name": self.service_name,
Expand Down Expand Up @@ -201,9 +202,9 @@ def get_current_status(self, service_name: str) -> str:
service_status = self.hook.get_service_status(service_name)

if self.canary_percent is None or 0.0 < self.canary_percent < 100.0:
return service_status.canary_version.state
return str(service_status.canary_version.state)
else:
return service_status.state
return str(service_status.state)

def check_current_status(self, service_name: str) -> bool:
job_status = self.get_current_status(service_name)
Expand Down

0 comments on commit 9a85335

Please sign in to comment.