generated from astronomer/airflow-provider-sample
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathray_taskflow_example.py
57 lines (45 loc) · 1.27 KB
/
ray_taskflow_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from datetime import datetime
from pathlib import Path
from airflow.decorators import dag, task
from ray_provider.decorators.ray import ray
CONN_ID = "ray_conn"
RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml"
FOLDER_PATH = Path(__file__).parent / "ray_scripts"
RAY_TASK_CONFIG = {
"conn_id": CONN_ID,
"runtime_env": {"working_dir": str(FOLDER_PATH), "pip": ["numpy"]},
"num_cpus": 1,
"num_gpus": 0,
"memory": 0,
"poll_interval": 5,
"ray_cluster_yaml": str(RAY_SPEC),
"xcom_task_key": "dashboard",
}
@dag(
dag_id="Ray_Taskflow_Example",
start_date=datetime(2023, 1, 1),
schedule=None,
catchup=False,
tags=["ray", "example"],
)
def ray_taskflow_dag():
@task
def generate_data():
return [1, 2, 3]
@ray.task(config=RAY_TASK_CONFIG)
def process_data_with_ray(data):
import numpy as np
import ray
@ray.remote
def square(x):
return x**2
ray.init()
data = np.array(data)
futures = [square.remote(x) for x in data]
results = ray.get(futures)
mean = np.mean(results)
print(f"Mean of this population is {mean}")
return mean
data = generate_data()
process_data_with_ray(data)
ray_example_dag = ray_taskflow_dag()