diff --git a/CHANGELOG.md b/CHANGELOG.md index 73d99d009..f3d155faf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#412](https://github.com/Open-EO/openeo-python-client/issues/412)). - More robust handling of billing currency/plans in capabilities ([#414](https://github.com/Open-EO/openeo-python-client/issues/414)) +- Avoid blindly adding a `save_result` node from `DataCube.execute_batch()` when there is already one + ([#401](https://github.com/Open-EO/openeo-python-client/issues/401)) ## [0.15.0] - 2023-03-03 diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 730c3e01e..497f30cf0 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -1131,7 +1131,7 @@ def download( graph: Union[dict, str, Path], outputfile: Union[Path, str, None] = None, timeout: int = 30 * 60, - ): + ) -> Union[None, bytes]: """ Downloads the result of a process graph synchronously, and save the result to the given file or return bytes object if no outputfile is specified. diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index c724c6518..d6242dd5d 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -58,6 +58,9 @@ class DataCube(_ProcessGraphAbstraction): and this process graph can be "grown" to a desired workflow by calling the appropriate methods. """ + # TODO: set this based on back-end or user preference? + _DEFAULT_RASTER_FORMAT = "GTiff" + def __init__(self, graph: PGNode, connection: 'openeo.Connection', metadata: CollectionMetadata = None): super().__init__(pgnode=graph, connection=connection) self.metadata = CollectionMetadata.get_or_create(metadata) @@ -1810,8 +1813,13 @@ def atmospheric_correction( }) @openeo_process - def save_result(self, format: str = "GTiff", options: dict = None) -> 'DataCube': + def save_result( + self, + format: str = _DEFAULT_RASTER_FORMAT, + options: Optional[dict] = None, + ) -> "DataCube": formats = set(self._connection.list_output_formats().keys()) + # TODO: map format to correct casing too? if format.lower() not in {f.lower() for f in formats}: raise ValueError("Invalid format {f!r}. Should be one of {s}".format(f=format, s=formats)) return self.process( @@ -1819,27 +1827,31 @@ def save_result(self, format: str = "GTiff", options: dict = None) -> 'DataCube' arguments={ "data": THIS, "format": format, + # TODO: leave out options if unset? "options": options or {} } ) - def download( - self, outputfile: Union[str, pathlib.Path, None] = None, format: Optional[str] = None, - options: Optional[dict] = None - ): + def _ensure_save_result( + self, + format: Optional[str] = None, + options: Optional[dict] = None, + ) -> "DataCube": """ - Download image collection, e.g. as GeoTIFF. - If outputfile is provided, the result is stored on disk locally, otherwise, a bytes object is returned. - The bytes object can be passed on to a suitable decoder for decoding. + Make sure there is a (final) `save_result` node in the process graph. + If there is already one: check if it is consistent with the given format/options (if any) + and add a new one otherwise. - :param outputfile: Optional, an output file if the result needs to be stored on disk. - :param format: Optional, an output format supported by the backend. - :param options: Optional, file format options - :return: None if the result is stored to disk, or a bytes object returned by the backend. + :param format: (optional) desired `save_result` file format + :param options: (optional) desired `save_result` file format parameters + :return: """ - if self.result_node().process_id == "save_result": - # There is already a `save_result` node: check if it is consistent with given format/options - args = self.result_node().arguments + # TODO: move to generic data cube parent class (not only for raster cubes, but also vector cubes) + result_node = self.result_node() + if result_node.process_id == "save_result": + # There is already a `save_result` node: + # check if it is consistent with given format/options (if any) + args = result_node.arguments if format is not None and format.lower() != args["format"].lower(): raise ValueError( f"Existing `save_result` node with different format {args['format']!r} != {format!r}" @@ -1851,10 +1863,30 @@ def download( cube = self else: # No `save_result` node yet: automatically add it. - if not format: - format = guess_format(outputfile) if outputfile else "GTiff" - cube = self.save_result(format=format, options=options) + cube = self.save_result( + format=format or self._DEFAULT_RASTER_FORMAT, options=options + ) + return cube + + def download( + self, + outputfile: Optional[Union[str, pathlib.Path]] = None, + format: Optional[str] = None, + options: Optional[dict] = None, + ) -> Union[None, bytes]: + """ + Download the raster data cube, e.g. as GeoTIFF. + If outputfile is provided, the result is stored on disk locally, otherwise, a bytes object is returned. + The bytes object can be passed on to a suitable decoder for decoding. + :param outputfile: Optional, an output file if the result needs to be stored on disk. + :param format: Optional, an output format supported by the backend. + :param options: Optional, file format options + :return: None if the result is stored to disk, or a bytes object returned by the backend. + """ + if format is None and outputfile is not None: + format = guess_format(outputfile) + cube = self._ensure_save_result(format=format, options=options) return self._connection.download(cube.flat_graph(), outputfile) def validate(self) -> List[dict]: @@ -1869,27 +1901,35 @@ def tiled_viewing_service(self, type: str, **kwargs) -> Service: return self._connection.create_service(self.flat_graph(), type=type, **kwargs) def execute_batch( - self, - outputfile: Union[str, pathlib.Path] = None, out_format: str = None, - print=print, max_poll_interval=60, connection_retry_interval=30, - job_options=None, **format_options) -> BatchJob: + self, + outputfile: Optional[Union[str, pathlib.Path]] = None, + out_format: Optional[str] = None, + *, + print: typing.Callable[[str], None] = print, + max_poll_interval: float = 60, + connection_retry_interval: float = 30, + job_options: Optional[dict] = None, + # TODO: avoid `format_options` as keyword arguments + **format_options, + ) -> BatchJob: """ Evaluate the process graph by creating a batch job, and retrieving the results when it is finished. This method is mostly recommended if the batch job is expected to run in a reasonable amount of time. For very long-running jobs, you probably do not want to keep the client running. - :param job_options: :param outputfile: The path of a file to which a result can be written - :param out_format: (optional) Format of the job result. - :param format_options: String Parameters for the job result format - + :param out_format: (optional) File format to use for the job result. + :param job_options: """ if "format" in format_options and not out_format: out_format = format_options["format"] # align with 'download' call arg name - if not out_format: - out_format = guess_format(outputfile) if outputfile else "GTiff" - job = self.create_job(out_format, job_options=job_options, **format_options) + if not out_format and outputfile: + out_format = guess_format(outputfile) + + job = self.create_job( + out_format=out_format, job_options=job_options, **format_options + ) return job.run_synchronous( outputfile=outputfile, print=print, max_poll_interval=max_poll_interval, connection_retry_interval=connection_retry_interval @@ -1904,6 +1944,7 @@ def create_job( plan: Optional[str] = None, budget: Optional[float] = None, job_options: Optional[dict] = None, + # TODO: avoid `format_options` as keyword arguments **format_options, ) -> BatchJob: """ @@ -1914,22 +1955,18 @@ def create_job( it still needs to be started and tracked explicitly. Use :py:meth:`execute_batch` instead to have the openEO Python client take care of that job management. - :param out_format: String Format of the job result. + :param out_format: output file format. :param title: job title :param description: job description :param plan: billing plan :param budget: maximum cost the request is allowed to produce - :param job_options: A dictionary containing (custom) job options - :param format_options: String Parameters for the job result format + :param job_options: custom job options. :return: Created job. """ # TODO: add option to also automatically start the job? # TODO: avoid using all kwargs as format_options # TODO: centralize `create_job` for `DataCube`, `VectorCube`, `MlModel`, ... - cube = self - if out_format: - # add `save_result` node - cube = cube.save_result(format=out_format, options=format_options) + cube = self._ensure_save_result(format=out_format, options=format_options) return self._connection.create_job( process_graph=cube.flat_graph(), title=title, diff --git a/openeo/util.py b/openeo/util.py index b92fdb7d5..031944915 100644 --- a/openeo/util.py +++ b/openeo/util.py @@ -437,7 +437,7 @@ def deep_set(data: dict, *keys, value): raise ValueError("No keys given") -def guess_format(filename: Union[str, Path]): +def guess_format(filename: Union[str, Path]) -> str: """ Guess the output format from a given filename and return the corrected format. Any names not in the dict get passed through. diff --git a/tests/rest/datacube/test_datacube.py b/tests/rest/datacube/test_datacube.py index 90db74fea..37c988da0 100644 --- a/tests/rest/datacube/test_datacube.py +++ b/tests/rest/datacube/test_datacube.py @@ -4,9 +4,9 @@ - 1.0.0-style DataCube """ - -from datetime import date, datetime import pathlib +from datetime import date, datetime +from unittest import mock import numpy as np import pytest @@ -16,9 +16,10 @@ from openeo.capabilities import ComparableVersion from openeo.rest import BandMathException from openeo.rest.datacube import DataCube -from .conftest import API_URL -from .. import get_download_graph + from ... import load_json_resource +from .. import get_download_graph +from .conftest import API_URL def test_apply_dimension_temporal_cumsum(s2cube, api_version): @@ -446,3 +447,192 @@ def result_callback(request, context): requests_mock.post(API_URL + '/result', content=result_callback) result = connection.load_collection("S2").download(format=format) assert result == b"data" + + +class TestExecuteBatch: + @pytest.fixture + def get_create_job_pg(self, connection): + """Fixture to help intercepting the process graph that was passed to Connection.create_job""" + with mock.patch.object(connection, "create_job") as create_job: + + def get() -> dict: + assert create_job.call_count == 1 + return create_job.call_args[1]["process_graph"] + + yield get + + def test_create_job_defaults(self, s2cube, get_create_job_pg, recwarn, caplog): + s2cube.create_job() + pg = get_create_job_pg() + assert set(pg.keys()) == {"loadcollection1", "saveresult1"} + assert pg["saveresult1"] == { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": "GTiff", + "options": {}, + }, + "result": True, + } + assert recwarn.list == [] + assert caplog.records == [] + + @pytest.mark.parametrize( + ["out_format", "expected"], + [("GTiff", "GTiff"), ("NetCDF", "NetCDF")], + ) + def test_create_job_out_format( + self, s2cube, get_create_job_pg, out_format, expected + ): + s2cube.create_job(out_format=out_format) + pg = get_create_job_pg() + assert set(pg.keys()) == {"loadcollection1", "saveresult1"} + assert pg["saveresult1"] == { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": expected, + "options": {}, + }, + "result": True, + } + + @pytest.mark.parametrize( + ["save_result_format", "execute_format", "expected"], + [ + ("GTiff", "GTiff", "GTiff"), + ("GTiff", None, "GTiff"), + ("NetCDF", "NetCDF", "NetCDF"), + ("NetCDF", None, "NetCDF"), + ], + ) + def test_create_job_existing_save_result( + self, + s2cube, + get_create_job_pg, + save_result_format, + execute_format, + expected, + ): + cube = s2cube.save_result(format=save_result_format) + cube.create_job(out_format=execute_format) + pg = get_create_job_pg() + assert set(pg.keys()) == {"loadcollection1", "saveresult1"} + assert pg["saveresult1"] == { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": expected, + "options": {}, + }, + "result": True, + } + + @pytest.mark.parametrize( + ["save_result_format", "execute_format"], + [("NetCDF", "GTiff"), ("GTiff", "NetCDF")], + ) + def test_create_job_existing_save_result_incompatible( + self, s2cube, save_result_format, execute_format + ): + cube = s2cube.save_result(format=save_result_format) + with pytest.raises(ValueError): + cube.create_job(out_format=execute_format) + + def test_execute_batch_defaults(self, s2cube, get_create_job_pg, recwarn, caplog): + s2cube.execute_batch() + pg = get_create_job_pg() + assert set(pg.keys()) == {"loadcollection1", "saveresult1"} + assert pg["saveresult1"] == { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": "GTiff", + "options": {}, + }, + "result": True, + } + assert recwarn.list == [] + assert caplog.records == [] + + @pytest.mark.parametrize( + ["out_format", "expected"], + [("GTiff", "GTiff"), ("NetCDF", "NetCDF")], + ) + def test_execute_batch_out_format( + self, s2cube, get_create_job_pg, out_format, expected + ): + s2cube.execute_batch(out_format=out_format) + pg = get_create_job_pg() + assert set(pg.keys()) == {"loadcollection1", "saveresult1"} + assert pg["saveresult1"] == { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": expected, + "options": {}, + }, + "result": True, + } + + @pytest.mark.parametrize( + ["output_file", "expected"], + [("cube.tiff", "GTiff"), ("cube.nc", "netCDF")], + ) + def test_execute_batch_out_format_from_output_file( + self, s2cube, get_create_job_pg, output_file, expected + ): + s2cube.execute_batch(outputfile=output_file) + pg = get_create_job_pg() + assert set(pg.keys()) == {"loadcollection1", "saveresult1"} + assert pg["saveresult1"] == { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": expected, + "options": {}, + }, + "result": True, + } + + @pytest.mark.parametrize( + ["save_result_format", "execute_format", "expected"], + [ + ("GTiff", "GTiff", "GTiff"), + ("GTiff", None, "GTiff"), + ("NetCDF", "NetCDF", "NetCDF"), + ("NetCDF", None, "NetCDF"), + ], + ) + def test_execute_batch_existing_save_result( + self, + s2cube, + get_create_job_pg, + save_result_format, + execute_format, + expected, + ): + cube = s2cube.save_result(format=save_result_format) + cube.execute_batch(out_format=execute_format) + pg = get_create_job_pg() + assert set(pg.keys()) == {"loadcollection1", "saveresult1"} + assert pg["saveresult1"] == { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": expected, + "options": {}, + }, + "result": True, + } + + @pytest.mark.parametrize( + ["save_result_format", "execute_format"], + [("NetCDF", "GTiff"), ("GTiff", "NetCDF")], + ) + def test_execute_batch_existing_save_result_incompatible( + self, s2cube, save_result_format, execute_format + ): + cube = s2cube.save_result(format=save_result_format) + with pytest.raises(ValueError): + cube.execute_batch(out_format=execute_format) diff --git a/tests/rest/datacube/test_datacube100.py b/tests/rest/datacube/test_datacube100.py index 124f4ebec..ff90647b2 100644 --- a/tests/rest/datacube/test_datacube100.py +++ b/tests/rest/datacube/test_datacube100.py @@ -2391,33 +2391,76 @@ def test_apply_append_math_keep_context(con100): } -@pytest.mark.parametrize(["save_result_kwargs", "download_kwargs", "expected_fail"], [ - ({}, {}, None), - ({"format": "GTiff"}, {}, None), - ({}, {"format": "GTiff"}, None), - ({"format": "GTiff"}, {"format": "GTiff"}, None), - ({"format": "netCDF"}, {"format": "NETCDF"}, None), - ( +@pytest.mark.parametrize( + ["save_result_kwargs", "download_filename", "download_kwargs", "expected"], + [ + ({}, "result.tiff", {}, b"this is GTiff data"), + ({}, "result.nc", {}, b"this is netCDF data"), + ({"format": "GTiff"}, "result.tiff", {}, b"this is GTiff data"), + ({"format": "GTiff"}, "result.tif", {}, b"this is GTiff data"), + ( + {"format": "GTiff"}, + "result.nc", + {}, + ValueError( + "Existing `save_result` node with different format 'GTiff' != 'netCDF'" + ), + ), + ({}, "result.tiff", {"format": "GTiff"}, b"this is GTiff data"), + ({}, "result.nc", {"format": "netCDF"}, b"this is netCDF data"), + ({}, "result.meh", {"format": "netCDF"}, b"this is netCDF data"), + ( + {"format": "GTiff"}, + "result.tiff", + {"format": "GTiff"}, + b"this is GTiff data", + ), + ( + {"format": "netCDF"}, + "result.tiff", + {"format": "NETCDF"}, + b"this is netCDF data", + ), + ( {"format": "netCDF"}, + "result.json", {"format": "JSON"}, - "Existing `save_result` node with different format 'netCDF' != 'JSON'" - ), - ({"options": {}}, {}, None), - ({"options": {"quality": "low"}}, {"options": {"quality": "low"}}, None), - ( + ValueError( + "Existing `save_result` node with different format 'netCDF' != 'JSON'" + ), + ), + ({"options": {}}, "result.tiff", {}, b"this is GTiff data"), + ( + {"options": {"quality": "low"}}, + "result.tiff", + {"options": {"quality": "low"}}, + b"this is GTiff data", + ), + ( {"options": {"colormap": "jet"}}, + "result.tiff", {"options": {"quality": "low"}}, - "Existing `save_result` node with different options {'colormap': 'jet'} != {'quality': 'low'}" - ), -]) + ValueError( + "Existing `save_result` node with different options {'colormap': 'jet'} != {'quality': 'low'}" + ), + ), + ], +) def test_save_result_and_download( - con100, requests_mock, tmp_path, save_result_kwargs, download_kwargs, expected_fail + con100, + requests_mock, + tmp_path, + save_result_kwargs, + download_filename, + download_kwargs, + expected, ): def post_result(request, context): pg = request.json()["process"]["process_graph"] process_histogram = collections.Counter(p["process_id"] for p in pg.values()) assert process_histogram["save_result"] == 1 - return b"tiffdata" + format = pg["saveresult1"]["arguments"]["format"] + return f"this is {format} data".encode("utf8") post_result_mock = requests_mock.post(API_URL + "/result", content=post_result) @@ -2425,14 +2468,14 @@ def post_result(request, context): if save_result_kwargs: cube = cube.save_result(**save_result_kwargs) - path = tmp_path / "tmp.tiff" - if expected_fail: - with pytest.raises(ValueError, match=expected_fail): + path = tmp_path / download_filename + if isinstance(expected, ValueError): + with pytest.raises(ValueError, match=str(expected)): cube.download(str(path), **download_kwargs) assert post_result_mock.call_count == 0 else: cube.download(str(path), **download_kwargs) - assert path.read_bytes() == b"tiffdata" + assert path.read_bytes() == expected assert post_result_mock.call_count == 1