Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intratask checkpointing #771

Merged
merged 32 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9060306
wip - intratask checkpointing
kumare3 Dec 7, 2021
1fbaa9d
sync checkpointer with tests
kumare3 Dec 8, 2021
f6102c8
Checkpinter in entrypoint
kumare3 Dec 10, 2021
bcf2fc5
checkpoint in progress
kumare3 Dec 10, 2021
61f6366
wip
kumare3 Dec 22, 2021
a7b8893
Merge branch 'master' into intratask-checkpoint-1
kumare3 Jan 3, 2022
d8ce7cd
Intratask checkpointer
kumare3 Jan 4, 2022
3deda7d
Checkpoint updated
kumare3 Jan 7, 2022
c51f507
Intra-task checkpointing
kumare3 Jan 8, 2022
e05cfac
Test and entrypoint updated
kumare3 Jan 11, 2022
740f366
Merge branch 'master' into intratask-checkpoint-1
kumare3 Jan 11, 2022
db95d90
Merge branch 'master' into intratask-checkpoint-1
kumare3 Jan 12, 2022
6265264
lint fixed
kumare3 Jan 12, 2022
f19ed15
test fixes
kumare3 Jan 12, 2022
1213386
fmt
kumare3 Jan 12, 2022
f535c4d
Merge branch 'master' into intratask-checkpoint-1
kumare3 Jan 13, 2022
8646b8a
updated entrypoint
kumare3 Jan 13, 2022
bf098c5
updated
kumare3 Jan 13, 2022
700923f
update
kumare3 Jan 13, 2022
cf93b63
print
kumare3 Jan 13, 2022
890aac1
updated
kumare3 Jan 13, 2022
48d8fc4
SyncCheckpointer working
kumare3 Jan 14, 2022
a91e076
updated
kumare3 Jan 14, 2022
3762461
update
kumare3 Jan 14, 2022
631c66d
Merge branch 'master' into intratask-checkpoint-1
kumare3 Jan 16, 2022
6f286bf
fixed import problems
kumare3 Jan 16, 2022
3f7f617
fixed test
kumare3 Jan 16, 2022
e0683a3
fixed imports
kumare3 Jan 16, 2022
609ced0
fixed lints and errors
kumare3 Jan 18, 2022
13e1d37
Merge branch 'master' into intratask-checkpoint-1
kumare3 Jan 19, 2022
fa518bb
lint fix
kumare3 Jan 19, 2022
7ff8171
addressed comments
kumare3 Jan 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 87 additions & 35 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os as _os
import pathlib
import traceback as _traceback
from typing import List
from typing import List, Optional

