From e4fb38e9455f3385f8c48192ad54941c53816b40 Mon Sep 17 00:00:00 2001 From: Pankaj Singh <98807258+pankajastro@users.noreply.github.com> Date: Wed, 26 Jun 2024 20:06:00 +0530 Subject: [PATCH] Add trigger tests (#32) --- codespell-ignore-words.txt | 1 + tests/triggers/test_anyscale_triggers.py | 153 ++++++++++++++++++++++- 2 files changed, 152 insertions(+), 2 deletions(-) diff --git a/codespell-ignore-words.txt b/codespell-ignore-words.txt index bf52b4c..cb06a30 100644 --- a/codespell-ignore-words.txt +++ b/codespell-ignore-words.txt @@ -1 +1,2 @@ assertIn +asend diff --git a/tests/triggers/test_anyscale_triggers.py b/tests/triggers/test_anyscale_triggers.py index 040e77d..10e13ac 100644 --- a/tests/triggers/test_anyscale_triggers.py +++ b/tests/triggers/test_anyscale_triggers.py @@ -1,9 +1,10 @@ import asyncio import unittest -from unittest.mock import MagicMock, PropertyMock, patch +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch from airflow.exceptions import AirflowNotFoundException -from anyscale.job.models import JobConfig, JobState, JobStatus +from airflow.triggers.base import TriggerEvent +from anyscale.job.models import JobConfig, JobRunStatus, JobState, JobStatus from anyscale.service.models import ServiceState from anyscale_provider.triggers.anyscale import AnyscaleJobTrigger, AnyscaleServiceTrigger @@ -141,6 +142,77 @@ async def test_anyscale_run_trigger(self, mocked_sleep, mocked_get_job_logs, moc self.assertEqual(result.payload["message"], "Job 1234 completed with status JobState.SUCCEEDED.") self.assertEqual(result.payload["job_id"], "1234") + @patch("anyscale_provider.triggers.anyscale.AnyscaleHook.get_job_status") + @patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger._is_terminal_state") + def test_run_success(self, mock_terminal_state, mock_hook): + trigger = AnyscaleJobTrigger(conn_id="test_conn", job_id="test_job", poll_interval=1, fetch_logs=False) + mock_terminal_state.return_value = True + mock_hook.return_value = JobStatus( + id="test_job", state=JobState.SUCCEEDED, name="", config=JobConfig(entrypoint="122"), runs=[] + ) + + async def run_test(): + generator = trigger.run() + result = await generator.asend(None) + assert result == TriggerEvent( + {"state": "SUCCEEDED", "message": "Job test_job completed with state SUCCEEDED.", "job_id": "test_job"} + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(run_test()) + + @patch("anyscale_provider.triggers.anyscale.asyncio.get_event_loop") + @patch("anyscale_provider.triggers.anyscale.AnyscaleHook.get_job_status") + @patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger._is_terminal_state") + def test_run_success_fetch_log(self, mock_terminal_state, mock_hook, mock_asyncio_loop): + trigger = AnyscaleJobTrigger(conn_id="test_conn", job_id="test_job", poll_interval=1, fetch_logs=True) + mock_terminal_state.return_value = True + mock_hook.return_value = JobStatus( + id="test_job", + state=JobState.SUCCEEDED, + name="", + config=JobConfig(entrypoint="122"), + runs=[JobRunStatus(name="test", state="SUCCEEDED")], + ) + mock_loop = AsyncMock() + mock_asyncio_loop.return_value = mock_loop + mock_loop.run_in_executor.side_effect = "hello\n" + + async def run_test(): + generator = trigger.run() + result = await generator.asend(None) + assert result == TriggerEvent( + {"state": "SUCCEEDED", "message": "Job test_job completed with state SUCCEEDED.", "job_id": "test_job"} + ) + mock_asyncio_loop.assert_called_once() + mock_loop.run_in_executor.assert_called_once() + mock_loop.run_in_executor.return_value = [] + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(run_test()) + + @patch("anyscale_provider.triggers.anyscale.AnyscaleHook.get_job_status") + @patch("anyscale_provider.triggers.anyscale.AnyscaleJobTrigger._is_terminal_state") + def test_run_error(self, mock_terminal_state, mock_hook): + trigger = AnyscaleJobTrigger(conn_id="test_conn", job_id="test_job", poll_interval=1, fetch_logs=False) + mock_terminal_state.return_value = True + mock_hook.return_value = JobStatus( + id="test_job", state=JobState.FAILED, name="", config=JobConfig(entrypoint="122"), runs=[] + ) + + async def run_test(): + generator = trigger.run() + result = await generator.asend(None) + assert result == TriggerEvent( + {"state": "FAILED", "message": "Job test_job completed with state FAILED.", "job_id": "test_job"} + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(run_test()) + class TestAnyscaleServiceTrigger(unittest.TestCase): def setUp(self): @@ -285,6 +357,83 @@ def test_get_current_status_canary_100_percent(self, mock_get_service_status): # Ensure the mock was called correctly mock_get_service_status.assert_called_once_with("AstroService") + @patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._get_current_state") + @patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._check_current_state") + def test_run_success(self, mock_check_current_state, mock_get_current_state): + trigger = AnyscaleServiceTrigger( + conn_id="default_conn", + service_name="AstroService", + expected_state=ServiceState.RUNNING, + canary_percent=100.0, + ) + mock_check_current_state.return_value = False + mock_get_current_state.return_value = ServiceState.RUNNING + + async def run_test(): + generator = trigger.run() + result = await generator.asend(None) + assert result == TriggerEvent( + { + "state": ServiceState.RUNNING, + "message": "Service deployment succeeded", + "service_name": "AstroService", + } + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(run_test()) + + @patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._get_current_state") + @patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._check_current_state") + def test_run_failure(self, mock_check_current_state, mock_get_current_state): + trigger = AnyscaleServiceTrigger( + conn_id="default_conn", + service_name="AstroService", + expected_state=ServiceState.RUNNING, + canary_percent=100.0, + ) + mock_check_current_state.return_value = False + mock_get_current_state.return_value = ServiceState.UNKNOWN + + async def run_test(): + generator = trigger.run() + result = await generator.asend(None) + assert result == TriggerEvent( + { + "state": ServiceState.SYSTEM_FAILURE, + "message": "Service AstroService entered an unexpected state: UNKNOWN", + "service_name": "AstroService", + } + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(run_test()) + + @patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._get_current_state") + @patch("anyscale_provider.triggers.anyscale.AnyscaleServiceTrigger._check_current_state") + def test_run_service_exception(self, mock_check_current_state, mock_get_current_state): + trigger = AnyscaleServiceTrigger( + conn_id="default_conn", + service_name="AstroService", + expected_state=ServiceState.RUNNING, + canary_percent=100.0, + ) + mock_check_current_state.return_value = False + mock_get_current_state.side_effect = Exception("Unknown error") + + async def run_test(): + generator = trigger.run() + result = await generator.asend(None) + assert result == TriggerEvent( + {"state": ServiceState.SYSTEM_FAILURE, "message": "Unknown error", "service_name": "AstroService"} + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(run_test()) + if __name__ == "__main__": unittest.main()