Skip to content

Commit

Permalink
Restore the Imaginate node with the full node graph architecture (but…
Browse files Browse the repository at this point in the history
… a flaky deadlock remains) (#1908)

* Rework imaginate trigger mechanism

* Fix imaginate generation
  • Loading branch information
TrueDoctor authored Aug 7, 2024
1 parent 8041b12 commit 06a409f
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 36 deletions.
4 changes: 3 additions & 1 deletion editor/src/messages/portfolio/document/document_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ pub enum DocumentMessage {
GridOverlays(OverlayContext),
GridVisibility(bool),
GroupSelectedLayers,
ImaginateGenerate,
ImaginateGenerate {
imaginate_node: Vec<NodeId>,
},
ImaginateRandom {
imaginate_node: Vec<NodeId>,
then_generate: bool,
Expand Down
14 changes: 12 additions & 2 deletions editor/src/messages/portfolio/document/document_message_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,17 @@ impl MessageHandler<DocumentMessage, DocumentMessageData<'_>> for DocumentMessag

responses.add(DocumentMessage::MoveSelectedLayersToGroup { parent: new_group_folder });
}
DocumentMessage::ImaginateGenerate => responses.add(PortfolioMessage::SubmitGraphRender { document_id, ignore_hash: false }),
DocumentMessage::ImaginateGenerate { imaginate_node } => {
let random_value = generate_uuid();
responses.add(NodeGraphMessage::SetInputValue {
node_id: *imaginate_node.last().unwrap(),
// Needs to match the index of the seed parameter in `pub const IMAGINATE_NODE: DocumentNodeDefinition` in `document_node_type.rs`
input_index: 17,
value: graph_craft::document::value::TaggedValue::U64(random_value),
});

responses.add(PortfolioMessage::SubmitGraphRender { document_id, ignore_hash: false });
}
DocumentMessage::ImaginateRandom { imaginate_node, then_generate } => {
// Generate a random seed. We only want values between -2^53 and 2^53, because integer values
// outside of this range can get rounded in f64
Expand All @@ -511,7 +521,7 @@ impl MessageHandler<DocumentMessage, DocumentMessageData<'_>> for DocumentMessag

// Generate the image
if then_generate {
responses.add(DocumentMessage::ImaginateGenerate);
responses.add(DocumentMessage::ImaginateGenerate { imaginate_node });
}
}
DocumentMessage::ImportSvg {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4272,14 +4272,14 @@ pub static IMAGINATE_NODE: Lazy<DocumentNodeDefinition> = Lazy::new(|| DocumentN
DocumentNode {
inputs: vec![NodeInput::network(concrete!(ImageFrame<Color>), 0)],
implementation: DocumentNodeImplementation::proto("graphene_core::memo::MonitorNode<_, _, _>"),
manual_composition: Some(generic!(T)),
manual_composition: Some(concrete!(())),
skip_deduplication: true,
..Default::default()
},
DocumentNode {
inputs: vec![
NodeInput::node(NodeId(0), 0),
NodeInput::scope("editor-api"),
NodeInput::network(concrete!(&WasmEditorApi), 1),
NodeInput::network(concrete!(ImaginateController), 2),
NodeInput::network(concrete!(f64), 3),
NodeInput::network(concrete!(Option<DVec2>), 4),
Expand All @@ -4295,8 +4295,9 @@ pub static IMAGINATE_NODE: Lazy<DocumentNodeDefinition> = Lazy::new(|| DocumentN
NodeInput::network(concrete!(ImaginateMaskStartingFill), 14),
NodeInput::network(concrete!(bool), 15),
NodeInput::network(concrete!(bool), 16),
NodeInput::network(concrete!(u64), 17),
],
implementation: DocumentNodeImplementation::proto("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
implementation: DocumentNodeImplementation::proto("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
..Default::default()
},
]
Expand All @@ -4310,7 +4311,7 @@ pub static IMAGINATE_NODE: Lazy<DocumentNodeDefinition> = Lazy::new(|| DocumentN
NodeInput::value(TaggedValue::ImageFrame(ImageFrame::empty()), true),
NodeInput::scope("editor-api"),
NodeInput::value(TaggedValue::ImaginateController(Default::default()), false),
NodeInput::value(TaggedValue::U64(0), false), // Remember to keep index used in `ImaginateRandom` updated with this entry's index
NodeInput::value(TaggedValue::F64(0.), false), // Remember to keep index used in `ImaginateRandom` updated with this entry's index
NodeInput::value(TaggedValue::OptionalDVec2(None), false),
NodeInput::value(TaggedValue::U32(30), false),
NodeInput::value(TaggedValue::ImaginateSamplingMethod(ImaginateSamplingMethod::EulerA), false),
Expand All @@ -4324,6 +4325,7 @@ pub static IMAGINATE_NODE: Lazy<DocumentNodeDefinition> = Lazy::new(|| DocumentN
NodeInput::value(TaggedValue::ImaginateMaskStartingFill(ImaginateMaskStartingFill::Fill), false),
NodeInput::value(TaggedValue::Bool(false), false),
NodeInput::value(TaggedValue::Bool(false), false),
NodeInput::value(TaggedValue::U64(0), false),
],
..Default::default()
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1769,9 +1769,13 @@ pub fn imaginate_properties(document_node: &DocumentNode, node_id: NodeId, conte
.tooltip("Fill layer frame by generating a new image")
.on_update({
let controller = controller.clone();
let imaginate_node = imaginate_node.clone();
move |_| {
controller.trigger_regenerate();
DocumentMessage::ImaginateGenerate.into()
DocumentMessage::ImaginateGenerate {
imaginate_node: imaginate_node.clone(),
}
.into()
}
})
.widget_holder(),
Expand All @@ -1781,9 +1785,13 @@ pub fn imaginate_properties(document_node: &DocumentNode, node_id: NodeId, conte
.disabled(!matches!(imaginate_status, ImaginateStatus::ReadyDone))
.on_update({
let controller = controller.clone();
let imaginate_node = imaginate_node.clone();
move |_| {
controller.set_status(ImaginateStatus::Ready);
DocumentMessage::ImaginateGenerate.into()
DocumentMessage::ImaginateGenerate {
imaginate_node: imaginate_node.clone(),
}
.into()
}
})
.widget_holder(),
Expand Down Expand Up @@ -1858,8 +1866,8 @@ pub fn imaginate_properties(document_node: &DocumentNode, node_id: NodeId, conte

let mut widgets = start_widgets(document_node, node_id, resolution_index, "Resolution", FrontendGraphDataType::Number, false);

let round = |x: DVec2| {
let (x, y) = pick_safe_imaginate_resolution(x.into());
let round = |size: DVec2| {
let (x, y) = pick_safe_imaginate_resolution(size.into());
DVec2::new(x as f64, y as f64)
};

Expand Down
10 changes: 5 additions & 5 deletions node-graph/graph-craft/src/imaginate_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl std::cmp::PartialEq for ImaginateCache {

impl core::hash::Hash for ImaginateCache {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.0.lock().unwrap().hash(state);
self.0.try_lock().map(|g| g.hash(state));
}
}

Expand All @@ -50,11 +50,11 @@ pub struct ImaginateController(Arc<InternalImaginateControl>);

impl ImaginateController {
pub fn get_status(&self) -> ImaginateStatus {
self.0.status.lock().as_deref().cloned().unwrap_or_default()
self.0.status.try_lock().as_deref().cloned().unwrap_or_default()
}

pub fn set_status(&self, status: ImaginateStatus) {
if let Ok(mut lock) = self.0.status.lock() {
if let Ok(mut lock) = self.0.status.try_lock() {
*lock = status
}
}
Expand All @@ -68,13 +68,13 @@ impl ImaginateController {
}

pub fn request_termination(&self) {
if let Some(handle) = self.0.termination_sender.lock().ok().and_then(|mut lock| lock.take()) {
if let Some(handle) = self.0.termination_sender.try_lock().ok().and_then(|mut lock| lock.take()) {
handle.terminate()
}
}

pub fn set_termination_handle<H: ImaginateTerminationHandle>(&self, handle: Box<H>) {
if let Ok(mut lock) = self.0.termination_sender.lock() {
if let Ok(mut lock) = self.0.termination_sender.try_lock() {
*lock = Some(handle)
}
}
Expand Down
17 changes: 10 additions & 7 deletions node-graph/gstd/src/imaginate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,28 +500,31 @@ fn base64_to_image<D: AsRef<[u8]>, P: Pixel>(base64_data: D) -> Result<Image<P>,
}

pub fn pick_safe_imaginate_resolution((width, height): (f64, f64)) -> (u64, u64) {
const NATIVE_MODEL_RESOLUTION: f64 = 512.;
let size = if width * height == 0. { DVec2::splat(NATIVE_MODEL_RESOLUTION) } else { DVec2::new(width, height) };

const MAX_RESOLUTION: u64 = 1000 * 1000;

// this is the maximum width/height that can be obtained
// This is the maximum width/height that can be obtained
const MAX_DIMENSION: u64 = (MAX_RESOLUTION / 64) & !63;

// round the resolution to the nearest multiple of 64
let size = (DVec2::new(width, height).round().clamp(DVec2::ZERO, DVec2::splat(MAX_DIMENSION as _)).as_u64vec2() + U64Vec2::splat(32)).max(U64Vec2::splat(64)) & !U64Vec2::splat(63);
// Round the resolution to the nearest multiple of 64
let size = (size.round().clamp(DVec2::ZERO, DVec2::splat(MAX_DIMENSION as _)).as_u64vec2() + U64Vec2::splat(32)).max(U64Vec2::splat(64)) & !U64Vec2::splat(63);
let resolution = size.x * size.y;

if resolution > MAX_RESOLUTION {
// scale down the image, so it is smaller than MAX_RESOLUTION
// Scale down the image, so it is smaller than MAX_RESOLUTION
let scale = (MAX_RESOLUTION as f64 / resolution as f64).sqrt();
let size = size.as_dvec2() * scale;

if size.x < 64.0 {
// the image is extremely wide
// The image is extremely wide
(64, MAX_DIMENSION)
} else if size.y < 64.0 {
// the image is extremely high
// The image is extremely high
(MAX_DIMENSION, 64)
} else {
// round down to a multiple of 64, so that the resolution still is smaller than MAX_RESOLUTION
// Round down to a multiple of 64, so that the resolution still is smaller than MAX_RESOLUTION
(size.as_u64vec2() & !U64Vec2::splat(63)).into()
}
} else {
Expand Down
41 changes: 32 additions & 9 deletions node-graph/gstd/src/raster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,28 +474,32 @@ fn empty_image<_P: Pixel>(transform: DAffine2, color: _P) -> ImageFrame<_P> {
#[cfg(feature = "serde")]
macro_rules! generate_imaginate_node {
($($val:ident: $t:ident: $o:ty,)*) => {
pub struct ImaginateNode<P: Pixel, E, C, $($t,)*> {
pub struct ImaginateNode<P: Pixel, E, C, G, $($t,)*> {
editor_api: E,
controller: C,
generation_id: G,
$($val: $t,)*
cache: std::sync::Arc<std::sync::Mutex<HashMap<u64, Image<P>>>>,
last_generation: std::sync::atomic::AtomicU64,
}

impl<'e, P: Pixel, E, C, $($t,)*> ImaginateNode<P, E, C, $($t,)*>
impl<'e, P: Pixel, E, C, G, $($t,)*> ImaginateNode<P, E, C, G, $($t,)*>
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, &'e WasmEditorApi>>,
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
G: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, u64>>,
{
#[allow(clippy::too_many_arguments)]
pub fn new(editor_api: E, controller: C, $($val: $t,)* ) -> Self {
Self { editor_api, controller, $($val,)* cache: Default::default() }
pub fn new(editor_api: E, controller: C, $($val: $t,)* generation_id: G ) -> Self {
Self { editor_api, controller, generation_id, $($val,)* cache: Default::default(), last_generation: std::sync::atomic::AtomicU64::new(u64::MAX) }
}
}

impl<'i, 'e: 'i, P: Pixel + 'i + Hash + Default + Send, E: 'i, C: 'i, $($t: 'i,)*> Node<'i, ImageFrame<P>> for ImaginateNode<P, E, C, $($t,)*>
impl<'i, 'e: 'i, P: Pixel + 'i + Hash + Default + Send, E: 'i, C: 'i, G: 'i, $($t: 'i,)*> Node<'i, ImageFrame<P>> for ImaginateNode<P, E, C, G, $($t,)*>
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, &'e WasmEditorApi>>,
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
G: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, u64>>,
{
type Output = DynFuture<'i, ImageFrame<P>>;

Expand All @@ -509,26 +513,45 @@ macro_rules! generate_imaginate_node {
let hash = hasher.finish();
let editor_api = self.editor_api.eval(());
let cache = self.cache.clone();
let generation_future = self.generation_id.eval(());
let last_generation = &self.last_generation;

Box::pin(async move {
// let controller: std::pin::Pin<Box<dyn std::future::Future<Output = ImaginateController> + Send>> = controller;
let controller: ImaginateController = controller.await;
if controller.take_regenerate_trigger() {
let generation_id = generation_future.await;
if generation_id != last_generation.swap(generation_id, std::sync::atomic::Ordering::SeqCst) {
let image = super::imaginate::imaginate(frame.image, editor_api, controller, $($val,)*).await;

cache.lock().unwrap().insert(hash, image.clone());

return ImageFrame { image, ..frame }
return wrap_image_frame(image, frame.transform);
}
let image = cache.lock().unwrap().get(&hash).cloned().unwrap_or_default();

ImageFrame { image, ..frame }
return wrap_image_frame(image, frame.transform);
})
}
}
}
}

fn wrap_image_frame<P: Pixel>(image: Image<P>, transform: DAffine2) -> ImageFrame<P> {
if !transform.decompose_scale().abs_diff_eq(DVec2::ZERO, 0.00001) {
ImageFrame {
image,
transform,
alpha_blending: AlphaBlending::default(),
}
} else {
let resolution = DVec2::new(image.height as f64, image.width as f64);
ImageFrame {
image,
transform: DAffine2::from_scale_angle_translation(resolution, 0., transform.translation),
alpha_blending: AlphaBlending::default(),
}
}
}

#[cfg(feature = "serde")]
generate_imaginate_node! {
seed: Seed: f64,
Expand Down
9 changes: 5 additions & 4 deletions node-graph/interpreted-executor/src/node_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,14 +568,14 @@ fn node_registry() -> HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeCons
raster_node!(graphene_core::raster::PosterizeNode<_>, params: [f64]),
raster_node!(graphene_core::raster::ExposureNode<_, _, _>, params: [f64, f64, f64]),
vec![(
ProtoNodeIdentifier::new("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
ProtoNodeIdentifier::new("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
|args: Vec<graph_craft::proto::SharedNodeContainer>| {
Box::pin(async move {
use graphene_std::raster::ImaginateNode;
macro_rules! instantiate_imaginate_node {
($($i:expr,)*) => { ImaginateNode::new($(graphene_std::any::input_node(args[$i].clone()),)* ) };
}
let node: ImaginateNode<Color, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _> = instantiate_imaginate_node!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,);
let node: ImaginateNode<Color, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _> = instantiate_imaginate_node!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,);
let any = graphene_std::any::DynAnyNode::new(node);
any.into_type_erased()
})
Expand All @@ -584,9 +584,9 @@ fn node_registry() -> HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeCons
concrete!(ImageFrame<Color>),
concrete!(ImageFrame<Color>),
vec![
fn_type!(WasmEditorApi),
fn_type!(&WasmEditorApi),
fn_type!(ImaginateController),
fn_type!(u64),
fn_type!(f64),
fn_type!(Option<DVec2>),
fn_type!(u32),
fn_type!(ImaginateSamplingMethod),
Expand All @@ -600,6 +600,7 @@ fn node_registry() -> HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeCons
fn_type!(ImaginateMaskStartingFill),
fn_type!(bool),
fn_type!(bool),
fn_type!(u64),
],
),
)],
Expand Down

0 comments on commit 06a409f

Please sign in to comment.