diff --git a/kfp/doc/simple_transform_pipeline.md b/kfp/doc/simple_transform_pipeline.md
index e169f2e03..10341c24b 100644
--- a/kfp/doc/simple_transform_pipeline.md
+++ b/kfp/doc/simple_transform_pipeline.md
@@ -18,6 +18,7 @@ Note: the project and the explanation below are based on [KFPv1](https://www.kub
- [Input parameters definition](#inputs)
- [Pipeline definition](#pipeline)
- [Additional configuration](#add_config)
+ - [Tolerations and node selector](#tolerations)
- [Compiling a pipeline](#compilation)
- [Deploying a pipeline](#deploying)
- [Executing pipeline and watching execution results](#execution)
@@ -210,6 +211,17 @@ The final thing that we need to do is set some pipeline global configuration:
dsl.get_pipeline_conf().set_timeout(ONE_WEEK_SEC)
```
+### KFP pods Toleration and node selector (Optional)
+To apply kuberenetes [Tolerations](https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration/) or [nodeSelector](https://kubernetes.io/docs/concepts/scheduling-eviction/assign-pod-node/#nodeselector) to KFP pods, you need to set `KFP_TOLERATIONS` or `KFP_NODE_SELECTOR` environment variables respectively before compiling the pipeline. Here's an example:
+
+```bash
+export KFP_TOLERATIONS='[{"key": "key","operator": "Equal", "value1": "value", "effect": "NoSchedule"}]'
+
+export KFP_NODE_SELECTOR='{"label_key":"cloud.google.com/gke-accelerator","label_value":"nvidia-tesla-p4"}'
+
+```
+In KFP v1, setting `KFP_TOLERATIONS` will apply to the Ray pods, overriding any tolerations specified in the `ray_head_options` and `ray_worker_options` pipeline parameters if they are present.
+
## Compiling a pipeline
To compile pipeline execute `make workflow-build` command in the same directory where your pipeline is.
diff --git a/kfp/kfp_ray_components/src/create_ray_cluster.py b/kfp/kfp_ray_components/src/create_ray_cluster.py
index fdaf1ec4b..131f20c8c 100644
--- a/kfp/kfp_ray_components/src/create_ray_cluster.py
+++ b/kfp/kfp_ray_components/src/create_ray_cluster.py
@@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
+import json
+import os
import sys
from runtime_utils import KFPUtils, RayRemoteJobs
@@ -42,6 +44,16 @@ def start_ray_cluster(
head_node = head_options | {
"ray_start_params": {"metrics-export-port": "8080", "num-cpus": "0", "dashboard-host": "0.0.0.0"}
}
+ tolerations = os.getenv("KFP_TOLERATIONS", "")
+ if tolerations != "":
+ print(f"Adding tolerations {tolerations} for ray pods")
+ tolerations = json.loads(tolerations)
+ if "tolerations" in head_node:
+ print("Warning: head_node tolerations already defined, will overwrite it")
+ if "tolerations" in worker_node:
+ print("Warning: worker_node tolerations already defined, will overwrite it")
+ head_node["tolerations"] = tolerations
+ worker_node["tolerations"] = tolerations
# create cluster
remote_jobs = RayRemoteJobs(
server_url=server_url,
diff --git a/kfp/kfp_support_lib/kfp_v1_workflow_support/src/workflow_support/compile_utils/component.py b/kfp/kfp_support_lib/kfp_v1_workflow_support/src/workflow_support/compile_utils/component.py
index f7929ccca..1b1c8ba24 100644
--- a/kfp/kfp_support_lib/kfp_v1_workflow_support/src/workflow_support/compile_utils/component.py
+++ b/kfp/kfp_support_lib/kfp_v1_workflow_support/src/workflow_support/compile_utils/component.py
@@ -10,11 +10,20 @@
# limitations under the License.
################################################################################
+import json
import os
import kfp.dsl as dsl
from data_processing.utils import get_logger
from kubernetes import client as k8s_client
+from kubernetes.client import (
+ V1Affinity,
+ V1NodeAffinity,
+ V1NodeSelector,
+ V1NodeSelectorRequirement,
+ V1NodeSelectorTerm,
+ V1Toleration,
+)
logger = get_logger(__name__)
@@ -43,12 +52,55 @@ def add_settings_to_component(
:param image_pull_policy: pull policy to set to the component
:param cache_strategy: cache strategy
"""
+
+ def _add_tolerations() -> None:
+ """
+ Adds Tolerations if specified
+ """
+ try:
+ tolerations = os.getenv("KFP_TOLERATIONS", "")
+ if tolerations != "":
+ print(f"Note: Applying Tolerations {tolerations} to kfp and ray pods")
+
+ # Add Tolerations as env var so it can be used when creating the ray cluster
+ component.add_env_variable(k8s_client.V1EnvVar(name="KFP_TOLERATIONS", value=tolerations))
+
+ tolerations = json.loads(tolerations)
+ for toleration in tolerations:
+ component.add_toleration(
+ V1Toleration(
+ key=toleration["key"],
+ operator=toleration["operator"],
+ value=toleration["value"],
+ effect=toleration["effect"],
+ )
+ )
+ except Exception as e:
+ logger.warning(f"Exception while handling tolerations {e}")
+
+ def _add_node_selector() -> None:
+ """ "
+ Adds mode selector if specified
+ """
+ try:
+ node_selector = os.getenv("KFP_NODE_SELECTOR", "")
+ if node_selector != "":
+ print(f"Note: Applying node_selector {node_selector} to kubeflow pipelines pods")
+ node_selector = json.loads(node_selector)
+ component.add_node_selector_constraint(node_selector["label_key"], node_selector["label_value"])
+ except Exception as e:
+ logger.warning(f"Exception while handling node_selector {e}")
+
# Set cashing
component.execution_options.caching_strategy.max_cache_staleness = cache_strategy
# image pull policy
component.container.set_image_pull_policy(image_pull_policy)
# Set the timeout for the task
component.set_timeout(timeout)
+ # Add tolerations
+ _add_tolerations()
+ # Add affinity
+ _add_node_selector()
@staticmethod
def set_s3_env_vars_to_component(
diff --git a/kfp/kfp_support_lib/kfp_v2_workflow_support/src/workflow_support/compile_utils/component.py b/kfp/kfp_support_lib/kfp_v2_workflow_support/src/workflow_support/compile_utils/component.py
index 6bbc8ef24..695fa936a 100644
--- a/kfp/kfp_support_lib/kfp_v2_workflow_support/src/workflow_support/compile_utils/component.py
+++ b/kfp/kfp_support_lib/kfp_v2_workflow_support/src/workflow_support/compile_utils/component.py
@@ -1,10 +1,16 @@
+import json
+import os
from typing import Dict
import kfp.dsl as dsl
+from data_processing.utils import get_logger
from kfp import kubernetes
+logger = get_logger(__name__)
+
+
RUN_NAME = "KFP_RUN_NAME"
ONE_HOUR_SEC = 60 * 60
@@ -32,6 +38,41 @@ def add_settings_to_component(
:param cache_strategy: cache strategy
"""
+ def _add_tolerations() -> None:
+ try:
+ tolerations = os.getenv("KFP_TOLERATIONS", "")
+ if tolerations != "":
+ # TODO: apply the tolerations defined as env vars to ray pods.
+ # Currently they can be specified in the pipeline params:
+ # ray_head_options and ray_worker_options.
+
+ tolerations = json.loads(tolerations)
+ for toleration in tolerations:
+ kubernetes.add_toleration(
+ task,
+ key=toleration["key"],
+ operator=toleration["operator"],
+ value=toleration["value"],
+ effect=toleration["effect"],
+ )
+
+ except Exception as e:
+ logger.warning(f"Exception while handling tolerations {e}")
+
+ def _add_node_selector() -> None:
+ try:
+ node_selector = os.getenv("KFP_NODE_SELECTOR", "")
+ if node_selector != "":
+ print(f"Note: Applying node_selector {node_selector} to kubeflow pipelines pods")
+ node_selector = json.loads(node_selector)
+ kubernetes.add_node_selector(
+ task,
+ label_key=node_selector["label_key"],
+ label_value=node_selector["label_value"],
+ )
+ except Exception as e:
+ logger.warning(f"Exception while handling node_selector {e}")
+
kubernetes.use_field_path_as_env(
task, env_name=RUN_NAME, field_path="metadata.annotations['pipelines.kubeflow.org/run_name']"
)
@@ -41,6 +82,10 @@ def add_settings_to_component(
kubernetes.set_image_pull_policy(task, image_pull_policy)
# Set the timeout for the task to one day (in seconds)
kubernetes.set_timeout(task, seconds=timeout)
+ # Add tolerations if specified
+ _add_tolerations()
+ # Add node_selector if specified
+ _add_node_selector()
@staticmethod
def set_s3_env_vars_to_component(