Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generalize conv mem layout and ND #935

Merged
merged 12 commits into from
Feb 10, 2025
Prev Previous commit
Next Next commit
patch conv
  • Loading branch information
alexander-camuto committed Feb 8, 2025
commit 3cd9d2ad80d1960f53b76adccb2e147404807135
2 changes: 2 additions & 0 deletions benches/accum_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ impl Circuit<Fr> for MyCircuit {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
}),
)
.unwrap();
Expand Down
3 changes: 3 additions & 0 deletions examples/conv2d_mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use instant::Instant;
use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
use tract_onnx::tract_core::ndarray::Data;

mod params;

Expand Down Expand Up @@ -208,6 +209,8 @@ where
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
};
let x = config
.layer_config
Expand Down
Loading