-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Rewrite for inlining a single Call #1934
base: main
Are you sure you want to change the base?
Changes from all commits
7a3ac0b
6d8dfc1
0d5b5af
75f8842
4c345eb
120f17c
c10c81a
45dc680
d3f9c62
231b125
756902a
8db58de
45dd9ac
c995322
5a1f2f8
7f86624
b32065b
f7069c0
feaaace
49a5726
787773a
ee38b5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,8 @@ use crate::extension::ExtensionRegistry; | |
use crate::hugr::views::SiblingSubgraph; | ||
use crate::hugr::{HugrView, Node, OpType, RootTagged}; | ||
use crate::hugr::{NodeMetadata, Rewrite}; | ||
use crate::ops::OpTrait; | ||
use crate::types::Substitution; | ||
use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; | ||
|
||
use super::internal::HugrMutInternals; | ||
|
@@ -146,6 +148,29 @@ pub trait HugrMut: HugrMutInternals { | |
self.hugr_mut().remove_node(node); | ||
} | ||
|
||
/// Copies the strict descendants of `root` to under the `new_parent`, optionally applying a | ||
/// [Substitution] to the [OpType]s of the copied nodes. | ||
/// | ||
/// That is, the immediate children of root, are copied to make children of `new_parent`. | ||
/// | ||
/// Note this may invalidate the Hugr in two ways: | ||
/// * Adding children of `root` may make the children-list of `new_parent` invalid e.g. | ||
/// leading to multiple [Input](OpType::Input), [Output](OpType::Output) or | ||
/// [ExitBlock](OpType::ExitBlock) nodes or Input/Output in the wrong positions | ||
/// * Nonlocal edges incoming to the subtree of `root` will be copied to target the subtree under `new_parent` | ||
/// which may be invalid if `new_parent` is not a child of `root`s parent (for `Ext` edges - or | ||
/// correspondingly for `Dom` edges) | ||
fn copy_descendants( | ||
&mut self, | ||
root: Node, | ||
new_parent: Node, | ||
subst: Option<Substitution>, | ||
) -> HashMap<Node, Node> { | ||
panic_invalid_node(self, root); | ||
panic_invalid_node(self, new_parent); | ||
self.hugr_mut().copy_descendants(root, new_parent, subst) | ||
} | ||
|
||
/// Connect two nodes at the given ports. | ||
/// | ||
/// # Panics | ||
|
@@ -450,6 +475,42 @@ impl<T: RootTagged<RootHandle = Node, Node = Node> + AsMut<Hugr>> HugrMut for T | |
} | ||
translate_indices(node_map) | ||
} | ||
|
||
fn copy_descendants( | ||
&mut self, | ||
root: Node, | ||
new_parent: Node, | ||
subst: Option<Substitution>, | ||
) -> HashMap<Node, Node> { | ||
let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index()); | ||
let root2 = descendants.next(); | ||
debug_assert_eq!(root2, Some(root.pg_index())); | ||
let nodes = Vec::from_iter(descendants); | ||
let node_map = translate_indices( | ||
portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) | ||
.copy_in_parent() | ||
.expect("Is a MultiPortGraph"), | ||
); | ||
|
||
for node in self.children(root).collect::<Vec<_>>() { | ||
self.set_parent(*node_map.get(&node).unwrap(), new_parent); | ||
} | ||
|
||
// Copy the optypes, metadata, and hierarchy | ||
for (&node, &new_node) in node_map.iter() { | ||
for ch in self.children(node).collect::<Vec<_>>() { | ||
self.set_parent(*node_map.get(&ch).unwrap(), new_node); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is nondeterministic because you are iterating a HashMap. the order of children in their parent will depend on iteration order. It's a little out of scope for this PR, but would you consider changing EDIT: no I'm wrong. You mutate the nodes in a nondeterministic order, but I don't think this is actually nondeterministic. for each parent you set_parent it's children in a deterministic order. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, ISTR this hit me just getting Input/Output in the wrong order, so I had to make sure it did it "right"... |
||
} | ||
let new_optype = match subst { | ||
None => self.get_optype(node).clone(), | ||
Some(ref subst) => self.get_optype(node).substitute(subst), | ||
}; | ||
self.as_mut().op_types.set(new_node.pg_index(), new_optype); | ||
let meta = self.base_hugr().metadata.get(node.pg_index()).clone(); | ||
self.as_mut().metadata.set(new_node.pg_index(), meta); | ||
} | ||
node_map | ||
} | ||
} | ||
|
||
/// Internal implementation of `insert_hugr` and `insert_view` methods for | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
//! Rewrite to inline a Call to a FuncDefn by copying the body of the function | ||
//! into a DFG which replaces the Call node. | ||
use derive_more::{Display, Error}; | ||
|
||
use crate::ops::{DataflowParent, OpType, DFG}; | ||
use crate::types::Substitution; | ||
use crate::{HugrView, Node}; | ||
|
||
use super::{HugrMut, Rewrite}; | ||
|
||
/// Rewrite to inline a [Call](OpType::Call) to a known [FuncDefn](OpType::FuncDefn) | ||
pub struct InlineCall(Node); | ||
|
||
/// Error in performing [InlineCall] rewrite. | ||
#[derive(Clone, Debug, Display, Error, PartialEq)] | ||
#[non_exhaustive] | ||
pub enum InlineCallError { | ||
/// The specified Node was not a [Call](OpType::Call) | ||
#[display("Node to inline {_0} expected to be a Call but actually {_1}")] | ||
NotCallNode(Node, OpType), | ||
/// The node was a Call, but the target was not a [FuncDefn](OpType::FuncDefn) | ||
/// - presumably a [FuncDecl](OpType::FuncDecl), if the Hugr is valid. | ||
#[display("Call targetted node {_0} which must be a FuncDefn but was {_1}")] | ||
CallTargetNotFuncDefn(Node, OpType), | ||
} | ||
|
||
impl InlineCall { | ||
/// Create a new instance that will inline the specified node | ||
/// (i.e. that should be a [Call](OpType::Call)) | ||
pub fn new(node: Node) -> Self { | ||
Self(node) | ||
} | ||
} | ||
|
||
impl Rewrite for InlineCall { | ||
type ApplyResult = (); | ||
doug-q marked this conversation as resolved.
Show resolved
Hide resolved
|
||
type Error = InlineCallError; | ||
fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> { | ||
let call_ty = h.get_optype(self.0); | ||
if !call_ty.is_call() { | ||
return Err(InlineCallError::NotCallNode(self.0, call_ty.clone())); | ||
} | ||
let func = h.static_source(self.0).unwrap(); | ||
let func_ty = h.get_optype(func); | ||
if !func_ty.is_func_defn() { | ||
return Err(InlineCallError::CallTargetNotFuncDefn( | ||
func, | ||
func_ty.clone(), | ||
)); | ||
} | ||
Ok(()) | ||
} | ||
|
||
fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think you are considering the case of polymorphic functions here? I think you need to substitute the type args of the call into the signature and all descendents. I fo think it's ok to reject polymorphic functions for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh dang that's a good point. Knew some things would get forgotten what with this sleeping over hack fortnight... So I could do this later, but it's not much new code - you just need to iterate through the same
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact on the third I am tempted to make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the copying is important. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you were a type theorist, substitution would be a very low level op....lower level than copying a whole load of nodes which include types ;-) |
||
self.verify(h)?; // Now we know we have a Call to a FuncDefn. | ||
let orig_func = h.static_source(self.0).unwrap(); | ||
h.disconnect(self.0, h.get_optype(self.0).static_input_port().unwrap()); | ||
|
||
let new_op = OpType::from(DFG { | ||
signature: h | ||
.get_optype(orig_func) | ||
.as_func_defn() | ||
.unwrap() | ||
.inner_signature() | ||
.into_owned(), | ||
}); | ||
let (in_ports, out_ports) = (new_op.input_count(), new_op.output_count()); | ||
let ty_args = h | ||
.replace_op(self.0, new_op) | ||
.unwrap() | ||
.as_call() | ||
.unwrap() | ||
.type_args | ||
.clone(); | ||
h.set_num_ports(self.0, in_ports as _, out_ports as _); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @trvto Is getting a panic here. I think it's being caused by the incoming order edge on the call. I think you need to disconnect it and reconnect them on the DFG, because the called_function_port has been deleted. I think this would work because we know how DFGs and Call ports work, but it would be nice to debug_assert! our understanding in case it changes. |
||
|
||
h.copy_descendants( | ||
orig_func, | ||
self.0, | ||
(!ty_args.is_empty()).then_some(Substitution::new(&ty_args)), | ||
); | ||
Ok(()) | ||
} | ||
|
||
/// Failure only occurs if the node is not a Call, or the target not a FuncDefn. | ||
/// (Any later failure means an invalid Hugr and `panic`.) | ||
const UNCHANGED_ON_FAILURE: bool = true; | ||
|
||
fn invalidation_set(&self) -> impl Iterator<Item = Node> { | ||
Some(self.0).into_iter() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use std::iter::successors; | ||
|
||
use itertools::Itertools; | ||
|
||
use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; | ||
use crate::ops::handle::{FuncID, NodeHandle}; | ||
use crate::ops::{Input, Value}; | ||
use crate::std_extensions::arithmetic::{ | ||
int_ops::{self, IntOpDef}, | ||
int_types::{self, ConstInt, INT_TYPES}, | ||
}; | ||
use crate::{types::Signature, HugrView, Node}; | ||
|
||
use super::{HugrMut, InlineCall, InlineCallError}; | ||
|
||
fn calls(h: &impl HugrView<Node = Node>) -> Vec<Node> { | ||
h.nodes().filter(|n| h.get_optype(*n).is_call()).collect() | ||
} | ||
|
||
fn extension_ops(h: &impl HugrView<Node = Node>) -> Vec<Node> { | ||
h.nodes() | ||
.filter(|n| h.get_optype(*n).is_extension_op()) | ||
.collect() | ||
} | ||
|
||
#[test] | ||
fn test_inline() -> Result<(), Box<dyn std::error::Error>> { | ||
let mut mb = ModuleBuilder::new(); | ||
let cst3 = mb.add_constant(Value::from(ConstInt::new_u(4, 3)?)); | ||
let sig = Signature::new_endo(INT_TYPES[4].clone()) | ||
.with_extension_delta(int_ops::EXTENSION_ID) | ||
.with_extension_delta(int_types::EXTENSION_ID); | ||
let func = { | ||
let mut fb = mb.define_function("foo", sig.clone())?; | ||
let c1 = fb.load_const(&cst3); | ||
let [i] = fb.input_wires_arr(); | ||
let add = fb.add_dataflow_op(IntOpDef::iadd.with_log_width(4), [i, c1])?; | ||
fb.finish_with_outputs(add.outputs())? | ||
}; | ||
let mut main = mb.define_function("main", sig)?; | ||
let call1 = main.call(func.handle(), &[], main.input_wires())?; | ||
let call2 = main.call(func.handle(), &[], call1.outputs())?; | ||
main.finish_with_outputs(call2.outputs())?; | ||
let mut hugr = mb.finish_hugr()?; | ||
let call1 = call1.node(); | ||
let call2 = call2.node(); | ||
assert_eq!( | ||
hugr.output_neighbours(func.node()).collect_vec(), | ||
[call1, call2] | ||
); | ||
assert_eq!(calls(&hugr), [call1, call2]); | ||
assert_eq!(extension_ops(&hugr).len(), 1); | ||
|
||
hugr.apply_rewrite(InlineCall(call1.node())).unwrap(); | ||
hugr.validate().unwrap(); | ||
assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]); | ||
assert_eq!(calls(&hugr), [call2]); | ||
assert_eq!(extension_ops(&hugr).len(), 2); | ||
|
||
hugr.apply_rewrite(InlineCall(call2.node())).unwrap(); | ||
hugr.validate().unwrap(); | ||
assert_eq!(hugr.output_neighbours(func.node()).next(), None); | ||
assert_eq!(calls(&hugr), []); | ||
assert_eq!(extension_ops(&hugr).len(), 3); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_recursion() -> Result<(), Box<dyn std::error::Error>> { | ||
let mut mb = ModuleBuilder::new(); | ||
let sig = Signature::new_endo(INT_TYPES[5].clone()) | ||
.with_extension_delta(int_ops::EXTENSION_ID) | ||
.with_extension_delta(int_types::EXTENSION_ID); | ||
let (func, rec_call) = { | ||
let mut fb = mb.define_function("foo", sig.clone())?; | ||
let cst1 = fb.add_load_value(ConstInt::new_u(5, 1)?); | ||
let [i] = fb.input_wires_arr(); | ||
let add = fb.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [i, cst1])?; | ||
let call = fb.call( | ||
&FuncID::<true>::from(fb.container_node()), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not for this PR but this recursive handle should probably in the function builder API |
||
&[], | ||
add.outputs(), | ||
)?; | ||
(fb.finish_with_outputs(call.outputs())?, call) | ||
}; | ||
let mut main = mb.define_function("main", sig)?; | ||
let call = main.call(func.handle(), &[], main.input_wires())?; | ||
let main = main.finish_with_outputs(call.outputs())?; | ||
let mut hugr = mb.finish_hugr()?; | ||
|
||
let func = func.node(); | ||
let mut call = call.node(); | ||
for i in 2..10 { | ||
hugr.apply_rewrite(InlineCall(call))?; | ||
hugr.validate().unwrap(); | ||
assert_eq!(extension_ops(&hugr).len(), i); | ||
let v = calls(&hugr); | ||
assert!(v.iter().all(|n| hugr.static_source(*n) == Some(func))); | ||
|
||
let [rec, nonrec] = v.try_into().expect("Should be two"); | ||
assert_eq!(rec, rec_call.node()); | ||
assert_eq!(hugr.output_neighbours(func).collect_vec(), [rec, nonrec]); | ||
call = nonrec; | ||
|
||
let mut ancestors = successors(hugr.get_parent(call), |n| hugr.get_parent(*n)); | ||
for _ in 1..i { | ||
assert!(hugr.get_optype(ancestors.next().unwrap()).is_dfg()); | ||
} | ||
assert_eq!(ancestors.next(), Some(main.node())); | ||
assert_eq!(ancestors.next(), Some(hugr.root())); | ||
assert_eq!(ancestors.next(), None); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
} | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_bad() { | ||
let mut modb = ModuleBuilder::new(); | ||
let decl = modb | ||
.declare( | ||
"UndefinedFunc", | ||
Signature::new_endo(INT_TYPES[3].clone()).into(), | ||
) | ||
.unwrap(); | ||
let mut main = modb | ||
.define_function("main", Signature::new_endo(INT_TYPES[3].clone())) | ||
.unwrap(); | ||
let call = main.call(&decl, &[], main.input_wires()).unwrap(); | ||
let main = main.finish_with_outputs(call.outputs()).unwrap(); | ||
let h = modb.finish_hugr().unwrap(); | ||
let mut h2 = h.clone(); | ||
assert_eq!( | ||
h2.apply_rewrite(InlineCall(call.node())), | ||
Err(InlineCallError::CallTargetNotFuncDefn( | ||
decl.node(), | ||
h.get_optype(decl.node()).clone() | ||
)) | ||
); | ||
assert_eq!(h, h2); | ||
let [inp, _out, _call] = h | ||
.children(main.node()) | ||
.collect::<Vec<_>>() | ||
.try_into() | ||
.unwrap(); | ||
assert_eq!( | ||
h2.apply_rewrite(InlineCall(inp)), | ||
Err(InlineCallError::NotCallNode( | ||
inp, | ||
Input { | ||
types: INT_TYPES[3].clone().into() | ||
} | ||
.into() | ||
)) | ||
) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like the hierarchy part isn't covered, do you need a test with more nesting?