Skip to content
This repository has been archived by the owner on Jul 19, 2024. It is now read-only.

Commit

Permalink
Batch upload flyte directory (flyteorg#1806)
Browse files Browse the repository at this point in the history
* Batch upload flyte directory

Signed-off-by: Kevin Su <[email protected]>

* Update get method

Signed-off-by: Kevin Su <[email protected]>

* Move batch size to type engine

Signed-off-by: Kevin Su <[email protected]>

* comment

Signed-off-by: Kevin Su <[email protected]>

* update comment

Signed-off-by: Kevin Su <[email protected]>

* Update flytekit/core/type_engine.py

Co-authored-by: Eduardo Apolinario <[email protected]>

* Add test

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
Signed-off-by: Future Outlier <[email protected]>
  • Loading branch information
2 people authored and Future Outlier committed Oct 3, 2023
1 parent a8435ea commit 6bdb02d
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 8 deletions.
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@
from flytekit.core.resources import Resources
from flytekit.core.schedule import CronSchedule, FixedRate
from flytekit.core.task import Secret, reference_task, task
from flytekit.core.type_engine import BatchSize
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
Expand Down
10 changes: 5 additions & 5 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def exists(self, path: str) -> bool:
return anon_fs.exists(path)
raise oe

def get(self, from_path: str, to_path: str, recursive: bool = False):
def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
file_system = self.get_filesystem_for_path(from_path)
if recursive:
from_path, to_path = self.recursive_paths(from_path, to_path)
Expand All @@ -194,13 +194,13 @@ def get(self, from_path: str, to_path: str, recursive: bool = False):
return shutil.copytree(
self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True
)
return file_system.get(from_path, to_path, recursive=recursive)
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
except OSError as oe:
logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}")
file_system = self.get_filesystem(get_protocol(from_path), anonymous=True)
if file_system is not None:
logger.debug(f"Attempting anonymous get with {file_system}")
return file_system.get(from_path, to_path, recursive=recursive)
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
raise oe

def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
Expand Down Expand Up @@ -287,7 +287,7 @@ def upload_directory(self, local_path: str, remote_path: str):
"""
return self.put_data(local_path, remote_path, is_multipart=True)

def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False):
def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs):
"""
:param remote_path:
:param local_path:
Expand All @@ -296,7 +296,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False
try:
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)
with timeit(f"Download data to local from {remote_path}"):
self.get(remote_path, to_path=local_path, recursive=is_multipart)
self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs)
except Exception as ex:
raise FlyteAssertion(
f"Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n"
Expand Down
31 changes: 31 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,37 @@
DEFINITIONS = "definitions"


class BatchSize:
"""
This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example,
@task
def t1(directory: Annotated[FlyteDirectory, BatchSize(10)]) -> Annotated[FlyteDirectory, BatchSize(100)]:
...
return FlyteDirectory(...)
In the above example flytekit will download all files from the input `directory` in chunks of 10, i.e. first it
downloads 10 files, loads them to memory, then writes those 10 to local disk, then it loads the next 10, so on
and so forth. Similarly, for outputs, in this case flytekit is going to upload the resulting directory in chunks of
100.
"""

def __init__(self, val: int):
self._val = val

@property
def val(self) -> int:
return self._val


def get_batch_size(t: Type) -> Optional[int]:
if is_annotated(t):
for annotation in get_args(t)[1:]:
if isinstance(annotation, BatchSize):
return annotation.val
return None


class TypeTransformerFailedError(TypeError, AssertionError, ValueError):
...

Expand Down
10 changes: 7 additions & 3 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from marshmallow import fields

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.core.type_engine import TypeEngine, TypeTransformer, get_batch_size
from flytekit.models import types as _type_models
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
Expand Down Expand Up @@ -321,6 +321,8 @@ def to_literal(

remote_directory = None
should_upload = True
batch_size = get_batch_size(python_type)

meta = BlobMetadata(type=self._blob_type(format=self.get_format(python_type)))

# There are two kinds of literals we handle, either an actual FlyteDirectory, or a string path to a directory.
Expand Down Expand Up @@ -357,7 +359,7 @@ def to_literal(
if should_upload:
if remote_directory is None:
remote_directory = ctx.file_access.get_random_remote_directory()
ctx.file_access.put_data(source_path, remote_directory, is_multipart=True)
ctx.file_access.put_data(source_path, remote_directory, is_multipart=True, batch_size=batch_size)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_directory)))

# If not uploading, then we can only take the original source path as the uri.
Expand All @@ -378,8 +380,10 @@ def to_python_value(
# For the remote case, return an FlyteDirectory object that can download
local_folder = ctx.file_access.get_random_local_directory()

batch_size = get_batch_size(expected_python_type)

def _downloader():
return ctx.file_access.get_data(uri, local_folder, is_multipart=True)
return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size)

expected_format = self.get_format(expected_python_type)

Expand Down
10 changes: 10 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,16 @@ def test_dir_no_downloader_default():
assert pv.download() == local_dir


def test_dir_with_batch_size():
flyte_dir = Annotated[FlyteDirectory, BatchSize(100)]
val = flyte_dir("s3://bucket/key")
transformer = TypeEngine.get_transformer(flyte_dir)
ctx = FlyteContext.current_context()
lt = transformer.get_literal_type(flyte_dir)
lv = transformer.to_literal(ctx, val, flyte_dir, lt)
assert val.path == transformer.to_python_value(ctx, lv, flyte_dir).remote_source


def test_dict_transformer():
d = DictTransformer()

Expand Down

0 comments on commit 6bdb02d

Please sign in to comment.