-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathparallel.py
40 lines (29 loc) · 972 Bytes
/
parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import kfp
from kfp import dsl
from kfp.components import func_to_container_op
@func_to_container_op
def list_generator_op(parallelism: int) -> str:
"""Generate list for parallel"""
import json
# JSON payload is required for ParallelFor
return json.dumps([x for x in range(parallelism)])
@func_to_container_op
def print_op(msg: str):
"""Print message."""
print(msg)
@dsl.pipeline(
name="ParallelFor example",
description="Shows how to use dsl.ParallelFor()",
)
def pipeline(parallelism: int):
# set the number of parallel
default_conf = kfp.dsl.get_pipeline_conf()
default_conf.set_parallelism(2)
list_task = list_generator_op(parallelism)
parallel_tasks = dsl.ParallelFor(list_task.output)
with parallel_tasks as msg:
print_op(msg)
print_op("Finished").after(parallel_tasks)
if __name__ == "__main__":
# Compile the pipeline
kfp.compiler.Compiler().compile(pipeline, "pipeline.yaml")