Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add clear output to OutputHandler #30

Merged
merged 9 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ edition = "2021"
async-trait = "0.1.74"
bytes = { version = "1.5.0", features = ["serde"] }
chrono = { version = "0.4.31", features = ["serde"] }
enum-as-inner = "0.6.0"
hex = "0.4.3"
indoc = "2.0.4"
lazy_static = "1.4.0"
Expand Down
2 changes: 1 addition & 1 deletion src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub mod outputs;
// export Handlers
pub use debug::DebugHandler;
pub use msg_count::MessageCountHandler;
pub use outputs::OutputHandler;
pub use outputs::SimpleOutputHandler;

#[async_trait::async_trait]
pub trait Handler: Debug + Send + Sync {
Expand Down
94 changes: 78 additions & 16 deletions src/handlers/outputs.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,91 @@
use tokio::sync::{Mutex, RwLock};

use crate::handlers::Handler;
use crate::jupyter::response::Response;
use std::collections::HashMap;
use crate::notebook::Output;

use std::fmt::Debug;
use std::sync::Arc;

#[async_trait::async_trait]
pub trait OutputHandler: Handler + Debug + Send + Sync {
async fn add_cell_content(&self, content: &HashMap<String, serde_json::Value>);
async fn clear_cell_content(&self);
#[derive(Debug, Clone)]
pub struct SimpleOutputHandler {
// interior mutability here because .handle needs to set this and is &self, and when trying
// to change that to &mut self then it broke the delegation of ZMQ messages to Actions over
// in actions.rs. TODO: come back to this when I'm better at Rust?
clear_on_next_output: Arc<Mutex<bool>>,
pub output: Arc<RwLock<Vec<Output>>>,
}

#[allow(clippy::single_match)]
async fn handle_output(&self, msg: &Response) {
match msg {
Response::ExecuteResult(result) => {
self.add_cell_content(&result.content.data).await;
}
_ => {}
impl Default for SimpleOutputHandler {
fn default() -> Self {
Self::new()
}
}

impl SimpleOutputHandler {
pub fn new() -> Self {
Self {
clear_on_next_output: Arc::new(Mutex::new(false)),
output: Arc::new(RwLock::new(vec![])),
}
}

async fn add_cell_content(&self, content: Output) {
self.output.write().await.push(content);
println!("add_cell_content");
}

async fn clear_cell_content(&self) {
self.output.write().await.clear();
println!("clear_cell_content");
}
}

// Need this here so that structs can impl OutputHandler and not get yelled
// at about also needing to impl Handler
#[async_trait::async_trait]
impl<T: OutputHandler + Send + Sync> Handler for T {
impl Handler for SimpleOutputHandler {
async fn handle(&self, msg: &Response) {
self.handle_output(msg).await;
let mut clear_on_next_output = self.clear_on_next_output.lock().await;
match msg {
Response::ExecuteResult(m) => {
let output = Output::ExecuteResult(m.content.clone());
if *clear_on_next_output {
self.clear_cell_content().await;
*clear_on_next_output = false;
}
self.add_cell_content(output).await;
}
Response::Stream(m) => {
let output = Output::Stream(m.content.clone());
if *clear_on_next_output {
self.clear_cell_content().await;
*clear_on_next_output = false;
}
self.add_cell_content(output).await;
}
Response::DisplayData(m) => {
let output = Output::DisplayData(m.content.clone());
if *clear_on_next_output {
self.clear_cell_content().await;
*clear_on_next_output = false;
}
self.add_cell_content(output).await;
}
Response::Error(m) => {
let output = Output::Error(m.content.clone());
if *clear_on_next_output {
self.clear_cell_content().await;
*clear_on_next_output = false;
}
self.add_cell_content(output).await;
}
Response::ClearOutput(m) => {
if m.content.wait {
*clear_on_next_output = true;
} else {
self.clear_cell_content().await;
}
}
_ => {}
}
}
}
3 changes: 1 addition & 2 deletions src/jupyter/iopub_content/clear_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ use bytes::Bytes;

use serde::Deserialize;

#[allow(dead_code)]
#[derive(Deserialize, Debug)]
pub struct ClearOutput {
wait: bool,
pub wait: bool,
}

impl From<Bytes> for ClearOutput {
Expand Down
23 changes: 10 additions & 13 deletions src/jupyter/iopub_content/display_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ https://jupyter-client.readthedocs.io/en/latest/messaging.html#display-data
use std::collections::HashMap;

use bytes::Bytes;
use serde::{Deserialize, Deserializer};
use serde::{Deserialize, Deserializer, Serialize};

#[allow(dead_code)]
#[derive(Deserialize, Debug)]
#[derive(Clone, Deserialize, Debug, Serialize, PartialEq)]
pub struct Transient {
display_id: String,
pub display_id: String,
}

// If the transient field is an empty dict, deserialize it as None
Expand All @@ -28,15 +27,14 @@ where
}
}

#[allow(dead_code)]
#[derive(Deserialize, Debug)]
#[derive(Clone, Deserialize, Debug, Serialize, PartialEq)]
pub struct DisplayData {
data: HashMap<String, serde_json::Value>,
metadata: serde_json::Value,
pub data: HashMap<String, serde_json::Value>,
pub metadata: serde_json::Value,
// Dev note: serde(default) is important here, when using custom deserialize_with and Option
// then it will throw errors when the field is missing unless default is included.
#[serde(default, deserialize_with = "deserialize_transient")]
transient: Option<Transient>,
pub transient: Option<Transient>,
}

impl From<Bytes> for DisplayData {
Expand All @@ -46,15 +44,14 @@ impl From<Bytes> for DisplayData {
}
}

#[allow(dead_code)]
#[derive(Deserialize, Debug)]
pub struct UpdateDisplayData {
data: HashMap<String, serde_json::Value>,
metadata: serde_json::Value,
pub data: HashMap<String, serde_json::Value>,
pub metadata: serde_json::Value,
// Dev note: serde(default) is important here, when using custom deserialize_with and Option
// then it will throw errors when the field is missing unless default is included.
#[serde(default, deserialize_with = "deserialize_transient")]
transient: Option<Transient>,
pub transient: Option<Transient>,
}

impl From<Bytes> for UpdateDisplayData {
Expand Down
11 changes: 5 additions & 6 deletions src/jupyter/iopub_content/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ https://jupyter-client.readthedocs.io/en/latest/messaging.html#execution-errors
*/

use bytes::Bytes;
use serde::Deserialize;
use serde::{Deserialize, Serialize};

#[allow(dead_code)]
#[derive(Deserialize, Debug)]
#[derive(Clone, Serialize, PartialEq, Deserialize, Debug)]
pub struct Error {
ename: String,
evalue: String,
traceback: Vec<String>,
pub ename: String,
pub evalue: String,
pub traceback: Vec<String>,
}

impl From<Bytes> for Error {
Expand Down
8 changes: 4 additions & 4 deletions src/jupyter/iopub_content/execute_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use std::collections::HashMap;

use bytes::Bytes;

use serde::Deserialize;
use serde::{Deserialize, Serialize};

#[allow(dead_code)]
#[derive(Deserialize, Debug)]
#[derive(Clone, Serialize, PartialEq, Deserialize, Debug)]
pub struct ExecuteResult {
execution_count: u32,
pub execution_count: u32,
pub data: HashMap<String, serde_json::Value>,
metadata: serde_json::Value,
pub metadata: serde_json::Value,
}

impl From<Bytes> for ExecuteResult {
Expand Down
14 changes: 8 additions & 6 deletions src/jupyter/iopub_content/stream.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
/*
https://jupyter-client.readthedocs.io/en/latest/messaging.html#streams-stdout-stderr-etc
*/
use crate::notebook::list_or_string_to_string;
use bytes::Bytes;
use serde::Deserialize;
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Debug)]
#[derive(Clone, Serialize, PartialEq, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
enum StreamName {
pub enum StreamName {
Stdout,
Stderr,
}

#[allow(dead_code)]
#[derive(Deserialize, Debug)]
#[derive(Clone, Serialize, PartialEq, Deserialize, Debug)]
pub struct Stream {
name: StreamName,
text: String,
pub name: StreamName,
#[serde(deserialize_with = "list_or_string_to_string")]
pub text: String,
}

impl From<Bytes> for Stream {
Expand Down
29 changes: 9 additions & 20 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,37 @@ use tokio::signal::unix::{signal, SignalKind};
use tokio::time::sleep;

use indoc::indoc;
use std::collections::HashMap;

use std::sync::Arc;
use std::time::Duration;

use kernel_sidecar_rs::handlers::OutputHandler;

#[derive(Debug)]
struct DebugOutputHandler {}

#[async_trait::async_trait]
impl OutputHandler for DebugOutputHandler {
async fn add_cell_content(&self, content: &HashMap<String, serde_json::Value>) {
println!("add_cell_content: {:?}", content);
}

async fn clear_cell_content(&self) {
println!("clear_cell_content");
}
}
use kernel_sidecar_rs::handlers::SimpleOutputHandler;

#[tokio::main]
async fn main() {
let silent = true;
// let kernel = JupyterKernel::ipython(silent);
let kernel = JupyterKernel::deno(silent);
let kernel = JupyterKernel::ipython(silent);
let client = Client::new(kernel.connection_info.clone()).await;
client.heartbeat().await;
// small sleep to make sure iopub is connected,
sleep(Duration::from_millis(50)).await;

let debug_handler = DebugHandler::new();
let msg_count_handler = MessageCountHandler::new();
let output_handler = SimpleOutputHandler::new();
let handlers = vec![
Arc::new(debug_handler) as Arc<dyn Handler>,
Arc::new(msg_count_handler.clone()) as Arc<dyn Handler>,
Arc::new(DebugOutputHandler {}) as Arc<dyn Handler>,
Arc::new(output_handler.clone()) as Arc<dyn Handler>,
];
// let action = client.kernel_info_request(handlers).await;
let code = indoc! {r#"
from IPython.display import clear_output

print("Hello, world!")

print("Before Clear Output")
clear_output()
print("After Clear Output")
"#}
.trim();
let action = client.execute_request(code.to_owned(), handlers).await;
Expand All @@ -68,4 +56,5 @@ async fn main() {
}
}
println!("Message counts: {:?}", msg_count_handler.counts);
println!("Output: {:?}", output_handler.output.read().await);
}
31 changes: 12 additions & 19 deletions src/notebook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
Models a Notebook document. https://ipython.org/ipython-doc/3/notebook/nbformat.html
*/

use crate::jupyter::iopub_content::display_data::DisplayData;
use crate::jupyter::iopub_content::errors::Error;
use crate::jupyter::iopub_content::execute_result::ExecuteResult;
use crate::jupyter::iopub_content::stream::Stream;
use enum_as_inner::EnumAsInner;
use serde::{Deserialize, Deserializer, Serialize};

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
Expand All @@ -14,26 +19,14 @@ pub struct Notebook {
pub nbformat_minor: u32,
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, EnumAsInner)]
#[serde(tag = "output_type", rename_all = "snake_case")]
pub enum Output {
// TODO: look into using the content structs from jupyter/iopub_content instead of redefining?
DisplayData(serde_json::Value),
Stream {
name: String,
#[serde(deserialize_with = "list_or_string_to_string")]
text: String,
},
ExecuteResult {
execution_count: u32,
data: serde_json::Value,
metadata: serde_json::Value,
},
Error {
ename: String,
evalue: String,
traceback: Vec<String>,
},
// TODO: use the content structs from crate::jupyter::iopub_content instead of redefining?
DisplayData(DisplayData),
Stream(Stream),
ExecuteResult(ExecuteResult),
Error(Error),
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
Expand Down Expand Up @@ -71,7 +64,7 @@ pub struct RawCell {
}

// Custom deserialization for source field since it may be a Vec<String> or String
fn list_or_string_to_string<'de, D>(deserializer: D) -> Result<String, D::Error>
pub fn list_or_string_to_string<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
Expand Down
Loading