diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 7bec83346b0..4fc9f69d3e3 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -4,7 +4,7 @@ import os as _os import pathlib import traceback as _traceback -from typing import List +from typing import List, Optional import click as _click from flyteidl.core import literals_pb2 as _literals_pb2 @@ -16,6 +16,7 @@ from flytekit.core import constants as _constants from flytekit.core import utils from flytekit.core.base_task import IgnoreOutputs, PythonTask +from flytekit.core.checkpointer import SyncCheckpoint from flytekit.core.context_manager import ( ExecutionParameters, ExecutionState, @@ -164,8 +165,10 @@ def _dispatch_execute( @contextlib.contextmanager def setup_execution( raw_output_data_prefix: str, - dynamic_addl_distro: str = None, - dynamic_dest_dir: str = None, + checkpoint_path: Optional[str] = None, + prev_checkpoint: Optional[str] = None, + dynamic_addl_distro: Optional[str] = None, + dynamic_dest_dir: Optional[str] = None, ): ctx = FlyteContextManager.current_context() @@ -175,6 +178,11 @@ def setup_execution( pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True) from flytekit import __version__ as _api_version + checkpointer = None + if checkpoint_path is not None: + checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) + logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") + execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=_internal_config.EXECUTION_PROJECT.get(), @@ -202,6 +210,7 @@ def setup_execution( logging=python_logging, tmp_dir=user_workspace_dir, raw_output_prefix=ctx.file_access._raw_output_prefix, + checkpoint=checkpointer, ) # TODO: Remove this check for flytekit 1.0 @@ -266,14 +275,16 @@ def _handle_annotated_task( @_scopes.system_entry_point def _execute_task( - inputs, - output_prefix, - raw_output_data_prefix, - test, + inputs: str, + output_prefix: str, + test: bool, + raw_output_data_prefix: str, resolver: str, resolver_args: List[str], - dynamic_addl_distro: str = None, - dynamic_dest_dir: str = None, + checkpoint_path: Optional[str] = None, + prev_checkpoint: Optional[str] = None, + dynamic_addl_distro: Optional[str] = None, + dynamic_dest_dir: Optional[str] = None, ): """ This function should be called for new API tasks (those only available in 0.16 and later that leverage Python @@ -302,7 +313,13 @@ def _execute_task( raise Exception("cannot be <1") with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()): - with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx: + with setup_execution( + raw_output_data_prefix, + checkpoint_path=checkpoint_path, + prev_checkpoint=prev_checkpoint, + dynamic_addl_distro=dynamic_addl_distro, + dynamic_dest_dir=dynamic_dest_dir, + ) as ctx: resolver_obj = load_object_from_module(resolver) # Use the resolver to load the actual task object _task_def = resolver_obj.load_task(loader_args=resolver_args) @@ -321,16 +338,20 @@ def _execute_map_task( raw_output_data_prefix, max_concurrency, test, - dynamic_addl_distro: str, - dynamic_dest_dir: str, resolver: str, resolver_args: List[str], + checkpoint_path: Optional[str] = None, + prev_checkpoint: Optional[str] = None, + dynamic_addl_distro: Optional[str] = None, + dynamic_dest_dir: Optional[str] = None, ): if len(resolver_args) < 1: raise Exception(f"Resolver args cannot be <1, got {resolver_args}") with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()): - with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx: + with setup_execution( + raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir + ) as ctx: resolver_obj = load_object_from_module(resolver) # Use the resolver to load the actual task object _task_def = resolver_obj.load_task(loader_args=resolver_args) @@ -352,6 +373,22 @@ def _execute_map_task( _handle_annotated_task(ctx, map_task, inputs, output_prefix) +def normalize_inputs( + raw_output_data_prefix: Optional[str], checkpoint_path: Optional[str], prev_checkpoint: Optional[str] +): + # Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original + # template string, so let's explicitly set it to None so that the downstream functions will know to fall back + # to the original shard formatter/prefix config. + if raw_output_data_prefix == "{{.rawOutputDataPrefix}}": + raw_output_data_prefix = None + if checkpoint_path == "{{.checkpointOutputPrefix}}": + checkpoint_path = None + if prev_checkpoint == "{{.prevCheckpointPrefix}}" or prev_checkpoint == "" or prev_checkpoint == '""': + prev_checkpoint = None + + return raw_output_data_prefix, checkpoint_path, prev_checkpoint + + @_click.group() def _pass_through(): pass @@ -361,6 +398,8 @@ def _pass_through(): @_click.option("--inputs", required=True) @_click.option("--output-prefix", required=True) @_click.option("--raw-output-data-prefix", required=False) +@_click.option("--checkpoint-path", required=False) +@_click.option("--prev-checkpoint", required=False) @_click.option("--test", is_flag=True) @_click.option("--dynamic-addl-distro", required=False) @_click.option("--dynamic-dest-dir", required=False) @@ -375,6 +414,8 @@ def execute_task_cmd( output_prefix, raw_output_data_prefix, test, + prev_checkpoint, + checkpoint_path, dynamic_addl_distro, dynamic_dest_dir, resolver, @@ -383,26 +424,27 @@ def execute_task_cmd( logger.info(get_version_message()) # We get weird errors if there are no click echo messages at all, so emit an empty string so that unit tests pass. _click.echo("") - # Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original - # template string, so let's explicitly set it to None so that the downstream functions will know to fall back - # to the original shard formatter/prefix config. - if raw_output_data_prefix == "{{.rawOutputDataPrefix}}": - raw_output_data_prefix = None + raw_output_data_prefix, checkpoint_path, prev_checkpoint = normalize_inputs( + raw_output_data_prefix, checkpoint_path, prev_checkpoint + ) # For new API tasks (as of 0.16.x), we need to call a different function. # Use the presence of the resolver to differentiate between old API tasks and new API tasks # The addition of a new top-level command seemed out of scope at the time of this writing to pursue given how # pervasive this top level command already (plugins mostly). + logger.debug(f"Running task execution with resolver {resolver}...") _execute_task( - inputs, - output_prefix, - raw_output_data_prefix, - test, - resolver, - resolver_args, - dynamic_addl_distro, - dynamic_dest_dir, + inputs=inputs, + output_prefix=output_prefix, + raw_output_data_prefix=raw_output_data_prefix, + test=test, + resolver=resolver, + resolver_args=resolver_args, + dynamic_addl_distro=dynamic_addl_distro, + dynamic_dest_dir=dynamic_dest_dir, + checkpoint_path=checkpoint_path, + prev_checkpoint=prev_checkpoint, ) @@ -446,6 +488,8 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd): @_click.option("--dynamic-addl-distro", required=False) @_click.option("--dynamic-dest-dir", required=False) @_click.option("--resolver", required=True) +@_click.option("--checkpoint-path", required=False) +@_click.option("--prev-checkpoint", required=False) @_click.argument( "resolver-args", type=_click.UNPROCESSED, @@ -461,19 +505,27 @@ def map_execute_task_cmd( dynamic_dest_dir, resolver, resolver_args, + prev_checkpoint, + checkpoint_path, ): logger.info(get_version_message()) + raw_output_data_prefix, checkpoint_path, prev_checkpoint = normalize_inputs( + raw_output_data_prefix, checkpoint_path, prev_checkpoint + ) + _execute_map_task( - inputs, - output_prefix, - raw_output_data_prefix, - max_concurrency, - test, - dynamic_addl_distro, - dynamic_dest_dir, - resolver, - resolver_args, + inputs=inputs, + output_prefix=output_prefix, + raw_output_data_prefix=raw_output_data_prefix, + max_concurrency=max_concurrency, + test=test, + dynamic_addl_distro=dynamic_addl_distro, + dynamic_dest_dir=dynamic_dest_dir, + resolver=resolver, + resolver_args=resolver_args, + checkpoint_path=checkpoint_path, + prev_checkpoint=prev_checkpoint, ) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index f590ba2033d..c5c35373d86 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -263,6 +263,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr else: logger.info("Cache hit") else: + es = ctx.execution_state + b = es.user_space_params.with_task_sandbox() + ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) outputs_literals = outputs_literal_map.literals diff --git a/flytekit/core/checkpointer.py b/flytekit/core/checkpointer.py new file mode 100644 index 00000000000..bd17b7748b7 --- /dev/null +++ b/flytekit/core/checkpointer.py @@ -0,0 +1,157 @@ +import io +import tempfile +import typing +from abc import abstractmethod +from pathlib import Path + + +class Checkpoint(object): + """ + Base class for Checkpoint system. Checkpoint system allows reading and writing custom checkpoints from user + scripts + """ + + @abstractmethod + def prev_exists(self) -> bool: + raise NotImplementedError("Use one of the derived classes") + + @abstractmethod + def restore(self, path: typing.Union[Path, str]) -> typing.Optional[Path]: + """ + Given a path, if a previous checkpoint exists, will be downloaded to this path. + If download is successful the downloaded path is returned + + .. note: + + Download will not be performed, if the checkpoint was previously restored. The method will return the + previously downloaded path. + + """ + raise NotImplementedError("Use one of the derived classes") + + @abstractmethod + def save(self, cp: typing.Union[Path, str, io.BufferedReader]): + """ + Args: + cp: Checkpoint file (path, str path or a io.BufferedReader) + + Usage: If you have a io.BufferedReader then the following should work + + .. code-block: python + + with input_file.open(mode="rb") as b: + checkpointer.save(b) + """ + raise NotImplementedError("Use one of the derived classes") + + @abstractmethod + def read(self) -> typing.Optional[bytes]: + """ + This should only be used if there is a singular checkpoint file written. If more than one checkpoint file is + found, this will raise a ValueError + """ + raise NotImplementedError("Use one of the derived classes") + + @abstractmethod + def write(self, b: bytes): + """ + This will overwrite the checkpoint. It can be retrieved using read or restore + """ + raise NotImplementedError("Use one of the derived classes") + + +class SyncCheckpoint(Checkpoint): + """ + This class is NOT THREAD-SAFE! + Sync Checkpoint, will synchronously checkpoint a user given file or folder. + It will also synchronously download / restore previous checkpoints, when restore is invoked. + + TODO: Implement an async checkpoint system + """ + + SRC_LOCAL_FOLDER = "prev_cp" + TMP_DST_PATH = "_dst_cp" + + def __init__(self, checkpoint_dest: str, checkpoint_src: typing.Optional[str] = None): + """ + Args: + checkpoint_src: If a previous checkpoint should exist, this path should be set to the folder that contains the checkpoint information + checkpoint_dest: Location where the new checkpoint should be copied to + """ + self._checkpoint_dest = checkpoint_dest + self._checkpoint_src = checkpoint_src if checkpoint_src and checkpoint_src != "" else None + self._td = tempfile.TemporaryDirectory() + self._prev_download_path = None + + def __del__(self): + self._td.cleanup() + + def prev_exists(self) -> bool: + return self._checkpoint_src is not None + + def restore(self, path: typing.Optional[typing.Union[Path, str]] = None) -> typing.Optional[Path]: + + # We have to lazy load, until we fix the imports + from flytekit.core.context_manager import FlyteContextManager + + if self._checkpoint_src is None or self._checkpoint_src == "": + return None + + if self._prev_download_path: + return self._prev_download_path + + if path is None: + p = Path(self._td.name) + path = p.joinpath(self.SRC_LOCAL_FOLDER) + path.mkdir() + elif isinstance(path, str): + path = Path(path) + + if not path.is_dir(): + raise ValueError("Checkpoints can be restored to a directory only.") + + FlyteContextManager.current_context().file_access.download_directory(self._checkpoint_src, str(path)) + self._prev_download_path = path + return self._prev_download_path + + def save(self, cp: typing.Union[Path, str, io.BufferedReader]): + # We have to lazy load, until we fix the imports + from flytekit.core.context_manager import FlyteContextManager + + fa = FlyteContextManager.current_context().file_access + if isinstance(cp, (Path, str)): + if isinstance(cp, str): + cp = Path(cp) + if cp.is_dir(): + fa.upload_directory(str(cp), self._checkpoint_dest) + else: + fname = cp.stem + cp.suffix + rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, fname) + fa.upload(str(cp), rpath) + return + + if not isinstance(cp, io.IOBase): + raise ValueError(f"Only a valid path or IOBase type (reader) should be provided, received {type(cp)}") + + p = Path(self._td.name) + dest_cp = p.joinpath(self.TMP_DST_PATH) + with dest_cp.open("wb") as f: + f.write(cp.read()) + + rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, self.TMP_DST_PATH) + fa.upload(str(dest_cp), rpath) + + def read(self) -> typing.Optional[bytes]: + p = self.restore() + if p is None: + return None + files = list(p.iterdir()) + if len(files) == 0 or len(files) > 1: + raise ValueError(f"Expected exactly one checkpoint - found {len(files)}") + f = files[0] + return f.read_bytes() + + def write(self, b: bytes): + f = io.BytesIO(b) + f = typing.cast(io.BufferedReader, f) + self.save(f) diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index a3ee0c2972e..70af22f1e19 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -19,6 +19,7 @@ import os import pathlib import re +import tempfile import traceback import typing from contextlib import contextmanager @@ -34,6 +35,7 @@ from flytekit.configuration import sdk as _sdk_config from flytekit.configuration import secrets from flytekit.core import mock_stats, utils +from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider from flytekit.core.node import Node from flytekit.interfaces.cli_identifiers import WorkflowExecutionIdentifier @@ -159,6 +161,7 @@ class Builder(object): execution_id: str attrs: typing.Dict[str, typing.Any] working_dir: typing.Union[os.PathLike, utils.AutoDeletingTempDir] + checkpoint: typing.Optional[Checkpoint] raw_output_prefix: str def __init__(self, current: typing.Optional[ExecutionParameters] = None): @@ -167,6 +170,7 @@ def __init__(self, current: typing.Optional[ExecutionParameters] = None): self.working_dir = current.working_directory if current else None self.execution_id = current.execution_id if current else None self.logging = current.logging if current else None + self.checkpoint = current._checkpoint if current else None self.attrs = current._attrs if current else {} self.raw_output_prefix = current.raw_output_prefix if current else None @@ -183,6 +187,7 @@ def build(self) -> ExecutionParameters: tmp_dir=self.working_dir, execution_id=self.execution_id, logging=self.logging, + checkpoint=self.checkpoint, raw_output_prefix=self.raw_output_prefix, **self.attrs, ) @@ -191,10 +196,26 @@ def build(self) -> ExecutionParameters: def new_builder(current: ExecutionParameters = None) -> Builder: return ExecutionParameters.Builder(current=current) + def with_task_sandbox(self) -> Builder: + prefix = self.working_directory + if isinstance(self.working_directory, utils.AutoDeletingTempDir): + prefix = self.working_directory.name + task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) + p = pathlib.Path(task_sandbox_dir) + cp_dir = p.joinpath("__cp") + cp_dir.mkdir(exist_ok=True) + cp = SyncCheckpoint(checkpoint_dest=str(cp_dir)) + b = self.new_builder(self) + b.checkpoint = cp + b.working_dir = task_sandbox_dir + return b + def builder(self) -> Builder: return ExecutionParameters.Builder(current=self) - def __init__(self, execution_date, tmp_dir, stats, execution_id, logging, raw_output_prefix, **kwargs): + def __init__( + self, execution_date, tmp_dir, stats, execution_id, logging, raw_output_prefix, checkpoint=None, **kwargs + ): """ Args: execution_date: Date when the execution is running @@ -202,6 +223,7 @@ def __init__(self, execution_date, tmp_dir, stats, execution_id, logging, raw_ou stats: handle to emit stats execution_id: Identifier for the xecution logging: handle to logging + checkpoint: Checkpoint Handle to the configured checkpoint system """ self._stats = stats self._execution_date = execution_date @@ -213,6 +235,7 @@ def __init__(self, execution_date, tmp_dir, stats, execution_id, logging, raw_ou self._attrs = kwargs # It is safe to recreate the Secrets Manager self._secrets_manager = SecretsManager() + self._checkpoint = checkpoint @property def stats(self) -> taggable.TaggableStats: @@ -274,6 +297,12 @@ def execution_id(self) -> str: def secrets(self) -> SecretsManager: return self._secrets_manager + @property + def checkpoint(self) -> Checkpoint: + if self._checkpoint is None: + raise NotImplementedError("Checkpointing is not available, please check the version of the platform.") + return self._checkpoint + def __getattr__(self, attr_name: str) -> typing.Any: """ This houses certain task specific context. For example in Spark, it houses the SparkSession, etc diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 29838cffcda..f760be5d3c0 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -78,6 +78,10 @@ def get_command(self, settings: SerializationSettings) -> List[str]: "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", "--resolver", self._run_task.task_resolver.location, "--", diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 0226760f085..c5f8413fea5 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -2,6 +2,7 @@ import importlib import re +from abc import ABC from typing import Callable, Dict, List, Optional, TypeVar from flytekit.core.base_task import PythonTask, TaskResolverMixin @@ -17,7 +18,7 @@ T = TypeVar("T") -class PythonAutoContainerTask(PythonTask[T], metaclass=FlyteTrackedABC): +class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC): """ A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the container and the container information to be automatically captured. @@ -119,6 +120,10 @@ def get_default_command(self, settings: SerializationSettings) -> List[str]: "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", "--resolver", self.task_resolver.location, "--", diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index 71dd45f581d..db3d4b573ba 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -5,7 +5,7 @@ import time as _time from hashlib import sha224 as _sha224 from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional from flytekit.configuration import resources as _resource_config from flytekit.models import task as _task_models @@ -53,18 +53,18 @@ def _get_container_definition( image: str, command: List[str], args: List[str], - data_loading_config: _task_models.DataLoadingConfig, - storage_request: str = None, - ephemeral_storage_request: str = None, - cpu_request: str = None, - gpu_request: str = None, - memory_request: str = None, - storage_limit: str = None, - ephemeral_storage_limit: str = None, - cpu_limit: str = None, - gpu_limit: str = None, - memory_limit: str = None, - environment: Dict[str, str] = None, + data_loading_config: Optional[_task_models.DataLoadingConfig] = None, + storage_request: Optional[str] = None, + ephemeral_storage_request: Optional[str] = None, + cpu_request: Optional[str] = None, + gpu_request: Optional[str] = None, + memory_request: Optional[str] = None, + storage_limit: Optional[str] = None, + ephemeral_storage_limit: Optional[str] = None, + cpu_limit: Optional[str] = None, + gpu_limit: Optional[str] = None, + memory_limit: Optional[str] = None, + environment: Optional[Dict[str, str]] = None, ) -> _task_models.Container: storage_limit = storage_limit or _resource_config.DEFAULT_STORAGE_LIMIT.get() storage_request = storage_request or _resource_config.DEFAULT_STORAGE_REQUEST.get() diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index 82a0bcf6f8f..9677f512113 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -69,6 +69,10 @@ def simple_pod_task(i: int): "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", @@ -134,6 +138,10 @@ def simple_pod_task(i: int): "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", @@ -321,6 +329,10 @@ def simple_pod_task(i: int): "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", @@ -367,6 +379,10 @@ def simple_pod_task(i: int): "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 3a71b567d89..4ccb3c8bcd8 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -5,7 +5,7 @@ import pytest from flyteidl.core.errors_pb2 import ErrorDocument -from flytekit.bin.entrypoint import _dispatch_execute, setup_execution +from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs from flytekit.core.dynamic_workflow_task import dynamic @@ -284,8 +284,19 @@ def test_setup_bad_prefix(): def test_setup_cloud_prefix(): - with setup_execution("s3://") as ctx: + with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: assert isinstance(ctx.file_access._default_remote, S3Persistence) - with setup_execution("gs://") as ctx: + with setup_execution("gs://", checkpoint_path=None, prev_checkpoint=None) as ctx: assert isinstance(ctx.file_access._default_remote, GCSPersistence) + + +def test_normalize_inputs(): + assert normalize_inputs("{{.rawOutputDataPrefix}}", "{{.checkpointOutputPrefix}}", "{{.prevCheckpointPrefix}}") == ( + None, + None, + None, + ) + assert normalize_inputs("/raw", "/cp1", '""') == ("/raw", "/cp1", None) + assert normalize_inputs("/raw", "/cp1", "") == ("/raw", "/cp1", None) + assert normalize_inputs("/raw", "/cp1", "/prev") == ("/raw", "/cp1", "/prev") diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py new file mode 100644 index 00000000000..0199737335e --- /dev/null +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -0,0 +1,117 @@ +from pathlib import Path + +import pytest + +import flytekit +from flytekit.core.checkpointer import SyncCheckpoint + + +def test_sync_checkpoint_write(tmpdir): + td_path = Path(tmpdir) + cp = SyncCheckpoint(checkpoint_dest=tmpdir) + assert cp.read() is None + assert cp.restore() is None + dst_path = td_path.joinpath(SyncCheckpoint.TMP_DST_PATH) + assert not dst_path.exists() + cp.write(b"bytes") + assert dst_path.exists() + + +def test_sync_checkpoint_save_file(tmpdir): + td_path = Path(tmpdir) + cp = SyncCheckpoint(checkpoint_dest=tmpdir) + dst_path = td_path.joinpath(SyncCheckpoint.TMP_DST_PATH) + assert not dst_path.exists() + inp = td_path.joinpath("test") + with inp.open("wb") as f: + f.write(b"blah") + with inp.open("rb") as f: + cp.save(f) + assert dst_path.exists() + + with pytest.raises(ValueError): + # Unsupported object + cp.save(SyncCheckpoint) # noqa + + +def test_sync_checkpoint_save_filepath(tmpdir): + td_path = Path(tmpdir) + cp = SyncCheckpoint(checkpoint_dest=tmpdir) + dst_path = td_path.joinpath("test") + assert not dst_path.exists() + inp = td_path.joinpath("test") + with inp.open("wb") as f: + f.write(b"blah") + cp.save(inp) + assert dst_path.exists() + + +def test_sync_checkpoint_restore(tmpdir): + td_path = Path(tmpdir) + dest = td_path.joinpath("dest") + dest.mkdir() + src = td_path.joinpath("src") + src.mkdir() + prev = src.joinpath("prev") + p = b"prev-bytes" + with prev.open("wb") as f: + f.write(p) + cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src)) + user_dest = td_path.joinpath("user_dest") + + with pytest.raises(ValueError): + cp.restore(user_dest) + + user_dest.mkdir() + assert cp.restore(user_dest) == user_dest + assert cp.restore("other_path") == user_dest + + +def test_sync_checkpoint_restore_default_path(tmpdir): + td_path = Path(tmpdir) + dest = td_path.joinpath("dest") + dest.mkdir() + src = td_path.joinpath("src") + src.mkdir() + prev = src.joinpath("prev") + p = b"prev-bytes" + with prev.open("wb") as f: + f.write(p) + cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src)) + assert cp.read() == p + assert cp._prev_download_path is not None + assert cp.restore() == cp._prev_download_path + + +def test_sync_checkpoint_read_multiple_files(tmpdir): + """ + Read can only work with one file. + """ + td_path = Path(tmpdir) + dest = td_path.joinpath("dest") + dest.mkdir() + src = td_path.joinpath("src") + src.mkdir() + prev = src.joinpath("prev") + prev2 = src.joinpath("prev2") + p = b"prev-bytes" + with prev.open("wb") as f: + f.write(p) + with prev2.open("wb") as f: + f.write(p) + cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src)) + + with pytest.raises(ValueError, match="Expected exactly one checkpoint - found 2"): + cp.read() + + +@flytekit.task +def t1(n: int) -> int: + ctx = flytekit.current_context() + cp = ctx.checkpoint + cp.write(bytes(n + 1)) + return n + 1 + + +def test_checkpoint_task(): + assert t1(n=5) == 6 diff --git a/tests/flytekit/unit/core/test_checkpointer.py b/tests/flytekit/unit/core/test_checkpointer.py new file mode 100644 index 00000000000..dda786545bf --- /dev/null +++ b/tests/flytekit/unit/core/test_checkpointer.py @@ -0,0 +1,67 @@ +import typing +from pathlib import Path + +import py.path + +from flytekit.core.checkpointer import SyncCheckpoint + +CHECKPOINT_FILE = "cp" + + +def create_folder_write_file(tmpdir: py.path.local) -> typing.Tuple[py.path.local, py.path.local, py.path.local]: + outputs = tmpdir.mkdir("outputs") + + # Make an input test directory with one file called cp + inputs = tmpdir.mkdir("inputs") + input_file = inputs.join(CHECKPOINT_FILE) + input_file.write_text("Hello!", encoding="utf-8") + + return inputs, input_file, outputs + + +def test_sync_checkpoint_file(tmpdir: py.path.local): + inputs, input_file, outputs = create_folder_write_file(tmpdir) + cp = SyncCheckpoint(checkpoint_dest=str(outputs)) + # Lets try to restore - should not work! + assert not cp.restore("/tmp") + # Now save + cp.save(str(input_file)) + # Expect file in tmpdir + expected_dst = outputs.join(CHECKPOINT_FILE) + assert outputs.listdir() == [expected_dst] + + +def test_sync_checkpoint_reader(tmpdir: py.path.local): + inputs, input_file, outputs = create_folder_write_file(tmpdir) + cp = SyncCheckpoint(checkpoint_dest=str(outputs)) + # Lets try to restore - should not work! + assert not cp.restore("/tmp") + # Now save + with input_file.open(mode="rb") as b: + cp.save(b) + # Expect file in tmpdir + expected_dst = outputs.join(SyncCheckpoint.TMP_DST_PATH) + assert outputs.listdir() == [expected_dst] + + +def test_sync_checkpoint_folder(tmpdir: py.path.local): + inputs, input_file, outputs = create_folder_write_file(tmpdir) + cp = SyncCheckpoint(checkpoint_dest=str(outputs)) + # Lets try to restore - should not work! + assert not cp.restore("/tmp") + # Now save + cp.save(Path(str(inputs))) + # Expect file in tmpdir + expected_dst = outputs.join(CHECKPOINT_FILE) + assert outputs.listdir() == [expected_dst] + + +def test_sync_checkpoint_previous(tmpdir: py.path.local): + inputs, input_file, outputs = create_folder_write_file(tmpdir) + cp = SyncCheckpoint(checkpoint_dest=str(outputs), checkpoint_src=str(inputs)) + scratch = tmpdir.mkdir("user_scratch") + assert cp.restore(str(scratch)) == scratch + assert scratch.listdir() == [scratch.join(CHECKPOINT_FILE)] + + # ensure download is not performed again + assert cp.restore("x") == scratch diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 31253ce1ddb..d1f95852c1c 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -76,6 +76,10 @@ def test_serialization(): "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index edd31405544..4394710dd5a 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -1,7 +1,9 @@ +from typing import Any + import pytest -from flytekit.core.context_manager import Image, ImageConfig -from flytekit.core.python_auto_container import get_registerable_container_image +from flytekit.core.context_manager import Image, ImageConfig, SerializationSettings +from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image @pytest.fixture @@ -10,7 +12,46 @@ def default_image_config(): return ImageConfig(default_image=default_image) +@pytest.fixture +def default_serialization_settings(default_image_config): + return SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config, env={"FOO": "bar"} + ) + + def test_image_name_interpolation(default_image_config): img_to_interpolate = "{{.image.default.fqn}}:{{.image.default.version}}-special" img = get_registerable_container_image(img=img_to_interpolate, cfg=default_image_config) assert img == "docker.io/xyz:some-git-hash-special" + + +class DummyAutoContainerTask(PythonAutoContainerTask): + def execute(self, **kwargs) -> Any: + pass + + +task = DummyAutoContainerTask(name="x", task_config=None, task_type="t") + + +def test_default_command(default_serialization_settings): + cmd = task.get_default_command(settings=default_serialization_settings) + assert cmd == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "test_python_auto_container", + "task-name", + "task", + ]