Skip to content

Commit e30c8f0

Browse files
alixinnevtavernier
authored andcommitted
fix(effects): fix python provider tests
1 parent 25b2a2f commit e30c8f0

File tree

6 files changed

+221
-185
lines changed

6 files changed

+221
-185
lines changed

src/effects.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,17 @@ impl EffectHandle {
162162
let (etx, mut erx) = channel(1);
163163

164164
// Create instance methods
165-
let methods =
166-
InstanceMethods::new(etx, crx, led_count, duration.and_then(|d| d.to_std().ok()));
165+
let methods = Arc::new(InstanceMethods::new(
166+
etx,
167+
crx,
168+
led_count,
169+
duration.and_then(|d| d.to_std().ok()),
170+
));
167171

168172
// Run effect
169173
let join_handle = tokio::task::spawn(async move {
174+
let methods = methods.clone();
175+
170176
// Create the blocking task
171177
let mut run_effect =
172178
tokio::task::spawn_blocking(move || provider.run(&full_path, args, methods));

src/effects/instance.rs

+64-40
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
use std::{
2-
cell::{Cell, RefCell},
3-
time::{Duration, Instant},
4-
};
1+
use std::time::{Duration, Instant};
52

3+
use async_trait::async_trait;
64
use thiserror::Error;
7-
use tokio::sync::mpsc::{Receiver, Sender};
5+
use tokio::sync::{
6+
mpsc::{Receiver, Sender},
7+
Mutex,
8+
};
89

910
use crate::{
1011
image::{RawImage, RawImageError},
@@ -18,12 +19,16 @@ pub enum ControlMessage {
1819
Abort,
1920
}
2021

22+
struct InstanceMethodsData {
23+
crx: Receiver<ControlMessage>,
24+
aborted: bool,
25+
}
26+
2127
pub struct InstanceMethods {
2228
tx: Sender<EffectMessageKind>,
23-
crx: RefCell<Receiver<ControlMessage>>,
2429
led_count: usize,
2530
deadline: Option<Instant>,
26-
aborted: Cell<bool>,
31+
data: Mutex<InstanceMethodsData>,
2732
}
2833

2934
impl InstanceMethods {
@@ -35,23 +40,26 @@ impl InstanceMethods {
3540
) -> Self {
3641
Self {
3742
tx,
38-
crx: crx.into(),
3943
led_count,
4044
deadline: duration.map(|d| Instant::now() + d),
41-
aborted: false.into(),
45+
data: Mutex::new(InstanceMethodsData {
46+
crx: crx.into(),
47+
aborted: false.into(),
48+
}),
4249
}
4350
}
4451

45-
fn completed(&self) -> bool {
46-
self.aborted.get() || self.deadline.map(|d| Instant::now() > d).unwrap_or(false)
52+
fn completed(&self, data: &InstanceMethodsData) -> bool {
53+
data.aborted || self.deadline.map(|d| Instant::now() > d).unwrap_or(false)
4754
}
4855

4956
/// Returns true if the should abort
50-
fn poll_control(&self) -> Result<(), RuntimeMethodError> {
51-
match self.crx.borrow_mut().try_recv() {
57+
async fn poll_control(&self) -> Result<(), RuntimeMethodError> {
58+
let mut data = self.data.lock().await;
59+
match data.crx.try_recv() {
5260
Ok(m) => match m {
5361
ControlMessage::Abort => {
54-
self.aborted.set(true);
62+
data.aborted = true;
5563
return Err(RuntimeMethodError::EffectAborted);
5664
}
5765
},
@@ -62,74 +70,90 @@ impl InstanceMethods {
6270
}
6371
tokio::sync::mpsc::error::TryRecvError::Disconnected => {
6472
// We were disconnected
65-
self.aborted.set(true);
73+
data.aborted = true;
6674
return Err(RuntimeMethodError::EffectAborted);
6775
}
6876
}
6977
}
7078
}
7179

72-
if self.completed() {
80+
if self.completed(&*data) {
7381
Err(RuntimeMethodError::EffectAborted)
7482
} else {
7583
Ok(())
7684
}
7785
}
7886

79-
fn wrap_result<T, E: Into<RuntimeMethodError>>(
87+
async fn wrap_result<T, E: Into<RuntimeMethodError>>(
8088
&self,
8189
res: Result<T, E>,
8290
) -> Result<T, RuntimeMethodError> {
8391
match res {
8492
Ok(t) => Ok(t),
8593
Err(err) => {
8694
// TODO: Log error?
87-
self.aborted.set(true);
95+
self.data.lock().await.aborted = true;
8896
Err(err.into())
8997
}
9098
}
9199
}
92100
}
93101

102+
#[async_trait]
94103
impl RuntimeMethods for InstanceMethods {
95104
fn get_led_count(&self) -> usize {
96105
self.led_count
97106
}
98107

99-
fn abort(&self) -> bool {
100-
self.poll_control().is_err()
108+
async fn abort(&self) -> bool {
109+
self.poll_control().await.is_err()
101110
}
102111

103-
fn set_color(&self, color: crate::models::Color) -> Result<(), RuntimeMethodError> {
104-
self.poll_control()?;
112+
async fn set_color(&self, color: crate::models::Color) -> Result<(), RuntimeMethodError> {
113+
self.poll_control().await?;
105114

106-
self.wrap_result(self.tx.blocking_send(EffectMessageKind::SetColor { color }))
115+
self.wrap_result(self.tx.send(EffectMessageKind::SetColor { color }).await)
116+
.await
107117
}
108118

109-
fn set_led_colors(&self, colors: Vec<crate::models::Color>) -> Result<(), RuntimeMethodError> {
110-
self.poll_control()?;
111-
112-
self.wrap_result(self.tx.blocking_send(EffectMessageKind::SetLedColors {
113-
colors: colors.into(),
114-
}))
119+
async fn set_led_colors(
120+
&self,
121+
colors: Vec<crate::models::Color>,
122+
) -> Result<(), RuntimeMethodError> {
123+
self.poll_control().await?;
124+
125+
self.wrap_result(
126+
self.tx
127+
.send(EffectMessageKind::SetLedColors {
128+
colors: colors.into(),
129+
})
130+
.await,
131+
)
132+
.await
115133
}
116134

117-
fn set_image(&self, image: RawImage) -> Result<(), RuntimeMethodError> {
118-
self.poll_control()?;
119-
120-
self.wrap_result(self.tx.blocking_send(EffectMessageKind::SetImage {
121-
image: image.into(),
122-
}))
135+
async fn set_image(&self, image: RawImage) -> Result<(), RuntimeMethodError> {
136+
self.poll_control().await?;
137+
138+
self.wrap_result(
139+
self.tx
140+
.send(EffectMessageKind::SetImage {
141+
image: image.into(),
142+
})
143+
.await,
144+
)
145+
.await
123146
}
124147
}
125148

126-
pub trait RuntimeMethods {
149+
#[async_trait]
150+
pub trait RuntimeMethods: Send {
127151
fn get_led_count(&self) -> usize;
128-
fn abort(&self) -> bool;
152+
async fn abort(&self) -> bool;
129153

130-
fn set_color(&self, color: Color) -> Result<(), RuntimeMethodError>;
131-
fn set_led_colors(&self, colors: Vec<Color>) -> Result<(), RuntimeMethodError>;
132-
fn set_image(&self, image: RawImage) -> Result<(), RuntimeMethodError>;
154+
async fn set_color(&self, color: Color) -> Result<(), RuntimeMethodError>;
155+
async fn set_led_colors(&self, colors: Vec<Color>) -> Result<(), RuntimeMethodError>;
156+
async fn set_image(&self, image: RawImage) -> Result<(), RuntimeMethodError>;
133157
}
134158

135159
#[derive(Debug, Error)]

src/effects/providers.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{path::Path, sync::Arc};
22

33
use thiserror::Error;
44

5-
use super::InstanceMethods;
5+
use super::instance::RuntimeMethods;
66

77
#[cfg(feature = "python")]
88
mod python;
@@ -41,7 +41,7 @@ pub trait Provider: std::fmt::Debug + Send + Sync {
4141
&self,
4242
full_script_path: &Path,
4343
args: serde_json::Value,
44-
methods: InstanceMethods,
44+
methods: Arc<dyn RuntimeMethods>,
4545
) -> Result<(), ProviderError>;
4646
}
4747

src/effects/providers/python.rs

+38-29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{convert::TryFrom, path::Path};
1+
use std::{convert::TryFrom, path::Path, sync::Arc};
22

33
use pyo3::{
44
exceptions::{PyRuntimeError, PyTypeError},
@@ -36,7 +36,7 @@ impl super::Provider for PythonProvider {
3636
&self,
3737
full_script_path: &Path,
3838
args: serde_json::Value,
39-
methods: crate::effects::instance::InstanceMethods,
39+
methods: Arc<dyn RuntimeMethods>,
4040
) -> Result<(), super::ProviderError> {
4141
Ok(do_run(methods, args, |py| {
4242
// Run script
@@ -63,36 +63,39 @@ impl From<RuntimeMethodError> for PyErr {
6363
/// Check if the effect should abort execution
6464
#[pyfunction]
6565
fn abort() -> bool {
66-
Context::with_current(|m| m.abort())
66+
Context::with_current(|m| async move { m.abort().await })
6767
}
6868

6969
/// Set a new color for the leds
7070
#[pyfunction(args = "*")]
7171
#[pyo3(name = "setColor")]
7272
fn set_color(args: &PyTuple) -> Result<(), PyErr> {
7373
Context::with_current(|m| {
74-
if let Result::<(u8, u8, u8), _>::Ok((r, g, b)) = args.extract() {
75-
m.set_color(Color::new(r, g, b))?;
76-
} else if let Result::<(&PyByteArray,), _>::Ok((bytearray,)) = args.extract() {
77-
if bytearray.len() == 3 * m.get_led_count() {
78-
// Safety: we are not modifying bytearray while accessing it
79-
unsafe {
80-
m.set_led_colors(
81-
bytearray
82-
.as_bytes()
83-
.chunks_exact(3)
84-
.map(|rgb| Color::new(rgb[0], rgb[1], rgb[2]))
85-
.collect(),
86-
)?;
74+
async move {
75+
if let Result::<(u8, u8, u8), _>::Ok((r, g, b)) = args.extract() {
76+
m.set_color(Color::new(r, g, b)).await?;
77+
} else if let Result::<(&PyByteArray,), _>::Ok((bytearray,)) = args.extract() {
78+
if bytearray.len() == 3 * m.get_led_count() {
79+
// Safety: we are not modifying bytearray while accessing it
80+
unsafe {
81+
m.set_led_colors(
82+
bytearray
83+
.as_bytes()
84+
.chunks_exact(3)
85+
.map(|rgb| Color::new(rgb[0], rgb[1], rgb[2]))
86+
.collect(),
87+
)
88+
.await?;
89+
}
90+
} else {
91+
return Err(RuntimeMethodError::InvalidByteArray.into());
8792
}
8893
} else {
89-
return Err(RuntimeMethodError::InvalidByteArray.into());
94+
return Err(RuntimeMethodError::InvalidArguments { name: "setColor" }.into());
9095
}
91-
} else {
92-
return Err(RuntimeMethodError::InvalidArguments { name: "setColor" }.into());
93-
}
9496

95-
Ok(())
97+
Ok(())
98+
}
9699
})
97100
}
98101

@@ -101,13 +104,16 @@ fn set_color(args: &PyTuple) -> Result<(), PyErr> {
101104
#[pyo3(name = "setImage")]
102105
fn set_image(width: u16, height: u16, data: &PyByteArray) -> Result<(), PyErr> {
103106
Context::with_current(|m| {
104-
// unwrap: we did all the necessary checks already
105-
m.set_image(
106-
RawImage::try_from((data.to_vec(), width as u32, height as u32))
107-
.map_err(|err| RuntimeMethodError::InvalidImageData(err))?,
108-
)?;
107+
async move {
108+
// unwrap: we did all the necessary checks already
109+
m.set_image(
110+
RawImage::try_from((data.to_vec(), width as u32, height as u32))
111+
.map_err(|err| RuntimeMethodError::InvalidImageData(err))?,
112+
)
113+
.await?;
109114

110-
Ok(())
115+
Ok(())
116+
}
111117
})
112118
}
113119

@@ -117,7 +123,10 @@ fn hyperion(_py: Python, m: &PyModule) -> PyResult<()> {
117123
m.add_function(wrap_pyfunction!(set_color, m)?)?;
118124
m.add_function(wrap_pyfunction!(set_image, m)?)?;
119125

120-
m.add("ledCount", Context::with_current(|m| m.get_led_count()))?;
126+
m.add(
127+
"ledCount",
128+
Context::with_current(|m| async move { m.get_led_count() }),
129+
)?;
121130

122131
Ok(())
123132
}
@@ -127,7 +136,7 @@ extern "C" fn hyperion_init() -> *mut pyo3::ffi::PyObject {
127136
}
128137

129138
fn do_run<T>(
130-
methods: impl RuntimeMethods + 'static,
139+
methods: Arc<dyn RuntimeMethods>,
131140
args: serde_json::Value,
132141
f: impl FnOnce(Python) -> Result<T, PyErr>,
133142
) -> Result<T, PyErr> {

0 commit comments

Comments
 (0)