Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert NodeId from an alias for usize to u32-sized opaque type #381

Merged
merged 1 commit into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,8 +764,12 @@ mod tests {
/// Return a model with a given set of inputs and outputs.
fn with_inputs_and_outputs(inputs: &[NodeInfo], outputs: &[NodeInfo]) -> FakeModel {
let node_infos = [inputs, outputs].concat();
let input_ids = (0..inputs.len()).collect();
let output_ids = (inputs.len()..(inputs.len() + outputs.len())).collect();
let input_ids = (0..inputs.len())
.map(|id| NodeId::from_u32(id as u32))
.collect();
let output_ids = (inputs.len()..(inputs.len() + outputs.len()))
.map(|id| NodeId::from_u32(id as u32))
.collect();

FakeModel {
input_ids,
Expand Down Expand Up @@ -796,11 +800,14 @@ mod tests {

impl Model for FakeModel {
fn find_node(&self, name: &str) -> Option<NodeId> {
self.nodes.iter().position(|info| info.name() == name)
self.nodes
.iter()
.position(|info| info.name() == name)
.map(|pos| NodeId::from_u32(pos as u32))
}

fn node_info(&self, id: NodeId) -> Option<NodeInfo> {
self.nodes.get(id).cloned()
self.nodes.get(id.as_u32() as usize).cloned()
}

fn input_ids(&self) -> &[NodeId] {
Expand Down
113 changes: 81 additions & 32 deletions src/graph.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::num::NonZero;
use std::sync::{Arc, Mutex};
use std::time::Duration;

Expand Down Expand Up @@ -281,8 +282,45 @@ impl Node {
}
}

/// ID of a node in a [Model](crate::Model) graph.
pub type NodeId = usize;
/// ID of a node in a [`Model`](crate::Model) graph.
///
/// This is used to identify input and output values as well as internal nodes.
///
/// Node IDs are u32 values <= `i32::MAX`.
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct NodeId(NonZero<u32>);

impl NodeId {
/// Return the underlying u32 value of the ID.
pub fn as_u32(self) -> u32 {
self.0.get() - 1
}

/// Construct a node ID from a u32 value.
///
/// Panics if the value exceeds `i32::MAX`.
pub fn from_u32(value: u32) -> NodeId {
// Node IDs are limited to `i32::MAX` because the `OperatorNode` type
// in the FlatBuffers schema represents operator input and output IDs
// as `i32`. Negative values are used as a niche to represent missing
// optional inputs.
assert!(value <= i32::MAX as u32);

// Valid node IDs are in the range `[0, i32::MAX]`, so we store them as
// values in `[1, i32::MAX + 1]` internally and reserve 0 as a niche to
// make `Option<NodeId>` the same size as `NodeId`.
NodeId(unsafe {
// Safety: `value + 1` cannot be zero
NonZero::new_unchecked(value + 1)
})
}
}

impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.as_u32().fmt(f)
}
}

/// Reasons why a graph execution failed
#[derive(Eq, PartialEq, Debug)]
Expand Down Expand Up @@ -368,14 +406,14 @@ impl NodeRefCount {
/// Increment ref count of node. If the refcount reaches `u8::MAX` it
/// will become "sticky" and never decrement.
fn inc(&mut self, id: NodeId) {
let rc = &mut self.rc[id];
let rc = &mut self.rc[id.as_u32() as usize];
*rc = rc.saturating_add(1);
}

/// Decrement ref count of node and return new count, or `None` if the
/// ref count was already zero.
fn dec(&mut self, id: NodeId) -> Option<usize> {
let rc = &mut self.rc[id];
let rc = &mut self.rc[id.as_u32() as usize];

// If the refcount reaches the max value, it becomes sticky.
if *rc == u8::MAX {
Expand All @@ -389,7 +427,7 @@ impl NodeRefCount {
}

fn count(&self, id: NodeId) -> usize {
self.rc[id] as usize
self.rc[id.as_u32() as usize] as usize
}
}

Expand Down Expand Up @@ -674,10 +712,10 @@ impl Graph {
}

pub fn add_node(&mut self, node: Node) -> NodeId {
let node_id = NodeId::from_u32(self.nodes.len() as u32);
self.nodes.push(node);
let node_id = self.nodes.len() - 1;

if let Some(name) = self.nodes[node_id].name() {
if let Some(name) = self.nodes[node_id.as_u32() as usize].name() {
self.node_id_from_name.insert(name.to_string(), node_id);
}

Expand Down Expand Up @@ -775,7 +813,10 @@ impl Graph {

/// Return an iterator over nodes in the graph.
pub fn iter(&self) -> impl Iterator<Item = (NodeId, &Node)> {
self.nodes.iter().enumerate()
self.nodes
.iter()
.enumerate()
.map(|(i, node)| (NodeId::from_u32(i as u32), node))
}

/// Return the debug name for a node.
Expand All @@ -788,7 +829,7 @@ impl Graph {

/// Retrieve a node by ID
pub fn get_node(&self, id: NodeId) -> Option<&Node> {
self.nodes.get(id)
self.nodes.get(id.as_u32() as usize)
}

/// Look up a node ID given its unique name
Expand All @@ -808,7 +849,7 @@ impl Graph {

/// Retrieve a node by ID
pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
self.nodes.get_mut(id)
self.nodes.get_mut(id.as_u32() as usize)
}

/// Return the total number of parameters in all constant nodes in this
Expand Down Expand Up @@ -928,7 +969,7 @@ impl Graph {

let inputs_by_id: FxHashMap<NodeId, InputOrOutput> = inputs.iter().cloned().collect();
let get_value_from_constant_or_input = |node_id: NodeId| -> Option<Input> {
match self.nodes.get(node_id) {
match self.nodes.get(node_id.as_u32() as usize) {
Some(Node::Constant(constant)) => Some(constant.as_input()),
Some(Node::Value(_)) => inputs_by_id.get(&node_id).map(|input| input.as_input()),
_ => {
Expand All @@ -938,21 +979,24 @@ impl Graph {
};

let get_value_from_capture = |node_id: NodeId| -> Option<Input> {
let name = self.nodes.get(node_id).and_then(|n| n.name())?;
let name = self
.nodes
.get(node_id.as_u32() as usize)
.and_then(|n| n.name())?;
captures.as_ref().and_then(|cap| cap.get_input(name))
};

// Count how often each temporary output is used, so we can free them
// when no longer needed.
let mut temp_value_refcount = NodeRefCount::with_capacity(self.nodes.len());
for &op_node_id in plan.iter() {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id) else {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id.as_u32() as usize) else {
return Err(RunError::PlanningError(
"operator node not found".to_string(),
));
};
for node_id in self.operator_dependencies(op_node) {
if let Some(Node::Value(_)) = self.nodes.get(node_id) {
if let Some(Node::Value(_)) = self.nodes.get(node_id.as_u32() as usize) {
temp_value_refcount.inc(node_id);
}
}
Expand Down Expand Up @@ -984,7 +1028,7 @@ impl Graph {
let mut op_start = Instant::now();

for (step, &op_node_id) in plan.iter().enumerate() {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id) else {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id.as_u32() as usize) else {
return Err(RunError::PlanningError(
"operator node not found".to_string(),
));
Expand Down Expand Up @@ -1305,7 +1349,7 @@ impl Graph {
// Walk forwards through the plan and prune away steps that cannot be
// computed due to missing inputs.
for &node_id in plan {
let Some(Node::Operator(op_node)) = self.nodes.get(node_id) else {
let Some(Node::Operator(op_node)) = self.nodes.get(node_id.as_u32() as usize) else {
continue;
};

Expand Down Expand Up @@ -1354,12 +1398,11 @@ impl Graph {
inputs: I,
include_captures: bool,
) -> FxHashSet<NodeId> {
let mut resolved: FxHashSet<NodeId> =
inputs
.chain(self.nodes.iter().enumerate().filter_map(|(node_id, node)| {
matches!(node, Node::Constant(_)).then_some(node_id)
}))
.collect();
let mut resolved: FxHashSet<NodeId> = inputs
.chain(self.nodes.iter().enumerate().filter_map(|(node_id, node)| {
matches!(node, Node::Constant(_)).then_some(NodeId::from_u32(node_id as u32))
}))
.collect();

if include_captures {
resolved.extend(self.captures().iter().copied());
Expand Down Expand Up @@ -1514,7 +1557,7 @@ mod tests {
use smallvec::{smallvec, SmallVec};

use super::{CachedPlan, CaptureEnv};
use crate::graph::{Dimension, Graph, Node, RunError, RunOptions, TypedConstant};
use crate::graph::{Dimension, Graph, Node, NodeId, RunError, RunOptions, TypedConstant};
use crate::ops::{
Add, Concat, Conv, Identity, If, InputList, IntoOpResult, Mul, OpError, Operator, Output,
OutputList, Relu, Shape,
Expand Down Expand Up @@ -1943,7 +1986,7 @@ mod tests {
#[test]
fn test_err_if_invalid_output() {
let g = Graph::new();
let result = g.run(vec![], &[123], None);
let result = g.run(vec![], &[NodeId::from_u32(123)], None);
assert_eq!(
result.err(),
Some(RunError::PlanningError("Missing output 123".to_string()))
Expand All @@ -1953,7 +1996,7 @@ mod tests {
#[test]
fn test_err_if_missing_operator_input() {
let mut g = Graph::new();
let (_, output) = g.add_simple_op("op", Relu {}, &[42]);
let (_, output) = g.add_simple_op("op", Relu {}, &[NodeId::from_u32(42)]);
let result = g.run(vec![], &[output], None);
assert_eq!(
result.err(),
Expand Down Expand Up @@ -2268,21 +2311,27 @@ mod tests {

#[test]
fn test_cached_plan_matches() {
let input_ids = &[3, 1, 2];
let output_ids = &[6, 4, 5];
let op_ids = &[10, 11, 12];
let input_ids = &[3, 1, 2].map(NodeId::from_u32);
let output_ids = &[6, 4, 5].map(NodeId::from_u32);
let op_ids = &[10, 11, 12].map(NodeId::from_u32);

let plan = CachedPlan::new(input_ids, output_ids, op_ids.to_vec());

assert!(plan.matches(input_ids, output_ids));

// Same input and output IDs, different orders.
assert!(plan.matches(&[1, 2, 3], &[4, 5, 6]));
assert!(plan.matches(&[3, 2, 1], &[6, 5, 4]));
assert!(plan.matches(
&[1, 2, 3].map(NodeId::from_u32),
&[4, 5, 6].map(NodeId::from_u32)
));
assert!(plan.matches(
&[3, 2, 1].map(NodeId::from_u32),
&[6, 5, 4].map(NodeId::from_u32)
));

// Different input and output IDs
assert!(!plan.matches(&[20, 21, 22], output_ids));
assert!(!plan.matches(input_ids, &[20, 21, 22]));
assert!(!plan.matches(&[20, 21, 22].map(NodeId::from_u32), output_ids));
assert!(!plan.matches(input_ids, &[20, 21, 22].map(NodeId::from_u32)));
}

/// A trivial control flow operator which just forwards inputs to a subgraph
Expand Down
38 changes: 17 additions & 21 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,20 +356,20 @@ impl Model {

let input_ids: Vec<NodeId> = serialized_graph
.inputs()
.map(|ids| ids.iter().map(|id| id as NodeId).collect())
.map(|ids| ids.iter().map(NodeId::from_u32).collect())
.unwrap_or_default();

let output_ids: Vec<NodeId> = serialized_graph
.outputs()
.map(|ids| ids.iter().map(|id| id as NodeId).collect())
.map(|ids| ids.iter().map(NodeId::from_u32).collect())
.unwrap_or_default();

let mut graph = Graph::with_capacity(node_count);
graph.set_input_ids(&input_ids);
graph.set_output_ids(&output_ids);

if let Some(captures) = serialized_graph.captures() {
let captures: Vec<NodeId> = captures.iter().map(|id| id as NodeId).collect();
let captures: Vec<NodeId> = captures.iter().map(NodeId::from_u32).collect();
graph.set_captures(&captures);
}

Expand Down Expand Up @@ -839,7 +839,7 @@ mod tests {
use rten_tensor::prelude::*;
use rten_tensor::Tensor;

use crate::graph::{Dimension, RunError};
use crate::graph::{Dimension, NodeId, RunError};
use crate::model::{Model, ModelOptions};
use crate::model_builder::{
GraphBuilder, IfArgs, MetadataArgs, ModelBuilder, ModelFormat, OpType,
Expand Down Expand Up @@ -1147,7 +1147,7 @@ mod tests {
.load(buffer)
.unwrap();

let result = model.run(vec![], &[output_node as usize], None);
let result = model.run(vec![], &[output_node], None);

assert_eq!(
result.err(),
Expand Down Expand Up @@ -1181,7 +1181,7 @@ mod tests {
let mut op_outputs = Vec::new();

let mut add_operator =
|builder: &mut GraphBuilder, name: &str, op: OpType, input_nodes: &[Option<u32>]| {
|builder: &mut GraphBuilder, name: &str, op: OpType, input_nodes: &[Option<NodeId>]| {
let output_name = format!("{}_out", name);
let op_output_node = builder.add_value(&output_name, None);
builder.add_operator(name, op, input_nodes, &[op_output_node]);
Expand Down Expand Up @@ -1605,8 +1605,8 @@ mod tests {
let result = model
.run(
vec![
(input_node as usize, input.view().into()),
(input_bool as usize, input_bool_data.view().into()),
(input_node, input.view().into()),
(input_bool, input_bool_data.view().into()),
],
&[output_id],
None,
Expand All @@ -1629,11 +1629,7 @@ mod tests {
for output in outputs {
let output_id = model.find_node(output).unwrap();
let result = model
.run(
vec![(input_2d as usize, input.view().into())],
&[output_id],
None,
)
.run(vec![(input_2d, input.view().into())], &[output_id], None)
.unwrap();
assert_eq!(result.len(), 1);
}
Expand All @@ -1645,11 +1641,11 @@ mod tests {
let result = model
.run(
vec![
(range_start_node as usize, start.into()),
(range_limit_node as usize, limit.into()),
(range_delta_node as usize, delta.into()),
(range_start_node, start.into()),
(range_limit_node, limit.into()),
(range_delta_node, delta.into()),
],
&[range_out as usize],
&[range_out],
None,
)
.unwrap();
Expand All @@ -1662,11 +1658,11 @@ mod tests {
let result = model
.run(
vec![
(where_cond as usize, cond.into()),
(where_x as usize, x.into()),
(where_y as usize, y.into()),
(where_cond, cond.into()),
(where_x, x.into()),
(where_y, y.into()),
],
&[where_out as usize],
&[where_out],
None,
)
.unwrap();
Expand Down
Loading
Loading