Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Rewrite for inlining a single Call #1934

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ missing_docs = "warn"
debug_assert_with_mut_call = "warn"

[workspace.dependencies]
portgraph = { version = "0.13.0" }
portgraph = { version = "0.13.2" }
insta = { version = "1.34.0" }
bitvec = "1.0.1"
capnp = "0.20.1"
Expand Down
61 changes: 61 additions & 0 deletions hugr-core/src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use crate::extension::ExtensionRegistry;
use crate::hugr::views::SiblingSubgraph;
use crate::hugr::{HugrView, Node, OpType, RootTagged};
use crate::hugr::{NodeMetadata, Rewrite};
use crate::ops::OpTrait;
use crate::types::Substitution;
use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex};

use super::internal::HugrMutInternals;
Expand Down Expand Up @@ -146,6 +148,29 @@ pub trait HugrMut: HugrMutInternals {
self.hugr_mut().remove_node(node);
}

/// Copies the strict descendants of `root` to under the `new_parent`, optionally applying a
/// [Substitution] to the [OpType]s of the copied nodes.
///
/// That is, the immediate children of root, are copied to make children of `new_parent`.
///
/// Note this may invalidate the Hugr in two ways:
/// * Adding children of `root` may make the children-list of `new_parent` invalid e.g.
/// leading to multiple [Input](OpType::Input), [Output](OpType::Output) or
/// [ExitBlock](OpType::ExitBlock) nodes or Input/Output in the wrong positions
/// * Nonlocal edges incoming to the subtree of `root` will be copied to target the subtree under `new_parent`
/// which may be invalid if `new_parent` is not a child of `root`s parent (for `Ext` edges - or
/// correspondingly for `Dom` edges)
fn copy_descendants(
&mut self,
root: Node,
new_parent: Node,
subst: Option<Substitution>,
) -> HashMap<Node, Node> {
panic_invalid_node(self, root);
panic_invalid_node(self, new_parent);
self.hugr_mut().copy_descendants(root, new_parent, subst)
}

/// Connect two nodes at the given ports.
///
/// # Panics
Expand Down Expand Up @@ -450,6 +475,42 @@ impl<T: RootTagged<RootHandle = Node, Node = Node> + AsMut<Hugr>> HugrMut for T
}
translate_indices(node_map)
}

fn copy_descendants(
&mut self,
root: Node,
new_parent: Node,
subst: Option<Substitution>,
) -> HashMap<Node, Node> {
let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index());
let root2 = descendants.next();
debug_assert_eq!(root2, Some(root.pg_index()));
let nodes = Vec::from_iter(descendants);
let node_map = translate_indices(
portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes)
.copy_in_parent()
.expect("Is a MultiPortGraph"),
);

for node in self.children(root).collect::<Vec<_>>() {
self.set_parent(*node_map.get(&node).unwrap(), new_parent);
}

// Copy the optypes, metadata, and hierarchy
for (&node, &new_node) in node_map.iter() {
for ch in self.children(node).collect::<Vec<_>>() {
self.set_parent(*node_map.get(&ch).unwrap(), new_node);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the hierarchy part isn't covered, do you need a test with more nesting?

Copy link
Collaborator

@doug-q doug-q Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nondeterministic because you are iterating a HashMap. the order of children in their parent will depend on iteration order. It's a little out of scope for this PR, but would you consider changing translate_indices to return a BTreeMap?

EDIT: no I'm wrong. You mutate the nodes in a nondeterministic order, but I don't think this is actually nondeterministic. for each parent you set_parent it's children in a deterministic order.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ISTR this hit me just getting Input/Output in the wrong order, so I had to make sure it did it "right"...

}
let new_optype = match subst {
None => self.get_optype(node).clone(),
Some(ref subst) => self.get_optype(node).substitute(subst),
};
self.as_mut().op_types.set(new_node.pg_index(), new_optype);
let meta = self.base_hugr().metadata.get(node.pg_index()).clone();
self.as_mut().metadata.set(new_node.pg_index(), meta);
}
node_map
}
}

/// Internal implementation of `insert_hugr` and `insert_view` methods for
Expand Down
1 change: 1 addition & 0 deletions hugr-core/src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Rewrite operations on the HUGR - replacement, outlining, etc.

pub mod consts;
pub mod inline_call;
pub mod inline_dfg;
pub mod insert_identity;
pub mod outline_cfg;
Expand Down
252 changes: 252 additions & 0 deletions hugr-core/src/hugr/rewrite/inline_call.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
//! Rewrite to inline a Call to a FuncDefn by copying the body of the function
//! into a DFG which replaces the Call node.
use derive_more::{Display, Error};

use crate::ops::{DataflowParent, OpType, DFG};
use crate::types::Substitution;
use crate::{HugrView, Node};

use super::{HugrMut, Rewrite};

/// Rewrite to inline a [Call](OpType::Call) to a known [FuncDefn](OpType::FuncDefn)
pub struct InlineCall(Node);

