Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Link and neighbour iterators counting self-loops twice #132

Merged
merged 3 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions .github/pre-commit
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,35 @@ if [[ "${IGNORE_RUSTHOOKS:=0}" -ne 0 ]]; then
exit 0
fi

if ! cargo fmt -- --check
if ! cargo fmt --all -- --check
then
echo "There are some code style issues."
echo "Run cargo fmt first."
exit 1
fi

if ! cargo check --all --all-features --workspace
then
echo "There are some compilation warnings."
exit 1
fi

if ! cargo test --all-features --workspace
then
echo "There are some test issues."
exit 1
fi

if ! cargo clippy --all-targets --all-features --workspace -- -D warnings
then
echo "There are some clippy issues."
exit 1
fi

if ! cargo test --all-features
RUSTDOCFLAGS="-Dwarnings"
if ! cargo doc --no-deps --all-features --workspace
then
echo "There are some test issues."
echo "There are some clippy issues."
exit 1
fi

Expand Down
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ test:

# Auto-fix all clippy warnings
fix:
cargo clippy --all-targets --all-features --workspace --fix --allow-staged
cargo clippy --all-targets --all-features --workspace --fix --allow-staged --allow-dirty

# Run the pre-commit checks
check:
Expand Down
105 changes: 101 additions & 4 deletions src/multiportgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,22 +344,23 @@ impl LinkView for MultiPortGraph {

#[inline]
fn links(&self, node: NodeIndex, direction: Direction) -> Self::NodeLinks<'_> {
NodeLinks::new(self, self.ports(node, direction))
NodeLinks::new(self, self.ports(node, direction), 0..0)
}

#[inline]
fn all_links(&self, node: NodeIndex) -> Self::NodeLinks<'_> {
NodeLinks::new(self, self.all_ports(node))
let output_ports = self.graph.node_outgoing_ports(node);
NodeLinks::new(self, self.all_ports(node), output_ports)
}

#[inline]
fn neighbours(&self, node: NodeIndex, direction: Direction) -> Self::Neighbours<'_> {
Neighbours::new(self, self.subports(node, direction))
Neighbours::new(self, self.subports(node, direction), node, false)
}

#[inline]
fn all_neighbours(&self, node: NodeIndex) -> Self::Neighbours<'_> {
Neighbours::new(self, self.all_subports(node))
Neighbours::new(self, self.all_subports(node), node, true)
}

#[inline]
Expand Down Expand Up @@ -811,6 +812,7 @@ pub mod test {
let mut g = MultiPortGraph::new();
let node0 = g.add_node(1, 2);
let node1 = g.add_node(2, 1);
let node0_input0 = g.input(node0, 0).unwrap();
let (node0_output0, node0_output1) = g.outputs(node0).collect_tuple().unwrap();
let (node1_input0, node1_input1) = g.inputs(node1).collect_tuple().unwrap();

Expand Down Expand Up @@ -866,5 +868,100 @@ pub mod test {
);
assert_eq!(g.all_neighbours(node0).collect_vec(), [node1, node1, node1]);
assert_eq!(g.port_links(node0_output0).collect_vec(), links[0..2]);

// Self-link
// The `all_links` / `all_neighbours` iterators should only return these once.
g.link_nodes(node0, 0, node0, 0).unwrap();
assert_eq!(
g.subport_outputs(node0).collect_vec(),
[
SubportIndex::new_multi(node0_output0, 0),
SubportIndex::new_multi(node0_output0, 1),
SubportIndex::new_multi(node0_output0, 2),
SubportIndex::new_unique(node0_output1),
]
);
assert_eq!(
g.subport_inputs(node0).collect_vec(),
[SubportIndex::new_unique(node0_input0)]
);

let links = [
(
SubportIndex::new_multi(node0_output0, 0),
SubportIndex::new_unique(node1_input0),
),
(
SubportIndex::new_multi(node0_output0, 1),
SubportIndex::new_multi(node1_input1, 0),
),
(
SubportIndex::new_multi(node0_output0, 2),
SubportIndex::new_unique(node0_input0),
),
(
SubportIndex::new_unique(node0_output1),
SubportIndex::new_multi(node1_input1, 1),
),
];
assert_eq!(
g.input_links(node0).collect_vec(),
[(
SubportIndex::new_unique(node0_input0),
SubportIndex::new_multi(node0_output0, 2),
)]
);
assert_eq!(g.output_links(node0).collect_vec(), links);
assert_eq!(g.all_links(node0).collect_vec(), links);
assert_eq!(g.input_neighbours(node0).collect_vec(), [node0]);
assert_eq!(
g.output_neighbours(node0).collect_vec(),
[node1, node1, node0, node1]
);
assert_eq!(
g.all_neighbours(node0).collect_vec(),
[node1, node1, node0, node1]
);
assert_eq!(g.port_links(node0_output0).collect_vec(), links[0..3]);
}

#[test]
fn insert_graph() -> Result<(), Box<dyn std::error::Error>> {
let mut g = crate::MultiPortGraph::new();
// Add dummy nodes to produce different node ids than in the other graph.
g.add_node(0, 0);
g.add_node(0, 0);
let node0g = g.add_node(1, 1);
let node1g = g.add_node(1, 1);
g.link_nodes(node0g, 0, node1g, 0)?;

let mut h = PortGraph::new();
let node0h = h.add_node(2, 2);
let node1h = h.add_node(1, 1);
h.link_nodes(node0h, 0, node1h, 0)?;
h.link_nodes(node0h, 1, node0h, 0)?;
h.link_nodes(node1h, 0, node0h, 1)?;

let map = g.insert_graph(&h)?;
assert_eq!(map.len(), 2);

assert_eq!(g.node_count(), 6);
assert_eq!(g.link_count(), 4);
assert!(g.contains_node(map[&node0h]));
assert!(g.contains_node(map[&node1h]));
assert_eq!(
g.input_neighbours(map[&node0h]).collect_vec(),
vec![map[&node0h], map[&node1h]]
);
assert_eq!(
g.output_neighbours(map[&node0h]).collect_vec(),
vec![map[&node1h], map[&node0h]]
);
assert_eq!(
g.all_neighbours(map[&node0h]).collect_vec(),
vec![map[&node1h], map[&node1h], map[&node0h]]
);

Ok(())
}
}
97 changes: 70 additions & 27 deletions src/multiportgraph/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::ops::Range;

