Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Extended Resources] GPU Accelerators #1843

Merged
merged 71 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
e029c2a
pip through to container
wild-endeavor Aug 16, 2023
550c3ff
move around
wild-endeavor Aug 16, 2023
dfdb38e
add asserts
wild-endeavor Aug 16, 2023
8bd3446
delete bad line
wild-endeavor Aug 16, 2023
f12e10c
switch to abc and add support for gpu unpartitioned
jeevb Sep 1, 2023
94b6cb8
Add Azure-specific headers when uploading to blob storage (#1784)
devictr Aug 17, 2023
7024010
Add async delete function in base_agent (#1800)
Future-Outlier Aug 19, 2023
34e0c68
Add support for execution name prefixes (#1803)
troychiu Aug 21, 2023
e636025
Remove ref in output (#1794)
wild-endeavor Aug 21, 2023
993df0d
Inherit directly from DataClassJsonMixin instead of using @dataclass_…
ringohoffman Aug 21, 2023
f322aef
Async file sensor (#1790)
pingsutw Aug 23, 2023
3cc6326
Eager workflows to support async workflows (#1579)
cosmicBboy Aug 25, 2023
99d2ea8
Enable SecretsManager.get to load and return bytes (#1798)
ysysys3074 Aug 25, 2023
68f87d9
Batch upload flyte directory (#1806)
pingsutw Aug 26, 2023
ca33b5f
Better error messaging for overrides (#1807)
kumare3 Aug 28, 2023
f48e4e9
Run remote Launchplan from `pyflyte run` (#1785)
kumare3 Aug 29, 2023
26f7de0
Add is none function (#1757)
pingsutw Aug 29, 2023
88a108f
Dynamic workflow should not throw nested task warning (#1812)
Aug 31, 2023
112f740
Add a manual image building GH action (#1816)
wild-endeavor Sep 1, 2023
40a789f
catch abfs protocol in data_persistence.py/get_filesystem and set ano…
fiedlerNr9 Sep 1, 2023
9a599bf
None doesnt work
jeevb Sep 1, 2023
18be31a
unpartitioned selector
jeevb Sep 1, 2023
015e24f
Fix list of annotated structured dataset (#1817)
wild-endeavor Sep 1, 2023
d167977
Support the flytectl config.yaml admin.clientSecretEnvVar option in f…
chaohengstudent Sep 6, 2023
1b6a027
Async agent delete function for while loop case (#1802)
Future-Outlier Sep 7, 2023
9e0a91a
refactor
jeevb Sep 12, 2023
810a5cf
fix docs warnings (#1827)
samhita-alla Sep 11, 2023
305864d
Fix extract_task_module (#1829)
pingsutw Sep 11, 2023
c8fc69d
Feat: Add type support for pydantic BaseModels (#1660)
ArthurBook Sep 11, 2023
e70ac1e
add test for unspecified mig
jeevb Sep 12, 2023
5201d1a
add support for overriding accelerator
jeevb Sep 12, 2023
b8fc677
cleanup
jeevb Sep 12, 2023
fbe8bc5
move from core to extras
jeevb Sep 12, 2023
ad49582
fixes
jeevb Sep 12, 2023
cefa3d3
fixes
jeevb Sep 12, 2023
ab1e0d6
fixes
jeevb Sep 12, 2023
e98517b
cleanup
jeevb Sep 13, 2023
68f423d
Make FlyteRemote slightly more copy/pastable (#1830)
katrogan Sep 12, 2023
2f681e9
Pyflyte meta inputs (#1823)
kumare3 Sep 12, 2023
62255f5
Use mashumaro to serialize/deserialize dataclass (#1735)
hhcs9527 Sep 12, 2023
a1af299
Databricks Agent (#1797)
Future-Outlier Sep 12, 2023
f0fc698
Prometheus metrics (#1815)
pingsutw Sep 13, 2023
c4ade35
Pyflyte register optionally activates schedule (#1832)
kumare3 Sep 14, 2023
3690e41
Remove versions 3.9 and 3.10 (#1831)
wild-endeavor Sep 14, 2023
92d4340
Snowflake agent (#1799)
hhcs9527 Sep 15, 2023
8a7a092
Update agent metric name (#1835)
pingsutw Sep 15, 2023
3e61111
MemVerge MMCloud Agent (#1821)
edwinyyyu Sep 15, 2023
948ae71
Add download badges in readme (#1836)
pingsutw Sep 18, 2023
f63faaf
Eager local entrypoint and support for offloaded types (#1833)
cosmicBboy Sep 18, 2023
92a7fd6
update requirements and add snowflake agent to api reference (#1838)
samhita-alla Sep 19, 2023
2c1f729
Fix: Make sure decks created in elastic task workers are transferred …
fg91 Sep 19, 2023
99abcb4
add accept grpc (#1841)
wild-endeavor Sep 20, 2023
3220a3e
Feat: Enable `flytekit` to authenticate with proxy in front of FlyteA…
fg91 Sep 20, 2023
5c81d17
bump flyteidl
jeevb Sep 20, 2023
8691dcb
Merge branch 'master' into gpu-selector
jeevb Sep 20, 2023
6f112a0
make requirements
jeevb Sep 21, 2023
92acdae
fix failing tests
jeevb Sep 21, 2023
662d1ee
move gpu accelerator to flyteidl.core.Resources
jeevb Sep 23, 2023
ab0c555
Use ResourceExtensions for extended resources
jeevb Oct 2, 2023
c38b76b
cleanup
jeevb Oct 2, 2023
c370262
Switch to using ExtendedResources in TaskTemplate
jeevb Oct 3, 2023
e239b33
Merge remote-tracking branch 'origin/master' into gpu-selector
jeevb Oct 6, 2023
2f342b8
cleanups
jeevb Oct 6, 2023
b95ace8
Merge branch 'master' into gpu-selector
jeevb Oct 23, 2023
41cf59c
update flyteidl
jeevb Oct 24, 2023
ee42a67
Replace _core_task imports with tasks_pb2
jeevb Oct 26, 2023
3c59c67
less verbose definitions
jeevb Oct 27, 2023
ae1f44d
Attempt at less confusing syntax
jeevb Oct 27, 2023
1d19dce
Streamline UX
jeevb Oct 31, 2023
0d35c5f
Run make fmt
jeevb Oct 31, 2023
6ead99c
Merge branch 'master' into gpu-selector
wild-endeavor Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from dataclasses import dataclass
from typing import Any, Coroutine, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast

from flyteidl.core import tasks_pb2

from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import (
ExecutionParameters,
Expand Down Expand Up @@ -344,6 +346,12 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]
"""
return None

def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
return None

def local_execution_mode(self) -> ExecutionState.Mode:
""" """
return ExecutionState.Mode.LOCAL_TASK_EXECUTION
Expand Down
8 changes: 8 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import typing
from typing import Any, List

from flyteidl.core import tasks_pb2

from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.loggers import logger
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
self._aliases: _workflow_model.Alias = None
self._outputs = None
self._resources: typing.Optional[_resources_model] = None
self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None

def runs_before(self, other: Node):
"""
Expand Down Expand Up @@ -172,6 +175,11 @@ def with_overrides(self, *args, **kwargs):
assert_not_promise(v, "container_image")
self.flyte_entity._container_image = v

if "accelerator" in kwargs:
v = kwargs["accelerator"]
assert_not_promise(v, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=v.to_flyte_idl())

return self


Expand Down
15 changes: 15 additions & 0 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from abc import ABC
from typing import Callable, Dict, List, Optional, TypeVar, Union

from flyteidl.core import tasks_pb2

from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
Expand All @@ -13,6 +15,7 @@
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
Expand Down Expand Up @@ -44,6 +47,7 @@ def __init__(
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
**kwargs,
):
"""
Expand All @@ -70,6 +74,7 @@ def __init__(
- `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`__
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
"""
sec_ctx = None
if secret_requests:
Expand Down Expand Up @@ -110,6 +115,7 @@ def __init__(
self._get_command_fn = self.get_default_command

self.pod_template = pod_template
self.accelerator = accelerator

@property
def task_resolver(self) -> TaskResolverMixin:
Expand Down Expand Up @@ -219,6 +225,15 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]
return {}
return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}

def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
if self.accelerator is None:
return None

return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl())


class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
"""
Expand Down
6 changes: 6 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
from flytekit.core.resources import Resources
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.models.documentation import Documentation
from flytekit.models.security import Secret
Expand Down Expand Up @@ -102,6 +103,7 @@ def task(
enable_deck: Optional[bool] = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]:
...

Expand Down Expand Up @@ -129,6 +131,7 @@ def task(
enable_deck: Optional[bool] = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]:
...

Expand All @@ -155,6 +158,7 @@ def task(
enable_deck: Optional[bool] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], Callable[..., FuncOut]]:
"""
This is the core decorator to use for any task type in flytekit.
Expand Down Expand Up @@ -248,6 +252,7 @@ def foo2():
:param docs: Documentation about this task
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
"""

def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
Expand Down Expand Up @@ -277,6 +282,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
docs=docs,
pod_template=pod_template,
pod_template_name=pod_template_name,
accelerator=accelerator,
)
update_wrapper(task_instance, fn)
return task_instance
Expand Down
90 changes: 90 additions & 0 deletions flytekit/extras/accelerators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import abc
import copy
from typing import ClassVar, Generic, Optional, Type, TypeVar

from flyteidl.core import tasks_pb2

T = TypeVar("T")
MIG = TypeVar("MIG", bound="MultiInstanceGPUAccelerator")


class BaseAccelerator(abc.ABC, Generic[T]):
@abc.abstractmethod
def to_flyte_idl(self) -> T:
...


class GPUAccelerator(BaseAccelerator):
def __init__(self, device: str) -> None:
self._device = device

def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator:
return tasks_pb2.GPUAccelerator(device=self._device)


A10G = GPUAccelerator("nvidia-a10g")
L4 = GPUAccelerator("nvidia-l4-vws")
K80 = GPUAccelerator("nvidia-tesla-k80")
M60 = GPUAccelerator("nvidia-tesla-m60")
P4 = GPUAccelerator("nvidia-tesla-p4")
P100 = GPUAccelerator("nvidia-tesla-p100")
T4 = GPUAccelerator("nvidia-tesla-t4")
V100 = GPUAccelerator("nvidia-tesla-v100")


class MultiInstanceGPUAccelerator(BaseAccelerator):
device: ClassVar[str]
_partition_size: Optional[str]

@property
def unpartitioned(self: MIG) -> MIG:
instance = copy.deepcopy(self)
instance._partition_size = None
return instance

@classmethod
def partitioned(cls: Type[MIG], partition_size: str) -> MIG:
instance = cls()
instance._partition_size = partition_size
return instance

def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator:
msg = tasks_pb2.GPUAccelerator(device=self.device)
if not hasattr(self, "_partition_size"):
return msg

if self._partition_size is None:
msg.unpartitioned = True
else:
msg.partition_size = self._partition_size
return msg


class _A100_Base(MultiInstanceGPUAccelerator):
device = "nvidia-tesla-a100"


class _A100(_A100_Base):
partition_1g_5gb = _A100_Base.partitioned("1g.5gb")
partition_2g_10gb = _A100_Base.partitioned("2g.10gb")
partition_3g_20gb = _A100_Base.partitioned("3g.20gb")
partition_4g_20gb = _A100_Base.partitioned("4g.20gb")
partition_7g_40gb = _A100_Base.partitioned("7g.40gb")


A100 = _A100()


class _A100_80GB_Base(MultiInstanceGPUAccelerator):
device = "nvidia-a100-80gb"


class _A100_80GB(_A100_80GB_Base):
partition_1g_10gb = _A100_80GB_Base.partitioned("1g.10gb")
partition_2g_20gb = _A100_80GB_Base.partitioned("2g.20gb")
partition_3g_40gb = _A100_80GB_Base.partitioned("3g.40gb")
partition_4g_40gb = _A100_80GB_Base.partitioned("4g.40gb")
partition_7g_80gb = _A100_80GB_Base.partitioned("7g.80gb")


A100_80GB = _A100_80GB()
16 changes: 13 additions & 3 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import typing

from flyteidl.core import tasks_pb2
from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit.models import common as _common
Expand Down Expand Up @@ -562,24 +563,33 @@ def from_flyte_idl(cls, pb2_object):


class TaskNodeOverrides(_common.FlyteIdlEntity):
def __init__(self, resources: typing.Optional[Resources] = None):
def __init__(
self, resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources]
):
self._resources = resources
self._extended_resources = extended_resources

@property
def resources(self) -> Resources:
return self._resources

@property
def extended_resources(self) -> tasks_pb2.ExtendedResources:
return self._extended_resources

def to_flyte_idl(self):
return _core_workflow.TaskNodeOverrides(
resources=self.resources.to_flyte_idl() if self.resources is not None else None,
extended_resources=self.extended_resources,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
resources = Resources.from_flyte_idl(pb2_object.resources)
extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None
if bool(resources.requests) or bool(resources.limits):
return cls(resources=resources)
return cls(resources=None)
return cls(resources=resources, extended_resources=extended_resources)
return cls(resources=None, extended_resources=extended_resources)


class TaskNode(_common.FlyteIdlEntity):
Expand Down
13 changes: 13 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def __init__(
config=None,
k8s_pod=None,
sql=None,
extended_resources=None,
):
"""
A task template represents the full set of information necessary to perform a unit of work in the Flyte system.
Expand All @@ -359,6 +360,7 @@ def __init__(
in tandem with the custom.
:param K8sPod k8s_pod: Alternative to the container used to execute this task.
:param Sql sql: This is used to execute query in FlytePropeller instead of running container or k8s_pod.
:param flyteidl.core.tasks_pb2.ExtendedResources extended_resources: The extended resources to allocate to the task.
"""
if (
(container is not None and k8s_pod is not None)
Expand All @@ -377,6 +379,7 @@ def __init__(
self._security_context = security_context
self._k8s_pod = k8s_pod
self._sql = sql
self._extended_resources = extended_resources

@property
def id(self):
Expand Down Expand Up @@ -451,6 +454,14 @@ def k8s_pod(self):
def sql(self):
return self._sql

@property
def extended_resources(self):
"""
If not None, the extended resources to allocate to the task.
:rtype: flyteidl.core.tasks_pb2.ExtendedResources
"""
return self._extended_resources

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.tasks_pb2.TaskTemplate
Expand All @@ -464,6 +475,7 @@ def to_flyte_idl(self):
container=self.container.to_flyte_idl() if self.container else None,
task_type_version=self.task_type_version,
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
extended_resources=self.extended_resources,
config={k: v for k, v in self.config.items()} if self.config is not None else None,
k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None,
sql=self.sql.to_flyte_idl() if self.sql else None,
Expand All @@ -487,6 +499,7 @@ def from_flyte_idl(cls, pb2_object):
security_context=_sec.SecurityContext.from_flyte_idl(pb2_object.security_context)
if pb2_object.security_context and pb2_object.security_context.ByteSize() > 0
else None,
extended_resources=pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None,
config={k: v for k, v in pb2_object.config.items()} if pb2_object.config is not None else None,
k8s_pod=K8sPod.from_flyte_idl(pb2_object.k8s_pod) if pb2_object.HasField("k8s_pod") else None,
sql=Sql.from_flyte_idl(pb2_object.sql) if pb2_object.HasField("sql") else None,
Expand Down
9 changes: 6 additions & 3 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def get_serializable_task(
config=entity.get_config(settings),
k8s_pod=pod,
sql=entity.get_sql(settings),
extended_resources=entity.get_extended_resources(settings),
)
if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask):
entity.reset_command_fn()
Expand Down Expand Up @@ -440,7 +441,8 @@ def get_serializable_node(
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
task_node=workflow_model.TaskNode(
reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources)
reference_id=task_spec.template.id,
overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources),
),
)
if entity._aliases:
Expand Down Expand Up @@ -516,7 +518,8 @@ def get_serializable_node(
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
task_node=workflow_model.TaskNode(
reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources)
reference_id=entity.flyte_entity.id,
overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources),
),
)
elif isinstance(entity.flyte_entity, FlyteWorkflow):
Expand Down Expand Up @@ -565,7 +568,7 @@ def get_serializable_array_node(
task_spec = get_serializable(entity_mapping, settings, entity, options)
task_node = workflow_model.TaskNode(
reference_id=task_spec.template.id,
overrides=TaskNodeOverrides(resources=node._resources),
overrides=TaskNodeOverrides(resources=node._resources, extended_resources=node._extended_resources),
)
node = workflow_model.Node(
id=entity.name,
Expand Down
Loading