Skip to content

Commit

Permalink
Merge branch '70-random-forest-providing-training-job-with-aggregator…
Browse files Browse the repository at this point in the history
…-job-id-fails'
  • Loading branch information
soxofaan committed Sep 15, 2022
2 parents 2500468 + 0c8541c commit 24746a5
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Properly rewrite model id in `load_ml_model` ([#70](https://github.com/Open-EO/openeo-aggregator/issues/70))

## [0.4.x]

### Added
Expand Down
24 changes: 24 additions & 0 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ def get_backend_for_process_graph(self, process_graph: dict, api_version: str) -
aggregator_job_id=arguments["id"]
)
backend_candidates = [b for b in backend_candidates if b == job_backend_id]
elif process_id == "load_ml_model":
model_backend_id = self._process_load_ml_model(arguments)[0]
if model_backend_id:
backend_candidates = [b for b in backend_candidates if b == model_backend_id]
except Exception as e:
_log.error(f"Failed to parse process graph: {e!r}", exc_info=True)
raise ProcessGraphInvalidException()
Expand Down Expand Up @@ -435,13 +439,33 @@ def preprocess(node: Any) -> Any:
assert job_backend_id == backend_id, f"{job_backend_id} != {backend_id}"
# Create new load_result node dict with updated job id
return dict_merge(node, arguments=dict_merge(arguments, id=job_id))
if process_id == "load_ml_model":
model_id = self._process_load_ml_model(arguments, expected_backend=backend_id)[1]
if model_id:
return dict_merge(node, arguments=dict_merge(arguments, id=model_id))
return {k: preprocess(v) for k, v in node.items()}
elif isinstance(node, list):
return [preprocess(x) for x in node]
return node

return preprocess(process_graph)

def _process_load_ml_model(
self, arguments: dict, expected_backend: Optional[str] = None
) -> Tuple[Union[str, None], str]:
"""Handle load_ml_model: detect/strip backend_id from model_id if it is a job_id"""
model_id = arguments.get("id")
if model_id and not model_id.startswith("http"):
# TODO: load_ml_model's `id` could also be file path (see https://github.com/Open-EO/openeo-processes/issues/384)
job_id, job_backend_id = JobIdMapping.parse_aggregator_job_id(
backends=self.backends,
aggregator_job_id=model_id
)
if expected_backend and job_backend_id != expected_backend:
raise BackendLookupFailureException(f"{job_backend_id} != {expected_backend}")
return job_backend_id, job_id
return None, model_id


class AggregatorBatchJobs(BatchJobs):

Expand Down
41 changes: 37 additions & 4 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,10 +670,43 @@ def post_result(request: requests.Request, context):
pg = {
"lr": {"process_id": "load_result", "arguments": {"id": job_id}},
"lc": {"process_id": "load_collection", "arguments": {"id": "S2"}},
"merge": {"process_id": "merge_cubes", "arguments": {
"cube1": {"from_node": "lr"},
"cube2": {"from_node": "lc"}
}}
}
api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN)
request = {"process": {"process_graph": pg}}
if expected_success:
api100.post("/result", json=request).assert_status_code(200)
assert (b1_mock.call_count, b2_mock.call_count) == {1: (1, 0), 2: (0, 1)}[s2_backend]
else:
api100.post("/result", json=request).assert_error(400, "BackendLookupFailure")
assert (b1_mock.call_count, b2_mock.call_count) == (0, 0)

@pytest.mark.parametrize(["job_id", "s2_backend", "expected_success"], [
("b1-b6tch-j08", 1, True),
("b2-b6tch-j08", 1, False),
("b1-b6tch-j08", 2, False),
("b2-b6tch-j08", 2, True),
("https://example.com/ml_model_metadata.json", 1, True), # In this case it picks the first backend.
("https://example.com/ml_model_metadata.json", 2, True),
])
def test_load_result_job_id_parsing_with_load_ml_model(
self, api100, requests_mock, backend1, backend2, job_id, s2_backend, expected_success
):
"""Issue #70: random forest: providing training job with aggregator job id fails"""

backend_root = {1: backend1, 2: backend2}[s2_backend]
requests_mock.get(backend_root + "/collections", json={"collections": [{"id": "S2"}]})

def post_result(request: requests.Request, context):
pg = request.json()["process"]["process_graph"]
assert pg["lmm"]["arguments"]["id"] in ["b6tch-j08", "https://example.com/ml_model_metadata.json"]
context.headers["Content-Type"] = "application/json"

b1_mock = requests_mock.post(backend1 + "/result", json=post_result)
b2_mock = requests_mock.post(backend2 + "/result", json=post_result)

pg = {
"lmm": {"process_id": "load_ml_model", "arguments": {"id": job_id}},
"lc": {"process_id": "load_collection", "arguments": {"id": "S2"}},
}
api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN)
request = {"process": {"process_graph": pg}}
Expand Down

0 comments on commit 24746a5

Please sign in to comment.