Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch upload flyte directory #1806

Merged
merged 7 commits into from
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@
from flytekit.core.resources import Resources
from flytekit.core.schedule import CronSchedule, FixedRate
from flytekit.core.task import Secret, reference_task, task
from flytekit.core.type_engine import BatchSize
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
Expand Down
10 changes: 5 additions & 5 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def exists(self, path: str) -> bool:
return anon_fs.exists(path)
raise oe

def get(self, from_path: str, to_path: str, recursive: bool = False):
def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
file_system = self.get_filesystem_for_path(from_path)
if recursive:
from_path, to_path = self.recursive_paths(from_path, to_path)
Expand All @@ -194,13 +194,13 @@ def get(self, from_path: str, to_path: str, recursive: bool = False):
return shutil.copytree(
self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True
)
return file_system.get(from_path, to_path, recursive=recursive)
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
except OSError as oe:
logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}")
file_system = self.get_filesystem(get_protocol(from_path), anonymous=True)
if file_system is not None:
logger.debug(f"Attempting anonymous get with {file_system}")
return file_system.get(from_path, to_path, recursive=recursive)
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
raise oe

def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
Expand Down Expand Up @@ -287,7 +287,7 @@ def upload_directory(self, local_path: str, remote_path: str):
"""
return self.put_data(local_path, remote_path, is_multipart=True)

def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False):
def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs):
"""
:param remote_path:
:param local_path:
Expand All @@ -296,7 +296,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False
try:
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)
with timeit(f"Download data to local from {remote_path}"):
self.get(remote_path, to_path=local_path, recursive=is_multipart)
self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs)
except Exception as ex:
raise FlyteAssertion(
f"Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n"
Expand Down
31 changes: 31 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,37 @@
DEFINITIONS = "definitions"


class BatchSize:
"""
This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example,

@task
def t1(directory: Annotated[FlyteDirectory, BatchSize(10)]) -> Annotated[FlyteDirectory, BatchSize(100)]:
...
return FlyteDirectory(...)

In the above example flytekit will download all files from the input `directory` in chunks of 10, i.e. first it
downloads 10 files, loads them to memory, then writes those 10 to local disk, then it loads the next 10, so on
and so forth. Similarly, for outputs, in this case flytekit is going to upload the resulting directory in chunks of
100.
"""

def __init__(self, val: int):
self._val = val

@property
def val(self) -> int:
return self._val


def get_batch_size(t: Type) -> Optional[int]:
if is_annotated(t):
for annotation in get_args(t)[1:]:
if isinstance(annotation, BatchSize):
return annotation.val
return None
eapolinario marked this conversation as resolved.
Show resolved Hide resolved


class TypeTransformerFailedError(TypeError, AssertionError, ValueError):
...

Expand Down
15 changes: 12 additions & 3 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from dataclasses_json import config, dataclass_json
from fsspec.utils import get_protocol
from marshmallow import fields
from typing_extensions import get_args

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.core.type_engine import TypeEngine, TypeTransformer, get_batch_size, is_annotated
from flytekit.models import types as _type_models
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
Expand Down Expand Up @@ -322,6 +323,10 @@ def to_literal(

remote_directory = None
should_upload = True
batch_size = get_batch_size(python_type)

if is_annotated(python_type):
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
python_type = get_args(python_type)[0]
meta = BlobMetadata(type=self._blob_type(format=self.get_format(python_type)))

# There are two kinds of literals we handle, either an actual FlyteDirectory, or a string path to a directory.
Expand Down Expand Up @@ -358,7 +363,7 @@ def to_literal(
if should_upload:
if remote_directory is None:
remote_directory = ctx.file_access.get_random_remote_directory()
ctx.file_access.put_data(source_path, remote_directory, is_multipart=True)
ctx.file_access.put_data(source_path, remote_directory, is_multipart=True, batch_size=batch_size)
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_directory)))

# If not uploading, then we can only take the original source path as the uri.
Expand All @@ -379,8 +384,12 @@ def to_python_value(
# For the remote case, return an FlyteDirectory object that can download
local_folder = ctx.file_access.get_random_local_directory()

batch_size = get_batch_size(expected_python_type)
if is_annotated(expected_python_type):
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
expected_python_type = get_args(expected_python_type)[0]

def _downloader():
return ctx.file_access.get_data(uri, local_folder, is_multipart=True)
return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size)

expected_format = self.get_format(expected_python_type)

Expand Down