Skip to content

Commit

Permalink
Add trigger tests (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro authored Jun 26, 2024
1 parent 8aa0a0f commit e4fb38e
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 2 deletions.
1 change: 1 addition & 0 deletions codespell-ignore-words.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
assertIn
asend
153 changes: 151 additions & 2 deletions tests/triggers/test_anyscale_triggers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import unittest
from unittest.mock import MagicMock, PropertyMock, patch
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch

from airflow.exceptions import AirflowNotFoundException
from anyscale.job.models import JobConfig, JobState, JobStatus
from airflow.triggers.base import TriggerEvent
from anyscale.job.models import JobConfig, JobRunStatus, JobState, JobStatus
from anyscale.service.models import ServiceState

from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger, AnyscaleServiceTrigger
Expand Down Expand Up @@ -141,6 +142,77 @@ async def test_anyscale_run_trigger(self, mocked_sleep, mocked_get_job_logs, moc
self.assertEqual(result.payload["message"], "Job 1234 completed with status JobState.SUCCEEDED.")
self.assertEqual(result.payload["job_id"], "1234")

@patch("anyscale_provider.triggers.anyscale.AnyscaleHook.get_job_status")
@patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger._is_terminal_state")
def test_run_success(self, mock_terminal_state, mock_hook):
trigger = AnyscaleJobTrigger(conn_id="test_conn", job_id="test_job", poll_interval=1, fetch_logs=False)
mock_terminal_state.return_value = True
mock_hook.return_value = JobStatus(
id="test_job", state=JobState.SUCCEEDED, name="", config=JobConfig(entrypoint="122"), runs=[]
)

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{"state": "SUCCEEDED", "message": "Job test_job completed with state SUCCEEDED.", "job_id": "test_job"}
)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())

@patch("anyscale_provider.triggers.anyscale.asyncio.get_event_loop")
@patch("anyscale_provider.triggers.anyscale.AnyscaleHook.get_job_status")
@patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger._is_terminal_state")
def test_run_success_fetch_log(self, mock_terminal_state, mock_hook, mock_asyncio_loop):
trigger = AnyscaleJobTrigger(conn_id="test_conn", job_id="test_job", poll_interval=1, fetch_logs=True)
mock_terminal_state.return_value = True
mock_hook.return_value = JobStatus(
id="test_job",
state=JobState.SUCCEEDED,
name="",
config=JobConfig(entrypoint="122"),
runs=[JobRunStatus(name="test", state="SUCCEEDED")],
)
mock_loop = AsyncMock()
mock_asyncio_loop.return_value = mock_loop
mock_loop.run_in_executor.side_effect = "hello\n"

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{"state": "SUCCEEDED", "message": "Job test_job completed with state SUCCEEDED.", "job_id": "test_job"}
)
mock_asyncio_loop.assert_called_once()
mock_loop.run_in_executor.assert_called_once()
mock_loop.run_in_executor.return_value = []

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())

@patch("anyscale_provider.triggers.anyscale.AnyscaleHook.get_job_status")
@patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger._is_terminal_state")
def test_run_error(self, mock_terminal_state, mock_hook):
trigger = AnyscaleJobTrigger(conn_id="test_conn", job_id="test_job", poll_interval=1, fetch_logs=False)
mock_terminal_state.return_value = True
mock_hook.return_value = JobStatus(
id="test_job", state=JobState.FAILED, name="", config=JobConfig(entrypoint="122"), runs=[]
)

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{"state": "FAILED", "message": "Job test_job completed with state FAILED.", "job_id": "test_job"}
)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())


class TestAnyscaleServiceTrigger(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -285,6 +357,83 @@ def test_get_current_status_canary_100_percent(self, mock_get_service_status):
# Ensure the mock was called correctly
mock_get_service_status.assert_called_once_with("AstroService")

@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._get_current_state")
@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._check_current_state")
def test_run_success(self, mock_check_current_state, mock_get_current_state):
trigger = AnyscaleServiceTrigger(
conn_id="default_conn",
service_name="AstroService",
expected_state=ServiceState.RUNNING,
canary_percent=100.0,
)
mock_check_current_state.return_value = False
mock_get_current_state.return_value = ServiceState.RUNNING

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{
"state": ServiceState.RUNNING,
"message": "Service deployment succeeded",
"service_name": "AstroService",
}
)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())

@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._get_current_state")
@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._check_current_state")
def test_run_failure(self, mock_check_current_state, mock_get_current_state):
trigger = AnyscaleServiceTrigger(
conn_id="default_conn",
service_name="AstroService",
expected_state=ServiceState.RUNNING,
canary_percent=100.0,
)
mock_check_current_state.return_value = False
mock_get_current_state.return_value = ServiceState.UNKNOWN

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{
"state": ServiceState.SYSTEM_FAILURE,
"message": "Service AstroService entered an unexpected state: UNKNOWN",
"service_name": "AstroService",
}
)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())

@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._get_current_state")
@patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._check_current_state")
def test_run_service_exception(self, mock_check_current_state, mock_get_current_state):
trigger = AnyscaleServiceTrigger(
conn_id="default_conn",
service_name="AstroService",
expected_state=ServiceState.RUNNING,
canary_percent=100.0,
)
mock_check_current_state.return_value = False
mock_get_current_state.side_effect = Exception("Unknown error")

async def run_test():
generator = trigger.run()
result = await generator.asend(None)
assert result == TriggerEvent(
{"state": ServiceState.SYSTEM_FAILURE, "message": "Unknown error", "service_name": "AstroService"}
)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_test())


if __name__ == "__main__":
unittest.main()

0 comments on commit e4fb38e

Please sign in to comment.