Skip to content

Commit

Permalink
Added Launch options
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 committed Sep 8, 2023
1 parent 5ef611f commit fc4d28f
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 17 deletions.
1 change: 1 addition & 0 deletions flytekit/clis/sdk_in_container/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

def make_field(o: click.Option) -> Field:
if o.multiple:
o.help = click.style("Multiple values allowed.", bold=True) + f"{o.help}"
return field(default_factory=lambda: o.default, metadata={"click.option": o})
return field(default=o.default, metadata={"click.option": o})

Expand Down
109 changes: 92 additions & 17 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
from dataclasses_json import DataClassJsonMixin
from rich.progress import Progress

from flytekit import FlyteContext, Literal
from flytekit import Annotations, FlyteContext, Labels, Literal
from flytekit.clis.sdk_in_container.constants import PyFlyteParams, get_option_from_metadata, make_field
from flytekit.clis.sdk_in_container.helpers import get_remote, patch_image_config
from flytekit.configuration import DefaultImages, ImageConfig
from flytekit.core import context_manager
from flytekit.core.base_task import PythonTask
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
from flytekit.interaction.click_types import FlyteLiteralConverter, JsonParamType
from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback
from flytekit.models import security
from flytekit.models.common import RawOutputDataConfig
from flytekit.models.interface import Parameter, Variable
from flytekit.models.types import SimpleType
from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow
Expand Down Expand Up @@ -142,18 +144,20 @@ class RunLevelParams(PyFlyteParams):
help="Whether to overwrite the cache if it already exists",
)
)
envs: typing.Dict[str, str] = make_field(
envvars: typing.Dict[str, str] = make_field(
click.Option(
param_decls=["--envs", "envs"],
param_decls=["--envvars", "--env"],
required=False,
type=JsonParamType(),
multiple=True,
type=str,
show_default=True,
help="Environment variables to set in the container",
callback=key_value_callback,
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
)
)
tag: typing.List[str] = make_field(
tags: typing.List[str] = make_field(
click.Option(
param_decls=["--tag", "tag"],
param_decls=["--tags", "--tag"],
required=False,
multiple=True,
type=str,
Expand All @@ -170,6 +174,65 @@ class RunLevelParams(PyFlyteParams):
help="Name to assign to this execution",
)
)
labels: typing.Dict[str, str] = make_field(
click.Option(
param_decls=["--labels", "--label"],
required=False,
multiple=True,
type=str,
show_default=True,
callback=key_value_callback,
help="Labels to be attached to the execution of the format `label_key=label_value`.",
)
)
annotations: typing.Dict[str, str] = make_field(
click.Option(
param_decls=["--annotations", "--annotation"],
required=False,
multiple=True,
type=str,
show_default=True,
callback=key_value_callback,
help="Annotations to be attached to the execution of the format `key=value`.",
)
)
raw_output_data_prefix: str = make_field(
click.Option(
param_decls=["--raw-output-data-prefix", "--raw-data-prefix"],
required=False,
type=str,
show_default=True,
help="File Path prefix to store raw output data."
" Examples are file://, s3://, gs:// etc as supported by fsspec."
" If not specified, raw data will be stored in default configured location in remote of locally"
" to temp file system."
+ click.style(
"Note, this is not metadata, but only the raw data location "
"used to store Flytefile, Flytedirectory, Structuredataset,"
" dataframes"
),
)
)
max_parallelism: int = make_field(
click.Option(
param_decls=["--max-parallelism"],
required=False,
type=int,
show_default=True,
help="Number of nodes of a workflow that can be executed in parallel. If not specified,"
" project/domain defaults are used. If 0 then it is unlimited.",
)
)
disable_notifications: bool = make_field(
click.Option(
param_decls=["--disable-notifications"],
required=False,
is_flag=True,
default=False,
show_default=True,
help="Should notifications be disabled for this execution.",
)
)
remote: bool = make_field(
click.Option(
param_decls=["--remote"],
Expand Down Expand Up @@ -333,6 +396,24 @@ def to_click_option(
)


def options_from_run_params(run_level_params: RunLevelParams) -> Options:
return Options(
labels=Labels(run_level_params.labels) if run_level_params.labels else None,
annotations=Annotations(run_level_params.annotations) if run_level_params.annotations else None,
raw_output_data_config=RawOutputDataConfig(output_location_prefix=run_level_params.raw_output_data_prefix)
if run_level_params.raw_output_data_prefix
else None,
max_parallelism=run_level_params.max_parallelism,
disable_notifications=run_level_params.disable_notifications,
security_context=security.SecurityContext(
run_as=security.Identity(k8s_service_account=run_level_params.service_account)
)
if run_level_params.service_account
else None,
notifications=[],
)


def run_remote(
remote: FlyteRemote,
entity: typing.Union[FlyteWorkflow, FlyteTask, FlyteLaunchPlan],
Expand All @@ -345,12 +426,6 @@ def run_remote(
"""
Helper method that executes the given remote FlyteLaunchplan, FlyteWorkflow or FlyteTask
"""
options = None
service_account = run_level_params.service_account
if service_account:
# options are only passed for the execution. This is to prevent errors when registering a duplicate workflow
# It is assumed that the users expectations is to override the service account only for the execution
options = Options.default_from(k8s_service_account=service_account)

execution = remote.execute(
entity,
Expand All @@ -359,11 +434,11 @@ def run_remote(
domain=domain,
name=run_level_params.name,
wait=run_level_params.wait_execution,
options=options,
options=options_from_run_params(run_level_params),
type_hints=type_hints,
overwrite_cache=run_level_params.overwrite_cache,
envs=run_level_params.envs,
tags=run_level_params.tag,
envs=run_level_params.envvars,
tags=run_level_params.tags,
)

console_url = remote.generate_console_url(execution)
Expand Down
15 changes: 15 additions & 0 deletions flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,18 @@ def convert(self, ctx, param, value) -> typing.Union[Literal, typing.Any]:
raise
except Exception as e:
raise click.BadParameter(f"Failed to convert param {param}, {value} to {self._python_type}") from e


def key_value_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
"""
Callback for click to parse key-value pairs.
"""
if not values:
return None
result = {}
for v in values:
if "=" not in v:
raise click.BadParameter(f"Expected key-value pair of the form key=value, got {v}")
k, v = v.split("=", 1)
result[k.strip()] = v.strip()
return result
16 changes: 16 additions & 0 deletions tests/flytekit/unit/interaction/test_click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
FileParamType,
FlyteLiteralConverter,
JsonParamType,
key_value_callback,
)
from flytekit.models.types import SimpleType
from flytekit.remote import FlyteRemote
Expand Down Expand Up @@ -141,3 +142,18 @@ def test_json_type():
yaml.dump({"a": "b"}, f)
f.flush()
assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"}


def test_key_value_callback():
"""Write a test that verifies that the callback works correctly."""
ctx = click.Context(click.Command("test_command"), obj={"remote": True})
assert key_value_callback(ctx, "a", None) is None
assert key_value_callback(ctx, "a", ["a=b"]) == {"a": "b"}
assert key_value_callback(ctx, "a", ["a=b", "c=d"]) == {"a": "b", "c": "d"}
assert key_value_callback(ctx, "a", ["a=b", "c=d", "e=f"]) == {"a": "b", "c": "d", "e": "f"}
with pytest.raises(click.BadParameter):
key_value_callback(ctx, "a", ["a=b", "c"])
with pytest.raises(click.BadParameter):
key_value_callback(ctx, "a", ["a=b", "c=d", "e"])
with pytest.raises(click.BadParameter):
key_value_callback(ctx, "a", ["a=b", "c=d", "e=f", "g"])

0 comments on commit fc4d28f

Please sign in to comment.