Skip to content

Commit 48d6450

Browse files
Refactor the event channel (#1912)
* Refactor the event channel * Fix for PR comments * Remove unsafe Co-authored-by: yangzhong <[email protected]>
1 parent a03eea4 commit 48d6450

File tree

12 files changed

+642
-364
lines changed

12 files changed

+642
-364
lines changed

ballista/rust/core/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ simd = ["datafusion/simd"]
3131

3232
[dependencies]
3333
ahash = { version = "0.7", default-features = false }
34-
async-trait = "0.1.36"
34+
async-trait = "0.1.41"
3535
futures = "0.3"
3636
hashbrown = "0.12"
3737
log = "0.4"

ballista/rust/core/src/event_loop.rs

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::atomic::{AtomicBool, Ordering};
19+
use std::sync::Arc;
20+
21+
use async_trait::async_trait;
22+
use log::{error, info};
23+
use tokio::sync::mpsc;
24+
25+
use crate::error::{BallistaError, Result};
26+
27+
#[async_trait]
28+
pub trait EventAction<E>: Send + Sync {
29+
fn on_start(&self);
30+
31+
fn on_stop(&self);
32+
33+
async fn on_receive(&self, event: E) -> Result<Option<E>>;
34+
35+
fn on_error(&self, error: BallistaError);
36+
}
37+
38+
#[derive(Clone)]
39+
pub struct EventLoop<E> {
40+
name: String,
41+
stopped: Arc<AtomicBool>,
42+
buffer_size: usize,
43+
action: Arc<dyn EventAction<E>>,
44+
tx_event: Option<mpsc::Sender<E>>,
45+
}
46+
47+
impl<E: Send + 'static> EventLoop<E> {
48+
pub fn new(
49+
name: String,
50+
buffer_size: usize,
51+
action: Arc<dyn EventAction<E>>,
52+
) -> Self {
53+
Self {
54+
name,
55+
stopped: Arc::new(AtomicBool::new(false)),
56+
buffer_size,
57+
action,
58+
tx_event: None,
59+
}
60+
}
61+
62+
fn run(&self, mut rx_event: mpsc::Receiver<E>) {
63+
assert!(
64+
self.tx_event.is_some(),
65+
"The event sender should be initialized first!"
66+
);
67+
let tx_event = self.tx_event.as_ref().unwrap().clone();
68+
let name = self.name.clone();
69+
let stopped = self.stopped.clone();
70+
let action = self.action.clone();
71+
tokio::spawn(async move {
72+
info!("Starting the event loop {}", name);
73+
while !stopped.load(Ordering::SeqCst) {
74+
if let Some(event) = rx_event.recv().await {
75+
match action.on_receive(event).await {
76+
Ok(Some(event)) => {
77+
if let Err(e) = tx_event.send(event).await {
78+
let msg = format!("Fail to send event due to {}", e);
79+
error!("{}", msg);
80+
action.on_error(BallistaError::General(msg));
81+
}
82+
}
83+
Err(e) => {
84+
error!("Fail to process event due to {}", e);
85+
action.on_error(e);
86+
}
87+
_ => {}
88+
}
89+
} else {
90+
info!("Event Channel closed, shutting down");
91+
break;
92+
}
93+
}
94+
info!("The event loop {} has been stopped", name);
95+
});
96+
}
97+
98+
pub fn start(&mut self) -> Result<()> {
99+
if self.stopped.load(Ordering::SeqCst) {
100+
return Err(BallistaError::General(format!(
101+
"{} has already been stopped",
102+
self.name
103+
)));
104+
}
105+
self.action.on_start();
106+
107+
let (tx_event, rx_event) = mpsc::channel::<E>(self.buffer_size);
108+
self.tx_event = Some(tx_event);
109+
self.run(rx_event);
110+
111+
Ok(())
112+
}
113+
114+
pub fn stop(&self) {
115+
if !self.stopped.swap(true, Ordering::SeqCst) {
116+
self.action.on_stop();
117+
} else {
118+
// Keep quiet to allow calling `stop` multiple times.
119+
}
120+
}
121+
122+
pub fn get_sender(&self) -> Result<EventSender<E>> {
123+
Ok(EventSender {
124+
tx_event: self.tx_event.as_ref().cloned().ok_or_else(|| {
125+
BallistaError::General("Event sender not exist!!!".to_string())
126+
})?,
127+
})
128+
}
129+
}
130+
131+
pub struct EventSender<E> {
132+
tx_event: mpsc::Sender<E>,
133+
}
134+
135+
impl<E> EventSender<E> {
136+
pub async fn post_event(&self, event: E) -> Result<()> {
137+
Ok(self.tx_event.send(event).await.map_err(|e| {
138+
BallistaError::General(format!("Fail to send event due to {}", e))
139+
})?)
140+
}
141+
}

