Skip to content

Commit

Permalink
Issue #95 basic test for crossbackend run_partitioned_job
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Mar 21, 2023
1 parent 418f8fa commit 9f5322e
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 4 deletions.
12 changes: 10 additions & 2 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import logging
import time
from contextlib import nullcontext
from typing import Callable, Dict, List, Sequence

import openeo
Expand Down Expand Up @@ -180,7 +181,9 @@ def _loop():
yield i


def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection) -> dict:
def run_partitioned_job(
pjob: PartitionedJob, connection: openeo.Connection, fail_fast: bool = True
) -> dict:
"""
Run partitioned job (probably with dependencies between subjobs)
Expand All @@ -202,7 +205,10 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection) ->
# Map subjob_id to a batch job instances
batch_jobs: Dict[str, BatchJob] = {}

skip_intermittent_failures = SkipIntermittentFailures(limit=3)
if not fail_fast:
skip_intermittent_failures = SkipIntermittentFailures(limit=3)
else:
skip_intermittent_failures = nullcontext()

for _ in _loop():
need_sleep = True
Expand Down Expand Up @@ -251,6 +257,8 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection) ->
f"Started batch job {batch_job.job_id!r} for subjob {subjob_id!r}"
)
except Exception as e:
if fail_fast:
raise
states[subjob_id] = SUBJOB_STATES.ERROR
_log.warning(
f"Failed to start batch job for subjob {subjob_id!r}: {e}",
Expand Down
126 changes: 124 additions & 2 deletions tests/partitionedjobs/test_crossbackend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
from openeo_aggregator.partitionedjobs import SubJob
from openeo_aggregator.partitionedjobs.crossbackend import CrossBackendSplitter
import dataclasses
import re
from typing import Dict, List, Optional
from unittest import mock

import openeo
import pytest
import requests
import requests_mock
from openeo_driver.jobregistry import JOB_STATUS
from openeo_driver.testing import DictSubSet

from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
run_partitioned_job,
)


class TestCrossBackendSplitter:
Expand Down Expand Up @@ -77,3 +92,110 @@ def test_basic(self):
),
}
assert res.dependencies == {"main": ["B2:lc2"]}


@dataclasses.dataclass
class _FakeJob:
pg: Optional[dict] = None
created: bool = False
started: bool = False
status_train: List[str] = dataclasses.field(
default_factory=lambda: [
JOB_STATUS.QUEUED,
JOB_STATUS.RUNNING,
JOB_STATUS.FINISHED,
]
)


class _FakeAggregator:
def __init__(self, url: str = "http://oeoa.test"):
self.url = url
self.jobs: Dict[str, _FakeJob] = {}
self.job_status_stacks = {}

def setup_requests_mock(self, requests_mock: requests_mock.Mocker):
requests_mock.get(f"{self.url}/", json={"api_version": "1.1.0"})
requests_mock.post(f"{self.url}/jobs", text=self._handle_job_create)
requests_mock.get(re.compile("/jobs/([^/]*)$"), json=self._handle_job_status)
requests_mock.post(
re.compile("/jobs/([^/]*)/results$"), text=self._handle_job_start
)

def _handle_job_create(self, request: requests.Request, context):
pg = request.json()["process"]["process_graph"]
# Determine job id based on used collection id
cids = "-".join(
sorted(
n["arguments"]["id"]
for n in pg.values()
if n["process_id"] == "load_collection"
)
)
assert cids
job_id = f"job-{cids}".lower()
if job_id in self.jobs:
assert not self.jobs[job_id].created
self.jobs[job_id].pg = pg
self.jobs[job_id].created = True
else:
self.jobs[job_id] = _FakeJob(pg=pg, created=True)

context.headers["Location"] = f"{self.url}/jobs/{job_id}"
context.headers["OpenEO-Identifier"] = job_id
context.status_code = 201

def _job_id_from_request(self, request: requests.Request) -> str:
return re.search(r"/jobs/([^/]*)", request.path).group(1).lower()

def _handle_job_status(self, request: requests.Request, context):
job_id = self._job_id_from_request(request)
job = self.jobs[job_id]
assert job.created
if job.started:
if len(job.status_train) > 1:
status = job.status_train.pop(0)
elif len(job.status_train) == 1:
status = job.status_train[0]
else:
status = JOB_STATUS.ERROR
else:
status = JOB_STATUS.CREATED
return {"status": status}

def _handle_job_start(self, request: requests.Request, context):
job_id = self._job_id_from_request(request)
self.jobs[job_id].started = True
context.status_code = 202


class TestRunPartitionedJobs:
@pytest.fixture
def aggregator(self, requests_mock) -> _FakeAggregator:
aggregator = _FakeAggregator()
aggregator.setup_requests_mock(requests_mock=requests_mock)
return aggregator

def test_simple(self, aggregator):
connection = openeo.Connection(aggregator.url)
pjob = PartitionedJob(
process={},
metadata={},
job_options={},
subjobs={
"one": SubJob(
process_graph={
"lc": {
"process_id": "load_collection",
"arguments": {"id": "S2"},
}
},
backend_id="b1",
)
},
)

with mock.patch("time.sleep") as sleep:
res = run_partitioned_job(pjob=pjob, connection=connection, fail_fast=True)
assert res == {"one": DictSubSet({"state": "finished"})}
assert sleep.call_count >= 1

0 comments on commit 9f5322e

Please sign in to comment.