From 9f5322e50fd7a46e18da85c8023825ffcbf3e973 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Tue, 21 Mar 2023 21:22:11 +0100 Subject: [PATCH] Issue #95 basic test for crossbackend run_partitioned_job --- .../partitionedjobs/crossbackend.py | 12 +- tests/partitionedjobs/test_crossbackend.py | 126 +++++++++++++++++- 2 files changed, 134 insertions(+), 4 deletions(-) diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index c3191600..909e22bc 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -4,6 +4,7 @@ import itertools import logging import time +from contextlib import nullcontext from typing import Callable, Dict, List, Sequence import openeo @@ -180,7 +181,9 @@ def _loop(): yield i -def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection) -> dict: +def run_partitioned_job( + pjob: PartitionedJob, connection: openeo.Connection, fail_fast: bool = True +) -> dict: """ Run partitioned job (probably with dependencies between subjobs) @@ -202,7 +205,10 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection) -> # Map subjob_id to a batch job instances batch_jobs: Dict[str, BatchJob] = {} - skip_intermittent_failures = SkipIntermittentFailures(limit=3) + if not fail_fast: + skip_intermittent_failures = SkipIntermittentFailures(limit=3) + else: + skip_intermittent_failures = nullcontext() for _ in _loop(): need_sleep = True @@ -251,6 +257,8 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection) -> f"Started batch job {batch_job.job_id!r} for subjob {subjob_id!r}" ) except Exception as e: + if fail_fast: + raise states[subjob_id] = SUBJOB_STATES.ERROR _log.warning( f"Failed to start batch job for subjob {subjob_id!r}: {e}", diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index 7d1d23a3..c1f43dcc 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -1,5 +1,20 @@ -from openeo_aggregator.partitionedjobs import SubJob -from openeo_aggregator.partitionedjobs.crossbackend import CrossBackendSplitter +import dataclasses +import re +from typing import Dict, List, Optional +from unittest import mock + +import openeo +import pytest +import requests +import requests_mock +from openeo_driver.jobregistry import JOB_STATUS +from openeo_driver.testing import DictSubSet + +from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob +from openeo_aggregator.partitionedjobs.crossbackend import ( + CrossBackendSplitter, + run_partitioned_job, +) class TestCrossBackendSplitter: @@ -77,3 +92,110 @@ def test_basic(self): ), } assert res.dependencies == {"main": ["B2:lc2"]} + + +@dataclasses.dataclass +class _FakeJob: + pg: Optional[dict] = None + created: bool = False + started: bool = False + status_train: List[str] = dataclasses.field( + default_factory=lambda: [ + JOB_STATUS.QUEUED, + JOB_STATUS.RUNNING, + JOB_STATUS.FINISHED, + ] + ) + + +class _FakeAggregator: + def __init__(self, url: str = "http://oeoa.test"): + self.url = url + self.jobs: Dict[str, _FakeJob] = {} + self.job_status_stacks = {} + + def setup_requests_mock(self, requests_mock: requests_mock.Mocker): + requests_mock.get(f"{self.url}/", json={"api_version": "1.1.0"}) + requests_mock.post(f"{self.url}/jobs", text=self._handle_job_create) + requests_mock.get(re.compile("/jobs/([^/]*)$"), json=self._handle_job_status) + requests_mock.post( + re.compile("/jobs/([^/]*)/results$"), text=self._handle_job_start + ) + + def _handle_job_create(self, request: requests.Request, context): + pg = request.json()["process"]["process_graph"] + # Determine job id based on used collection id + cids = "-".join( + sorted( + n["arguments"]["id"] + for n in pg.values() + if n["process_id"] == "load_collection" + ) + ) + assert cids + job_id = f"job-{cids}".lower() + if job_id in self.jobs: + assert not self.jobs[job_id].created + self.jobs[job_id].pg = pg + self.jobs[job_id].created = True + else: + self.jobs[job_id] = _FakeJob(pg=pg, created=True) + + context.headers["Location"] = f"{self.url}/jobs/{job_id}" + context.headers["OpenEO-Identifier"] = job_id + context.status_code = 201 + + def _job_id_from_request(self, request: requests.Request) -> str: + return re.search(r"/jobs/([^/]*)", request.path).group(1).lower() + + def _handle_job_status(self, request: requests.Request, context): + job_id = self._job_id_from_request(request) + job = self.jobs[job_id] + assert job.created + if job.started: + if len(job.status_train) > 1: + status = job.status_train.pop(0) + elif len(job.status_train) == 1: + status = job.status_train[0] + else: + status = JOB_STATUS.ERROR + else: + status = JOB_STATUS.CREATED + return {"status": status} + + def _handle_job_start(self, request: requests.Request, context): + job_id = self._job_id_from_request(request) + self.jobs[job_id].started = True + context.status_code = 202 + + +class TestRunPartitionedJobs: + @pytest.fixture + def aggregator(self, requests_mock) -> _FakeAggregator: + aggregator = _FakeAggregator() + aggregator.setup_requests_mock(requests_mock=requests_mock) + return aggregator + + def test_simple(self, aggregator): + connection = openeo.Connection(aggregator.url) + pjob = PartitionedJob( + process={}, + metadata={}, + job_options={}, + subjobs={ + "one": SubJob( + process_graph={ + "lc": { + "process_id": "load_collection", + "arguments": {"id": "S2"}, + } + }, + backend_id="b1", + ) + }, + ) + + with mock.patch("time.sleep") as sleep: + res = run_partitioned_job(pjob=pjob, connection=connection, fail_fast=True) + assert res == {"one": DictSubSet({"state": "finished"})} + assert sleep.call_count >= 1