/// Error in performing [InlineCall] rewrite.
#[derive(Clone, Debug, Display, Error, PartialEq)]
#[non_exhaustive]
pub enum InlineCallError {
/// The specified Node was not a [Call](OpType::Call)
#[display("Node to inline {_0} expected to be a Call but actually {_1}")]
NotCallNode(Node, OpType),
/// The node was a Call, but the target was not a [FuncDefn](OpType::FuncDefn)
/// - presumably a [FuncDecl](OpType::FuncDecl), if the Hugr is valid.
#[display("Call targetted node {_0} which must be a FuncDefn but was {_1}")]
CallTargetNotFuncDefn(Node, OpType),
}

impl InlineCall {
/// Create a new instance that will inline the specified node
/// (i.e. that should be a [Call](OpType::Call))
pub fn new(node: Node) -> Self {
Self(node)
}
}

impl Rewrite for InlineCall {
type ApplyResult = ();
type Error = InlineCallError;
fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> {
let call_ty = h.get_optype(self.0);
if !call_ty.is_call() {
return Err(InlineCallError::NotCallNode(self.0, call_ty.clone()));
}
let func = h.static_source(self.0).unwrap();
let func_ty = h.get_optype(func);
if !func_ty.is_func_defn() {
return Err(InlineCallError::CallTargetNotFuncDefn(
func,
func_ty.clone(),
));
}
Ok(())
}

fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you are considering the case of polymorphic functions here? I think you need to substitute the type args of the call into the signature and all descendents.

I fo think it's ok to reject polymorphic functions for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh dang that's a good point. Knew some things would get forgotten what with this sleeping over hack fortnight...

So I could do this later, but it's not much new code - you just need to iterate through the same nodes and apply the substitution. However, this does involve cloning the OpType in copy_descendants followed by applying the substitution (which is kinda a copy itself). So we could

  • decide this bit of extra copying isn't important
  • make substitution destructive (mutating). This would be perhaps the biggest change.
  • pass a impl Fn(&OpType) -> OpType into copy_descendants. Seems a little bit arbitrary (shouldn't we also have callbacks to copy metadata, etc....) but I think it's reasonable to argue that we should have this one. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact on the third I am tempted to make copy_descendants take an Option<Substitution> rather than a generic callback. Substitution is pretty fundamental and it avoids the argument that we should also do the same thing for e.g. metadata ;)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the copying is important.
I do like passing an optional substitution, this is very nice.
It seems like a bit of a weird inter-module dependency, this is very low level, substitution pretty high level. But this is just a vibe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you were a type theorist, substitution would be a very low level op....lower level than copying a whole load of nodes which include types ;-)

self.verify(h)?; // Now we know we have a Call to a FuncDefn.
let orig_func = h.static_source(self.0).unwrap();
h.disconnect(self.0, h.get_optype(self.0).static_input_port().unwrap());

let new_op = OpType::from(DFG {
signature: h
.get_optype(orig_func)
.as_func_defn()
.unwrap()
.inner_signature()
.into_owned(),
});
let (in_ports, out_ports) = (new_op.input_count(), new_op.output_count());
let ty_args = h
.replace_op(self.0, new_op)
.unwrap()
.as_call()
.unwrap()
.type_args
.clone();
h.set_num_ports(self.0, in_ports as _, out_ports as _);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trvto Is getting a panic here. I think it's being caused by the incoming order edge on the call. I think you need to disconnect it and reconnect them on the DFG, because the called_function_port has been deleted.

I think this would work because we know how DFGs and Call ports work, but it would be nice to debug_assert! our understanding in case it changes.


h.copy_descendants(
orig_func,
self.0,
(!ty_args.is_empty()).then_some(Substitution::new(&ty_args)),
);
Ok(())
}

/// Failure only occurs if the node is not a Call, or the target not a FuncDefn.
/// (Any later failure means an invalid Hugr and `panic`.)
const UNCHANGED_ON_FAILURE: bool = true;

fn invalidation_set(&self) -> impl Iterator<Item = Node> {
Some(self.0).into_iter()
}
}

