Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for copying all the files in source root #1622

Merged
merged 4 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flytekit/clis/sdk_in_container/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CTX_PROJECT_ROOT = "project_root"
CTX_MODULE = "module"
CTX_VERBOSE = "verbose"
CTX_COPY_ALL = "copy_all"


project_option = _click.option(
Expand Down
9 changes: 9 additions & 0 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flytekit import BlobType, Literal, Scalar
from flytekit.clis.sdk_in_container.constants import (
CTX_CONFIG_FILE,
CTX_COPY_ALL,
CTX_DOMAIN,
CTX_MODULE,
CTX_PROJECT,
Expand Down Expand Up @@ -512,6 +513,13 @@ def get_workflow_command_base_params() -> typing.List[click.Option]:
default="/root",
help="Directory inside the image where the tar file containing the code will be copied to",
),
click.Option(
param_decls=["--copy-all", "copy_all"],
required=False,
is_flag=True,
default=False,
help="Copy all files in the source root directory to the destination directory",
),
click.Option(
param_decls=["-i", "--image", "image_config"],
required=False,
Expand Down Expand Up @@ -643,6 +651,7 @@ def _run(*args, **kwargs):
destination_dir=run_level_params.get("destination_dir"),
source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT),
module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE),
copy_all=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_COPY_ALL),
)

options = None
Expand Down
25 changes: 15 additions & 10 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,17 +799,19 @@ def register_script(
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
destination_dir: str = ".",
default_launch_plan: typing.Optional[bool] = True,
copy_all: bool = False,
default_launch_plan: bool = True,
options: typing.Optional[Options] = None,
source_path: typing.Optional[str] = None,
module_name: typing.Optional[str] = None,
) -> typing.Union[FlyteWorkflow, FlyteTask]:
"""
Use this method to register a workflow via script mode.
:param destination_dir:
:param domain:
:param project:
:param image_config:
:param destination_dir: The destination directory where the workflow will be copied to.
:param copy_all: If true, the entire source directory will be copied over to the destination directory.
:param domain: The domain to register the workflow in.
:param project: The project to register the workflow in.
:param image_config: The image config to use for the workflow.
:param version: version for the entity to be registered as
:param entity: The workflow to be registered or the task to be registered
:param default_launch_plan: This should be true if a default launch plan should be created for the workflow
Expand All @@ -822,11 +824,14 @@ def register_script(
image_config = ImageConfig.auto_default_image()

with tempfile.TemporaryDirectory() as tmp_dir:
archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz"))
compress_scripts(source_path, str(archive_fname), module_name)
md5_bytes, upload_native_url = self.upload_file(
archive_fname, project or self.default_project, domain or self.default_domain
)
if copy_all:
md5_bytes, upload_native_url = self.fast_package(pathlib.Path(source_path), False, tmp_dir)
else:
archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz"))
compress_scripts(source_path, str(archive_fname), module_name)
md5_bytes, upload_native_url = self.upload_file(
archive_fname, project or self.default_project, domain or self.default_domain
)

serialization_settings = SerializationSettings(
project=project,
Expand Down
10 changes: 10 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def test_imperative_wf():
assert result.exit_code == 0


def test_copy_all_files():
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
["run", "--copy-all", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"],
catch_exceptions=False,
)
assert result.exit_code == 0


def test_pyflyte_run_cli():
runner = CliRunner()
parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet")
Expand Down