Skip to content

Commit

Permalink
Intratask checkpointing (flyteorg#771)
Browse files Browse the repository at this point in the history
* wip - intratask checkpointing

Signed-off-by: Ketan Umare <[email protected]>

* sync checkpointer with tests

Signed-off-by: Ketan Umare <[email protected]>

* Checkpinter in entrypoint

Signed-off-by: Ketan Umare <[email protected]>

* checkpoint in progress

Signed-off-by: Ketan Umare <[email protected]>

* wip

Signed-off-by: Ketan Umare <[email protected]>

* Intratask checkpointer

Signed-off-by: Ketan Umare <[email protected]>

* Checkpoint updated

Signed-off-by: Ketan Umare <[email protected]>

* Intra-task checkpointing

Signed-off-by: Ketan Umare <[email protected]>

* Test and entrypoint updated

Signed-off-by: Ketan Umare <[email protected]>

* lint fixed

Signed-off-by: Ketan Umare <[email protected]>

* test fixes

Signed-off-by: Ketan Umare <[email protected]>

* fmt

Signed-off-by: Ketan Umare <[email protected]>

* updated entrypoint

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* update

Signed-off-by: Ketan Umare <[email protected]>

* print

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* SyncCheckpointer working

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* update

Signed-off-by: Ketan Umare <[email protected]>

* fixed import problems

Signed-off-by: Ketan Umare <[email protected]>

* fixed test

Signed-off-by: Ketan Umare <[email protected]>

* fixed imports

Signed-off-by: Ketan Umare <[email protected]>

* fixed lints and errors

Signed-off-by: Ketan Umare <[email protected]>

* lint fix

Signed-off-by: Ketan Umare <[email protected]>

* addressed comments

Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 authored and kennyworkman committed Feb 8, 2022
1 parent d81f192 commit ff837a8
Show file tree
Hide file tree
Showing 13 changed files with 561 additions and 55 deletions.
122 changes: 87 additions & 35 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os as _os
import pathlib
import traceback as _traceback
from typing import List
from typing import List, Optional

import click as _click
from flyteidl.core import literals_pb2 as _literals_pb2
Expand All @@ -16,6 +16,7 @@
from flytekit.core import constants as _constants
from flytekit.core import utils
from flytekit.core.base_task import IgnoreOutputs, PythonTask
from flytekit.core.checkpointer import SyncCheckpoint
from flytekit.core.context_manager import (
ExecutionParameters,
ExecutionState,
Expand Down Expand Up @@ -164,8 +165,10 @@ def _dispatch_execute(
@contextlib.contextmanager
def setup_execution(
raw_output_data_prefix: str,
dynamic_addl_distro: str = None,
dynamic_dest_dir: str = None,
checkpoint_path: Optional[str] = None,
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
):
ctx = FlyteContextManager.current_context()

Expand All @@ -175,6 +178,11 @@ def setup_execution(
pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
from flytekit import __version__ as _api_version

checkpointer = None
if checkpoint_path is not None:
checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint)
logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}")

execution_parameters = ExecutionParameters(
execution_id=_identifier.WorkflowExecutionIdentifier(
project=_internal_config.EXECUTION_PROJECT.get(),
Expand Down Expand Up @@ -202,6 +210,7 @@ def setup_execution(
logging=python_logging,
tmp_dir=user_workspace_dir,
raw_output_prefix=ctx.file_access._raw_output_prefix,
checkpoint=checkpointer,
)

# TODO: Remove this check for flytekit 1.0
Expand Down Expand Up @@ -266,14 +275,16 @@ def _handle_annotated_task(

@_scopes.system_entry_point
def _execute_task(
inputs,
output_prefix,
raw_output_data_prefix,
test,
inputs: str,
output_prefix: str,
test: bool,
raw_output_data_prefix: str,
resolver: str,
resolver_args: List[str],
dynamic_addl_distro: str = None,
dynamic_dest_dir: str = None,
checkpoint_path: Optional[str] = None,
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
):
"""
This function should be called for new API tasks (those only available in 0.16 and later that leverage Python
Expand Down Expand Up @@ -302,7 +313,13 @@ def _execute_task(
raise Exception("cannot be <1")

with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()):
with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx:
with setup_execution(
raw_output_data_prefix,
checkpoint_path=checkpoint_path,
prev_checkpoint=prev_checkpoint,
dynamic_addl_distro=dynamic_addl_distro,
dynamic_dest_dir=dynamic_dest_dir,
) as ctx:
resolver_obj = load_object_from_module(resolver)
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
Expand All @@ -321,16 +338,20 @@ def _execute_map_task(
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro: str,
dynamic_dest_dir: str,
resolver: str,
resolver_args: List[str],
checkpoint_path: Optional[str] = None,
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
):
if len(resolver_args) < 1:
raise Exception(f"Resolver args cannot be <1, got {resolver_args}")

with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()):
with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx:
with setup_execution(
raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir
) as ctx:
resolver_obj = load_object_from_module(resolver)
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
Expand All @@ -352,6 +373,22 @@ def _execute_map_task(
_handle_annotated_task(ctx, map_task, inputs, output_prefix)


def normalize_inputs(
raw_output_data_prefix: Optional[str], checkpoint_path: Optional[str], prev_checkpoint: Optional[str]
):
# Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original
# template string, so let's explicitly set it to None so that the downstream functions will know to fall back
# to the original shard formatter/prefix config.
if raw_output_data_prefix == "{{.rawOutputDataPrefix}}":
raw_output_data_prefix = None
if checkpoint_path == "{{.checkpointOutputPrefix}}":
checkpoint_path = None
if prev_checkpoint == "{{.prevCheckpointPrefix}}" or prev_checkpoint == "" or prev_checkpoint == '""':
prev_checkpoint = None

return raw_output_data_prefix, checkpoint_path, prev_checkpoint


@_click.group()
def _pass_through():
pass
Expand All @@ -361,6 +398,8 @@ def _pass_through():
@_click.option("--inputs", required=True)
@_click.option("--output-prefix", required=True)
@_click.option("--raw-output-data-prefix", required=False)
@_click.option("--checkpoint-path", required=False)
@_click.option("--prev-checkpoint", required=False)
@_click.option("--test", is_flag=True)
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
Expand All @@ -375,6 +414,8 @@ def execute_task_cmd(
output_prefix,
raw_output_data_prefix,
test,
prev_checkpoint,
checkpoint_path,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
Expand All @@ -383,26 +424,27 @@ def execute_task_cmd(
logger.info(get_version_message())
# We get weird errors if there are no click echo messages at all, so emit an empty string so that unit tests pass.
_click.echo("")
# Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original
# template string, so let's explicitly set it to None so that the downstream functions will know to fall back
# to the original shard formatter/prefix config.
if raw_output_data_prefix == "{{.rawOutputDataPrefix}}":
raw_output_data_prefix = None
raw_output_data_prefix, checkpoint_path, prev_checkpoint = normalize_inputs(
raw_output_data_prefix, checkpoint_path, prev_checkpoint
)

# For new API tasks (as of 0.16.x), we need to call a different function.
# Use the presence of the resolver to differentiate between old API tasks and new API tasks
# The addition of a new top-level command seemed out of scope at the time of this writing to pursue given how
# pervasive this top level command already (plugins mostly).

logger.debug(f"Running task execution with resolver {resolver}...")
_execute_task(
inputs,
output_prefix,
raw_output_data_prefix,
test,
resolver,
resolver_args,
dynamic_addl_distro,
dynamic_dest_dir,
inputs=inputs,
output_prefix=output_prefix,
raw_output_data_prefix=raw_output_data_prefix,
test=test,
resolver=resolver,
resolver_args=resolver_args,
dynamic_addl_distro=dynamic_addl_distro,
dynamic_dest_dir=dynamic_dest_dir,
checkpoint_path=checkpoint_path,
prev_checkpoint=prev_checkpoint,
)


Expand Down Expand Up @@ -446,6 +488,8 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd):
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
@_click.option("--resolver", required=True)
@_click.option("--checkpoint-path", required=False)
@_click.option("--prev-checkpoint", required=False)
@_click.argument(
"resolver-args",
type=_click.UNPROCESSED,
Expand All @@ -461,19 +505,27 @@ def map_execute_task_cmd(
dynamic_dest_dir,
resolver,
resolver_args,
prev_checkpoint,
checkpoint_path,
):
logger.info(get_version_message())

raw_output_data_prefix, checkpoint_path, prev_checkpoint = normalize_inputs(
raw_output_data_prefix, checkpoint_path, prev_checkpoint
)

_execute_map_task(
inputs,
output_prefix,
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
inputs=inputs,
output_prefix=output_prefix,
raw_output_data_prefix=raw_output_data_prefix,
max_concurrency=max_concurrency,
test=test,
dynamic_addl_distro=dynamic_addl_distro,
dynamic_dest_dir=dynamic_dest_dir,
resolver=resolver,
resolver_args=resolver_args,
checkpoint_path=checkpoint_path,
prev_checkpoint=prev_checkpoint,
)


Expand Down
3 changes: 3 additions & 0 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
else:
logger.info("Cache hit")
else:
es = ctx.execution_state
b = es.user_space_params.with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
outputs_literals = outputs_literal_map.literals

Expand Down
Loading

0 comments on commit ff837a8

Please sign in to comment.