#[cfg(test)]
mod test {
use std::iter::successors;

use itertools::Itertools;

use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use crate::ops::handle::{FuncID, NodeHandle};
use crate::ops::{Input, Value};
use crate::std_extensions::arithmetic::{
int_ops::{self, IntOpDef},
int_types::{self, ConstInt, INT_TYPES},
};
use crate::{types::Signature, HugrView, Node};

use super::{HugrMut, InlineCall, InlineCallError};

fn calls(h: &impl HugrView<Node = Node>) -> Vec<Node> {
h.nodes().filter(|n| h.get_optype(*n).is_call()).collect()
}

fn extension_ops(h: &impl HugrView<Node = Node>) -> Vec<Node> {
h.nodes()
.filter(|n| h.get_optype(*n).is_extension_op())
.collect()
}

#[test]
fn test_inline() -> Result<(), Box<dyn std::error::Error>> {
let mut mb = ModuleBuilder::new();
let cst3 = mb.add_constant(Value::from(ConstInt::new_u(4, 3)?));
let sig = Signature::new_endo(INT_TYPES[4].clone())
.with_extension_delta(int_ops::EXTENSION_ID)
.with_extension_delta(int_types::EXTENSION_ID);
let func = {
let mut fb = mb.define_function("foo", sig.clone())?;
let c1 = fb.load_const(&cst3);
let [i] = fb.input_wires_arr();
let add = fb.add_dataflow_op(IntOpDef::iadd.with_log_width(4), [i, c1])?;
fb.finish_with_outputs(add.outputs())?
};
let mut main = mb.define_function("main", sig)?;
let call1 = main.call(func.handle(), &[], main.input_wires())?;
let call2 = main.call(func.handle(), &[], call1.outputs())?;
main.finish_with_outputs(call2.outputs())?;
let mut hugr = mb.finish_hugr()?;
let call1 = call1.node();
let call2 = call2.node();
assert_eq!(
hugr.output_neighbours(func.node()).collect_vec(),
[call1, call2]
);
assert_eq!(calls(&hugr), [call1, call2]);
assert_eq!(extension_ops(&hugr).len(), 1);

hugr.apply_rewrite(InlineCall(call1.node())).unwrap();
hugr.validate().unwrap();
assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]);
assert_eq!(calls(&hugr), [call2]);
assert_eq!(extension_ops(&hugr).len(), 2);

hugr.apply_rewrite(InlineCall(call2.node())).unwrap();
hugr.validate().unwrap();
assert_eq!(hugr.output_neighbours(func.node()).next(), None);
assert_eq!(calls(&hugr), []);
assert_eq!(extension_ops(&hugr).len(), 3);

Ok(())
}

#[test]
fn test_recursion() -> Result<(), Box<dyn std::error::Error>> {
let mut mb = ModuleBuilder::new();
let sig = Signature::new_endo(INT_TYPES[5].clone())
.with_extension_delta(int_ops::EXTENSION_ID)
.with_extension_delta(int_types::EXTENSION_ID);
let (func, rec_call) = {
let mut fb = mb.define_function("foo", sig.clone())?;
let cst1 = fb.add_load_value(ConstInt::new_u(5, 1)?);
let [i] = fb.input_wires_arr();
let add = fb.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [i, cst1])?;
let call = fb.call(
&FuncID::<true>::from(fb.container_node()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this PR but this recursive handle should probably in the function builder API

&[],
add.outputs(),
)?;
(fb.finish_with_outputs(call.outputs())?, call)
};
let mut main = mb.define_function("main", sig)?;
let call = main.call(func.handle(), &[], main.input_wires())?;
let main = main.finish_with_outputs(call.outputs())?;
let mut hugr = mb.finish_hugr()?;

let func = func.node();
let mut call = call.node();
for i in 2..10 {
hugr.apply_rewrite(InlineCall(call))?;
hugr.validate().unwrap();
assert_eq!(extension_ops(&hugr).len(), i);
let v = calls(&hugr);
assert!(v.iter().all(|n| hugr.static_source(*n) == Some(func)));

let [rec, nonrec] = v.try_into().expect("Should be two");
assert_eq!(rec, rec_call.node());
assert_eq!(hugr.output_neighbours(func).collect_vec(), [rec, nonrec]);
call = nonrec;

let mut ancestors = successors(hugr.get_parent(call), |n| hugr.get_parent(*n));
for _ in 1..i {
assert!(hugr.get_optype(ancestors.next().unwrap()).is_dfg());
}
assert_eq!(ancestors.next(), Some(main.node()));
assert_eq!(ancestors.next(), Some(hugr.root()));
assert_eq!(ancestors.next(), None);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

}
Ok(())
}

#[test]
fn test_bad() {
let mut modb = ModuleBuilder::new();
let decl = modb
.declare(
"UndefinedFunc",
Signature::new_endo(INT_TYPES[3].clone()).into(),
)
.unwrap();
let mut main = modb
.define_function("main", Signature::new_endo(INT_TYPES[3].clone()))
.unwrap();
let call = main.call(&decl, &[], main.input_wires()).unwrap();
let main = main.finish_with_outputs(call.outputs()).unwrap();
let h = modb.finish_hugr().unwrap();
let mut h2 = h.clone();
assert_eq!(
h2.apply_rewrite(InlineCall(call.node())),
Err(InlineCallError::CallTargetNotFuncDefn(
decl.node(),
h.get_optype(decl.node()).clone()
))
);
assert_eq!(h, h2);
let [inp, _out, _call] = h
.children(main.node())
.collect::<Vec<_>>()
.try_into()
.unwrap();
assert_eq!(
h2.apply_rewrite(InlineCall(inp)),
Err(InlineCallError::NotCallNode(
inp,
Input {
types: INT_TYPES[3].clone().into()
}
.into()
))
)
}
}
Loading