use super::{MultiPortGraph, SubportIndex};
use crate::portgraph::{self, NodePorts};
use crate::{LinkView, NodeIndex, PortIndex, PortOffset, PortView};
use crate::{Direction, LinkView, NodeIndex, PortIndex, PortOffset, PortView};

/// Iterator over the nodes of a graph.
#[derive(Clone)]
Expand Down Expand Up @@ -127,14 +127,27 @@ pub struct Neighbours<'a> {
multigraph: &'a MultiPortGraph,
subports: NodeSubports<'a>,
current_copy_node: Option<NodeIndex>,
/// The node for which the neighbours are being iterated.
node: NodeIndex,
/// Whether to ignore self-loops in the input -> output direction.
/// This is used to avoid counting self-loops twice when iterating both
/// input and output neighbours.
ignore_dupped_self_loops: bool,
}

impl<'a> Neighbours<'a> {
pub(super) fn new(multigraph: &'a MultiPortGraph, subports: NodeSubports<'a>) -> Self {
pub(super) fn new(
multigraph: &'a MultiPortGraph,
subports: NodeSubports<'a>,
node: NodeIndex,
ignore_dupped_self_loops: bool,
) -> Self {
Self {
multigraph,
subports,
current_copy_node: None,
node,
ignore_dupped_self_loops,
}
}
}
Expand All @@ -143,26 +156,42 @@ impl<'a> Iterator for Neighbours<'a> {
type Item = NodeIndex;

fn next(&mut self) -> Option<Self::Item> {
let link = self.subports.find_map(|subport| {
let port_index = subport.port();
if !self.multigraph.is_multiport(port_index) {
self.multigraph.graph.port_link(port_index)
} else {
// There is a copy node
if subport.offset() == 0 {
self.current_copy_node = self.multigraph.get_copy_node(port_index);
loop {
let link = self.subports.find_map(|subport| {
let port_index = subport.port();
if !self.multigraph.is_multiport(port_index) {
self.multigraph.graph.port_link(port_index)
} else {
// There is a copy node
if subport.offset() == 0 {
self.current_copy_node = self.multigraph.get_copy_node(port_index);
}
let copy_node = self
.current_copy_node
.expect("Copy node not connected to a multiport.");
let dir = self.multigraph.graph.port_direction(port_index).unwrap();
let offset = PortOffset::new(dir, subport.offset());
let subport_index =
self.multigraph.graph.port_index(copy_node, offset).unwrap();
self.multigraph.graph.port_link(subport_index)
}
let copy_node = self
.current_copy_node
.expect("Copy node not connected to a multiport.");
let dir = self.multigraph.graph.port_direction(port_index).unwrap();
let offset = PortOffset::new(dir, subport.offset());
let subport_index = self.multigraph.graph.port_index(copy_node, offset).unwrap();
self.multigraph.graph.port_link(subport_index)
})?;
let link_subport = self.multigraph.get_subport_from_index(link).unwrap();
let node = self
.multigraph
.graph
.port_node(link_subport.port())
.unwrap();
// Ignore self-loops in the input -> output direction.
if self.ignore_dupped_self_loops
&& node == self.node
&& self.multigraph.port_direction(link_subport.port()).unwrap()
== Direction::Outgoing
{
continue;
}
})?;
let link_subport = self.multigraph.get_subport_from_index(link).unwrap();
self.multigraph.graph.port_node(link_subport.port())
return Some(node);
}
}
}

Expand All @@ -178,14 +207,22 @@ pub struct NodeLinks<'a> {
multigraph: &'a MultiPortGraph,
ports: NodePorts,
current_links: Option<PortLinks<'a>>,
/// Ignore links to the given target ports.
/// This is used to avoid counting self-loops twice.
ignore_target_ports: Range<usize>,
}

impl<'a> NodeLinks<'a> {
pub(super) fn new(multigraph: &'a MultiPortGraph, ports: NodePorts) -> Self {
pub(super) fn new(
multigraph: &'a MultiPortGraph,
ports: NodePorts,
ignore_target_ports: Range<usize>,
) -> Self {
Self {
multigraph,
ports,
current_links: None,
ignore_target_ports,
}
}
}
Expand All @@ -196,14 +233,20 @@ impl<'a> Iterator for NodeLinks<'a> {

fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(links) = &mut self.current_links {
if let Some(link) = links.next() {
return Some(link);
}
let Some(links) = &mut self.current_links else {
let port = self.ports.next()?;
self.current_links = Some(PortLinks::new(self.multigraph, port));
continue;
};
let Some((from, to)) = links.next() else {
self.current_links = None;
continue;
};
// Ignore self-loops in the input -> output direction.
if self.ignore_target_ports.contains(&to.port().index()) {
continue;
}
let port = self.ports.next()?;
self.current_links = Some(PortLinks::new(self.multigraph, port));
return Some((from, to));
}
}
}
Expand Down
Loading
Loading