ballista/rust/core/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub fn print_version() {
2525
pub mod client;
2626
pub mod config;
2727
pub mod error;
28+
pub mod event_loop;
2829
pub mod execution_plans;
2930
pub mod utils;
3031

ballista/rust/executor/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ snmalloc = ["snmalloc-rs"]
3232
arrow = { version = "10.0" }
3333
arrow-flight = { version = "10.0" }
3434
anyhow = "1"
35-
async-trait = "0.1.36"
35+
async-trait = "0.1.41"
3636
ballista-core = { path = "../core", version = "0.6.0" }
3737
configure_me = "0.4.0"
3838
datafusion = { path = "../../../datafusion", version = "7.0.0" }

ballista/rust/scheduler/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ tonic = "0.6"
5454
tower = { version = "0.4" }
5555
warp = "0.3"
5656
parking_lot = "0.12"
57+
async-trait = "0.1.41"
5758

5859
[dev-dependencies]
5960
ballista-core = { path = "../core", version = "0.6.0" }

ballista/rust/scheduler/src/main.rs

+12-21
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,13 @@ use ballista_scheduler::state::EtcdClient;
4040
#[cfg(feature = "sled")]
4141
use ballista_scheduler::state::StandaloneClient;
4242

43-
use ballista_scheduler::scheduler_server::{
44-
SchedulerEnv, SchedulerServer, TaskScheduler,
45-
};
43+
use ballista_scheduler::scheduler_server::SchedulerServer;
4644
use ballista_scheduler::state::{ConfigBackend, ConfigBackendClient};
4745

4846
use ballista_core::config::TaskSchedulingPolicy;
4947
use ballista_core::serde::BallistaCodec;
5048
use log::info;
51-
use tokio::sync::{mpsc, RwLock};
49+
use tokio::sync::RwLock;
5250

5351
#[macro_use]
5452
extern crate configure_me;
@@ -81,24 +79,15 @@ async fn start_server(
8179
"Starting Scheduler grpc server with task scheduling policy of {:?}",
8280
policy
8381
);
84-
let scheduler_server: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> =
82+
let mut scheduler_server: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> =
8583
match policy {
86-
TaskSchedulingPolicy::PushStaged => {
87-
// TODO make the buffer size configurable
88-
let (tx_job, rx_job) = mpsc::channel::<String>(10000);
89-
let scheduler_server = SchedulerServer::new_with_policy(
90-
config_backend.clone(),
91-
namespace.clone(),
92-
policy,
93-
Some(SchedulerEnv { tx_job }),
94-
Arc::new(RwLock::new(ExecutionContext::new())),
95-
BallistaCodec::default(),
96-
);
97-
let task_scheduler =
98-
TaskScheduler::new(Arc::new(scheduler_server.clone()));
99-
task_scheduler.start(rx_job);
100-
scheduler_server
101-
}
84+
TaskSchedulingPolicy::PushStaged => SchedulerServer::new_with_policy(
85+
config_backend.clone(),
86+
namespace.clone(),
87+
policy,
88+
Arc::new(RwLock::new(ExecutionContext::new())),
89+
BallistaCodec::default(),
90+
),
10291
_ => SchedulerServer::new(
10392
config_backend.clone(),
10493
namespace.clone(),
@@ -107,6 +96,8 @@ async fn start_server(
10796
),
10897
};
10998

99+
scheduler_server.init().await?;
100+
110101
Server::bind(&addr)
111102
.serve(make_service_fn(move |request: &AddrStream| {
112103
let scheduler_grpc_server =
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
use std::time::Duration;
20+
21+
use async_trait::async_trait;
22+
use log::{debug, warn};
23+
24+
use ballista_core::error::{BallistaError, Result};
25+
use ballista_core::event_loop::EventAction;
26+
use ballista_core::serde::protobuf::{LaunchTaskParams, TaskDefinition};
27+
use ballista_core::serde::scheduler::ExecutorData;
28+
use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan};
29+
30+
use crate::scheduler_server::task_scheduler::TaskScheduler;
31+
use crate::scheduler_server::ExecutorsClient;
32+
use crate::state::SchedulerState;
33+
34+
#[derive(Clone)]
35+
pub(crate) enum SchedulerServerEvent {
36+
JobSubmitted(String),
37+
}
38+
39+
pub(crate) struct SchedulerServerEventAction<
40+
T: 'static + AsLogicalPlan,
41+
U: 'static + AsExecutionPlan,
42+
> {
43+
state: Arc<SchedulerState<T, U>>,
44+
executors_client: ExecutorsClient,
45+
}
46+
47+
impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
48+
SchedulerServerEventAction<T, U>
49+
{
50+
pub fn new(
51+
state: Arc<SchedulerState<T, U>>,
52+
executors_client: ExecutorsClient,
53+
) -> Self {
54+
Self {
55+
state,
56+
executors_client,
57+
}
58+
}
59+
60+
async fn offer_resources(
61+
&self,
62+
job_id: String,
63+
) -> Result<Option<SchedulerServerEvent>> {
64+
let mut available_executors = self.state.get_available_executors_data();
65+
// In case of there's no enough resources, reschedule the tasks of the job
66+
if available_executors.is_empty() {
67+
// TODO Maybe it's better to use an exclusive runtime for this kind task scheduling
68+
warn!("Not enough available executors for task running");
69+
tokio::time::sleep(Duration::from_millis(100)).await;
70+
return Ok(Some(SchedulerServerEvent::JobSubmitted(job_id)));
71+
}
72+
73+
let (tasks_assigment, num_tasks) = self
74+
.state
75+
.fetch_tasks(&mut available_executors, &job_id)
76+
.await?;
77+
if num_tasks > 0 {
78+
self.launch_tasks(&available_executors, tasks_assigment)
79+
.await?;
80+
}
81+
82+
Ok(None)
83+
}
84+
85+
async fn launch_tasks(
86+
&self,
87+
executors: &[ExecutorData],
88+
tasks_assigment: Vec<Vec<TaskDefinition>>,
89+
) -> Result<()> {
90+
for (idx_executor, tasks) in tasks_assigment.into_iter().enumerate() {
91+
if !tasks.is_empty() {
92+
let executor_data = &executors[idx_executor];
93+
debug!(
94+
"Start to launch tasks {:?} to executor {:?}",
95+
tasks
96+
.iter()
97+
.map(|task| {
98+
if let Some(task_id) = task.task_id.as_ref() {
99+
format!(
100+
"{}/{}/{}",
101+
task_id.job_id,
102+
task_id.stage_id,
103+
task_id.partition_id
104+
)
105+
} else {
106+
"".to_string()
107+
}
108+
})
109+
.collect::<Vec<String>>(),
110+
executor_data.executor_id
111+
);
112+
let mut client = {
113+
let clients = self.executors_client.read().await;
114+
clients.get(&executor_data.executor_id).unwrap().clone()
115+
};
116+
// Update the resources first
117+
self.state.save_executor_data(executor_data.clone());
118+
// TODO check whether launching task is successful or not
119+
client.launch_task(LaunchTaskParams { task: tasks }).await?;
120+
} else {
121+
// Since the task assignment policy is round robin,
122+
// if find tasks for one executor is empty, just break fast
123+
break;
124+
}
125+
}
126+
127+
Ok(())
128+
}
129+
}
130+
131+
#[async_trait]
132+
impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
133+
EventAction<SchedulerServerEvent> for SchedulerServerEventAction<T, U>
134+
{
135+
// TODO
136+
fn on_start(&self) {}
137+
138+
// TODO
139+
fn on_stop(&self) {}
140+
141+
async fn on_receive(
142+
&self,
143+
event: SchedulerServerEvent,
144+
) -> Result<Option<SchedulerServerEvent>> {
145+
match event {
146+
SchedulerServerEvent::JobSubmitted(job_id) => {
147+
self.offer_resources(job_id).await
148+
}
149+
}
150+
}
151+
152+
// TODO
153+
fn on_error(&self, _error: BallistaError) {}
154+
}

0 commit comments

Comments
 (0)