import click as _click
from flyteidl.core import literals_pb2 as _literals_pb2
Expand All @@ -16,6 +16,7 @@
from flytekit.core import constants as _constants
from flytekit.core import utils
from flytekit.core.base_task import IgnoreOutputs, PythonTask
from flytekit.core.checkpointer import SyncCheckpoint
from flytekit.core.context_manager import (
ExecutionParameters,
ExecutionState,
Expand Down Expand Up @@ -164,8 +165,10 @@ def _dispatch_execute(
@contextlib.contextmanager
def setup_execution(
raw_output_data_prefix: str,
dynamic_addl_distro: str = None,
dynamic_dest_dir: str = None,
checkpoint_path: Optional[str] = None,
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
):
ctx = FlyteContextManager.current_context()

Expand All @@ -175,6 +178,11 @@ def setup_execution(
pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
from flytekit import __version__ as _api_version

checkpointer = None
if checkpoint_path is not None:
checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint)
logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}")

execution_parameters = ExecutionParameters(
execution_id=_identifier.WorkflowExecutionIdentifier(
project=_internal_config.EXECUTION_PROJECT.get(),
Expand Down Expand Up @@ -202,6 +210,7 @@ def setup_execution(
logging=python_logging,
tmp_dir=user_workspace_dir,
raw_output_prefix=ctx.file_access._raw_output_prefix,
checkpoint=checkpointer,
)

# TODO: Remove this check for flytekit 1.0
Expand Down Expand Up @@ -266,14 +275,16 @@ def _handle_annotated_task(

@_scopes.system_entry_point
def _execute_task(
inputs,
output_prefix,
raw_output_data_prefix,
test,
inputs: str,
output_prefix: str,
test: bool,
raw_output_data_prefix: str,
resolver: str,
resolver_args: List[str],
dynamic_addl_distro: str = None,
dynamic_dest_dir: str = None,
checkpoint_path: Optional[str] = None,
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
):
"""
This function should be called for new API tasks (those only available in 0.16 and later that leverage Python
Expand Down Expand Up @@ -302,7 +313,13 @@ def _execute_task(
raise Exception("cannot be <1")

with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()):
with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx:
with setup_execution(
raw_output_data_prefix,
checkpoint_path=checkpoint_path,
prev_checkpoint=prev_checkpoint,
dynamic_addl_distro=dynamic_addl_distro,
dynamic_dest_dir=dynamic_dest_dir,
) as ctx:
resolver_obj = load_object_from_module(resolver)
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
Expand All @@ -321,16 +338,20 @@ def _execute_map_task(
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro: str,
dynamic_dest_dir: str,
resolver: str,
resolver_args: List[str],
checkpoint_path: Optional[str] = None,
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
):
if len(resolver_args) < 1:
raise Exception(f"Resolver args cannot be <1, got {resolver_args}")

with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()):
with setup_execution(raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir) as ctx:
with setup_execution(
raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir
) as ctx:
resolver_obj = load_object_from_module(resolver)
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
Expand All @@ -352,6 +373,22 @@ def _execute_map_task(
_handle_annotated_task(ctx, map_task, inputs, output_prefix)


def normalize_inputs(
raw_output_data_prefix: Optional[str], checkpoint_path: Optional[str], prev_checkpoint: Optional[str]
):
# Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original
# template string, so let's explicitly set it to None so that the downstream functions will know to fall back
# to the original shard formatter/prefix config.
if raw_output_data_prefix == "{{.rawOutputDataPrefix}}":
raw_output_data_prefix = None
if checkpoint_path == "{{.checkpointOutputPrefix}}":
checkpoint_path = None
if prev_checkpoint == "{{.prevCheckpointPrefix}}" or prev_checkpoint == "" or prev_checkpoint == '""':
prev_checkpoint = None

return raw_output_data_prefix, checkpoint_path, prev_checkpoint


@_click.group()
def _pass_through():
pass
Expand All @@ -361,6 +398,8 @@ def _pass_through():
@_click.option("--inputs", required=True)
@_click.option("--output-prefix", required=True)
@_click.option("--raw-output-data-prefix", required=False)
@_click.option("--checkpoint-path", required=False)
@_click.option("--prev-checkpoint", required=False)
@_click.option("--test", is_flag=True)
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
Expand All @@ -375,6 +414,8 @@ def execute_task_cmd(
output_prefix,
raw_output_data_prefix,
test,
prev_checkpoint,
checkpoint_path,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
Expand All @@ -383,26 +424,27 @@ def execute_task_cmd(
logger.info(get_version_message())
# We get weird errors if there are no click echo messages at all, so emit an empty string so that unit tests pass.
_click.echo("")
# Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original
# template string, so let's explicitly set it to None so that the downstream functions will know to fall back
# to the original shard formatter/prefix config.
if raw_output_data_prefix == "{{.rawOutputDataPrefix}}":
raw_output_data_prefix = None
raw_output_data_prefix, checkpoint_path, prev_checkpoint = normalize_inputs(
raw_output_data_prefix, checkpoint_path, prev_checkpoint
)

# For new API tasks (as of 0.16.x), we need to call a different function.
# Use the presence of the resolver to differentiate between old API tasks and new API tasks
# The addition of a new top-level command seemed out of scope at the time of this writing to pursue given how
# pervasive this top level command already (plugins mostly).

logger.debug(f"Running task execution with resolver {resolver}...")
_execute_task(
inputs,
output_prefix,
raw_output_data_prefix,
test,
resolver,
resolver_args,
dynamic_addl_distro,
dynamic_dest_dir,
inputs=inputs,
output_prefix=output_prefix,
raw_output_data_prefix=raw_output_data_prefix,
test=test,
resolver=resolver,
resolver_args=resolver_args,
dynamic_addl_distro=dynamic_addl_distro,
dynamic_dest_dir=dynamic_dest_dir,
checkpoint_path=checkpoint_path,
prev_checkpoint=prev_checkpoint,
)


Expand Down Expand Up @@ -446,6 +488,8 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd):
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
@_click.option("--resolver", required=True)
@_click.option("--checkpoint-path", required=False)
@_click.option("--prev-checkpoint", required=False)
@_click.argument(
"resolver-args",
type=_click.UNPROCESSED,
Expand All @@ -461,19 +505,27 @@ def map_execute_task_cmd(
dynamic_dest_dir,
resolver,
resolver_args,
prev_checkpoint,
checkpoint_path,
):
logger.info(get_version_message())

raw_output_data_prefix, checkpoint_path, prev_checkpoint = normalize_inputs(
raw_output_data_prefix, checkpoint_path, prev_checkpoint
)

_execute_map_task(
inputs,
output_prefix,
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
inputs=inputs,
output_prefix=output_prefix,
raw_output_data_prefix=raw_output_data_prefix,
max_concurrency=max_concurrency,
test=test,
dynamic_addl_distro=dynamic_addl_distro,
dynamic_dest_dir=dynamic_dest_dir,
resolver=resolver,
resolver_args=resolver_args,
checkpoint_path=checkpoint_path,
prev_checkpoint=prev_checkpoint,
)


Expand Down
3 changes: 3 additions & 0 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
else:
logger.info("Cache hit")
else:
es = ctx.execution_state
b = es.user_space_params.with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
outputs_literals = outputs_literal_map.literals

Expand Down
Loading