Skip to content

Commit

Permalink
Issue #70 finetune load_ml_model handling a bit more (PR #71)
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Sep 15, 2022
1 parent 8303974 commit a5a5338
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Properly rewrite model id in `load_ml_model` ([#70](https://github.com/Open-EO/openeo-aggregator/issues/70))

## [0.4.x]

### Added
Expand Down
39 changes: 23 additions & 16 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,9 @@ def get_backend_for_process_graph(self, process_graph: dict, api_version: str) -
)
backend_candidates = [b for b in backend_candidates if b == job_backend_id]
elif process_id == "load_ml_model":
if not arguments["id"].startswith("http"):
# Extract backend id that can load this ML model.
_, job_backend_id = JobIdMapping.parse_aggregator_job_id(
backends=self.backends,
aggregator_job_id=arguments["id"]
)
backend_candidates = [b for b in backend_candidates if b == job_backend_id]
model_backend_id = self._process_load_ml_model(arguments)[0]
if model_backend_id:
backend_candidates = [b for b in backend_candidates if b == model_backend_id]
except Exception as e:
_log.error(f"Failed to parse process graph: {e!r}", exc_info=True)
raise ProcessGraphInvalidException()
Expand Down Expand Up @@ -443,22 +439,33 @@ def preprocess(node: Any) -> Any:
assert job_backend_id == backend_id, f"{job_backend_id} != {backend_id}"
# Create new load_result node dict with updated job id
return dict_merge(node, arguments=dict_merge(arguments, id=job_id))
if process_id == "load_ml_model" and "id" in arguments:
if not arguments["id"].startswith("http"):
job_id, job_backend_id = JobIdMapping.parse_aggregator_job_id(
backends=self.backends,
aggregator_job_id=arguments["id"]
)
assert job_backend_id == backend_id, f"{job_backend_id} != {backend_id}"
# Create new load_ml_model node dict with updated job id
return dict_merge(node, arguments=dict_merge(arguments, id=job_id))
if process_id == "load_ml_model":
model_id = self._process_load_ml_model(arguments, expected_backend=backend_id)[1]
if model_id:
return dict_merge(node, arguments=dict_merge(arguments, id=model_id))
return {k: preprocess(v) for k, v in node.items()}
elif isinstance(node, list):
return [preprocess(x) for x in node]
return node

return preprocess(process_graph)

def _process_load_ml_model(
self, arguments: dict, expected_backend: Optional[str] = None
) -> Tuple[Union[str, None], str]:
"""Handle load_ml_model: detect/strip backend_id from model_id if it is a job_id"""
model_id = arguments.get("id")
if model_id and not model_id.startswith("http"):
# TODO: load_ml_model's `id` could also be file path (see https://github.com/Open-EO/openeo-processes/issues/384)
job_id, job_backend_id = JobIdMapping.parse_aggregator_job_id(
backends=self.backends,
aggregator_job_id=model_id
)
if expected_backend and job_backend_id != expected_backend:
raise BackendLookupFailureException(f"{job_backend_id} != {expected_backend}")
return job_backend_id, job_id
return None, model_id


class AggregatorBatchJobs(BatchJobs):

Expand Down

0 comments on commit a5a5338

Please sign in to comment.