Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align name of Forward enum variant for RNN ops with ONNX attribute value #67

Merged
merged 2 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def convert_attr(val: str):
f'Replacing unsupported value "{val}" for "{name}" attr in {op} op with "{fallback}"'
)
return convert_attr(fallback)
raise ValueError(f"Unsupported value {val} for {name} attr")
raise ValueError(f'Unsupported value "{val}" for "{name}" attr')

def ignore_attr(self, name: str):
"""
Expand Down
2 changes: 1 addition & 1 deletion rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class OperatorType(object):


class RNNDirection(object):
Forwards = 0
Forward = 0
Reverse = 1
Bidirectional = 2

Expand Down
8 changes: 4 additions & 4 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -704,10 +704,10 @@ fn read_gru_op(node: &OperatorNode) -> ReadOpResult {

let hidden_size = attrs.hidden_size() as usize;
let direction = match attrs.direction() {
sg::RNNDirection::Forwards => Direction::Forwards,
sg::RNNDirection::Forward => Direction::Forward,
sg::RNNDirection::Reverse => Direction::Reverse,
sg::RNNDirection::Bidirectional => Direction::Bidirectional,
_ => Direction::Forwards,
_ => Direction::Forward,
};

Ok(Box::new(ops::GRU {
Expand Down Expand Up @@ -762,10 +762,10 @@ fn read_lstm_op(node: &OperatorNode) -> ReadOpResult {

let hidden_size = attrs.hidden_size() as usize;
let direction = match attrs.direction() {
sg::RNNDirection::Forwards => Direction::Forwards,
sg::RNNDirection::Forward => Direction::Forward,
sg::RNNDirection::Reverse => Direction::Reverse,
sg::RNNDirection::Bidirectional => Direction::Bidirectional,
_ => Direction::Forwards,
_ => Direction::Forward,
};

Ok(Box::new(ops::LSTM {
Expand Down
26 changes: 13 additions & 13 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::ops::{
/// Direction that an RNN operator will traverse the input sequence in.
#[derive(Copy, Clone, Debug)]
pub enum Direction {
Forwards,
Forward,
Reverse,
Bidirectional,
}
Expand All @@ -24,25 +24,25 @@ impl Direction {
/// Number of directions that an RNN operator will traverse the sequence in.
pub fn num_directions(self) -> usize {
match self {
Self::Forwards | Self::Reverse => 1,
Self::Forward | Self::Reverse => 1,
Self::Bidirectional => 2,
}
}
}

/// Forwards or backwards iterator over values in a range.
/// Forward or backward iterator over values in a range.
enum Sequence {
Forwards(Range<usize>),
Backwards(Rev<Range<usize>>),
Forward(Range<usize>),
Backward(Rev<Range<usize>>),
}

impl Iterator for Sequence {
type Item = usize;

fn next(&mut self) -> Option<usize> {
match self {
Sequence::Forwards(range) => range.next(),
Sequence::Backwards(rev_range) => rev_range.next(),
Sequence::Forward(range) => range.next(),
Sequence::Backward(rev_range) => rev_range.next(),
}
}
}
Expand All @@ -57,9 +57,9 @@ fn sequence_for_dir(op_dirs: Direction, dir: usize, seq_len: usize) -> Sequence
(0, Direction::Reverse) | (1, Direction::Bidirectional)
);
if reversed {
Sequence::Backwards((0..seq_len).rev())
Sequence::Backward((0..seq_len).rev())
} else {
Sequence::Forwards(0..seq_len)
Sequence::Forward(0..seq_len)
}
}

Expand Down Expand Up @@ -1045,23 +1045,23 @@ mod tests {
let cases = &[
Case {
name: "lstm_forwards",
dir: Direction::Forwards,
dir: Direction::Forward,
},
Case {
name: "lstm_initial",
dir: Direction::Forwards,
dir: Direction::Forward,
},
Case {
name: "lstm_bidirectional",
dir: Direction::Bidirectional,
},
Case {
name: "gru_forwards",
dir: Direction::Forwards,
dir: Direction::Forward,
},
Case {
name: "gru_initial",
dir: Direction::Forwards,
dir: Direction::Forward,
},
Case {
name: "gru_bidirectional",
Expand Down
2 changes: 1 addition & 1 deletion src/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ enum OperatorType: ubyte {
}

enum RNNDirection: ubyte {
Forwards,
Forward,
Reverse,
Bidirectional
}
Expand Down
20 changes: 10 additions & 10 deletions src/schema_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ pub const ENUM_MAX_RNNDIRECTION: u8 = 2;
)]
#[allow(non_camel_case_types)]
pub const ENUM_VALUES_RNNDIRECTION: [RNNDirection; 3] = [
RNNDirection::Forwards,
RNNDirection::Forward,
RNNDirection::Reverse,
RNNDirection::Bidirectional,
];
Expand All @@ -502,17 +502,17 @@ pub const ENUM_VALUES_RNNDIRECTION: [RNNDirection; 3] = [
pub struct RNNDirection(pub u8);
#[allow(non_upper_case_globals)]
impl RNNDirection {
pub const Forwards: Self = Self(0);
pub const Forward: Self = Self(0);
pub const Reverse: Self = Self(1);
pub const Bidirectional: Self = Self(2);

pub const ENUM_MIN: u8 = 0;
pub const ENUM_MAX: u8 = 2;
pub const ENUM_VALUES: &'static [Self] = &[Self::Forwards, Self::Reverse, Self::Bidirectional];
pub const ENUM_VALUES: &'static [Self] = &[Self::Forward, Self::Reverse, Self::Bidirectional];
/// Returns the variant's name or "" if unknown.
pub fn variant_name(self) -> Option<&'static str> {
match self {
Self::Forwards => Some("Forwards"),
Self::Forward => Some("Forward"),
Self::Reverse => Some("Reverse"),
Self::Bidirectional => Some("Bidirectional"),
_ => None,
Expand Down Expand Up @@ -3662,7 +3662,7 @@ impl<'a> GRUAttrs<'a> {
// which contains a valid value in this slot
unsafe {
self._tab
.get::<RNNDirection>(GRUAttrs::VT_DIRECTION, Some(RNNDirection::Forwards))
.get::<RNNDirection>(GRUAttrs::VT_DIRECTION, Some(RNNDirection::Forward))
.unwrap()
}
}
Expand Down Expand Up @@ -3714,7 +3714,7 @@ impl<'a> Default for GRUAttrsArgs {
#[inline]
fn default() -> Self {
GRUAttrsArgs {
direction: RNNDirection::Forwards,
direction: RNNDirection::Forward,
hidden_size: 0,
linear_before_reset: false,
}
Expand All @@ -3731,7 +3731,7 @@ impl<'a: 'b, 'b> GRUAttrsBuilder<'a, 'b> {
self.fbb_.push_slot::<RNNDirection>(
GRUAttrs::VT_DIRECTION,
direction,
RNNDirection::Forwards,
RNNDirection::Forward,
);
}
#[inline]
Expand Down Expand Up @@ -4041,7 +4041,7 @@ impl<'a> LSTMAttrs<'a> {
// which contains a valid value in this slot
unsafe {
self._tab
.get::<RNNDirection>(LSTMAttrs::VT_DIRECTION, Some(RNNDirection::Forwards))
.get::<RNNDirection>(LSTMAttrs::VT_DIRECTION, Some(RNNDirection::Forward))
.unwrap()
}
}
Expand Down Expand Up @@ -4080,7 +4080,7 @@ impl<'a> Default for LSTMAttrsArgs {
#[inline]
fn default() -> Self {
LSTMAttrsArgs {
direction: RNNDirection::Forwards,
direction: RNNDirection::Forward,
hidden_size: 0,
}
}
Expand All @@ -4096,7 +4096,7 @@ impl<'a: 'b, 'b> LSTMAttrsBuilder<'a, 'b> {
self.fbb_.push_slot::<RNNDirection>(
LSTMAttrs::VT_DIRECTION,
direction,
RNNDirection::Forwards,
RNNDirection::Forward,
);
}
#[inline]
Expand Down
Loading