Skip to content

Commit a7544f4

Browse files
feat: generalize conv mem layout and ND (#935)
1 parent c19fa52 commit a7544f4

File tree

17 files changed

+672
-178
lines changed

17 files changed

+672
-178
lines changed

.github/workflows/rust.yml

+2
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ jobs:
276276
locked: true
277277
# - name: The Worm Mock
278278
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
279+
- name: Large 1D Conv Mock
280+
run: cargo nextest run --verbose tests::large_mock_::large_tests_7_expects -- --include-ignored
279281
- name: MNIST Gan Mock
280282
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
281283
- name: NanoGPT Mock

Cargo.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

benches/accum_conv.rs

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ impl Circuit<Fr> for MyCircuit {
7373
padding: vec![(0, 0)],
7474
stride: vec![1; 2],
7575
group: 1,
76+
data_format: DataFormat::NCHW,
77+
kernel_format: KernelFormat::OIHW,
7678
}),
7779
)
7880
.unwrap();

examples/conv2d_mnist/main.rs

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use mnist::*;
3232
use rand::rngs::OsRng;
3333
use std::marker::PhantomData;
3434

35+
3536
mod params;
3637

3738
const K: usize = 20;
@@ -208,6 +209,8 @@ where
208209
padding: vec![(PADDING, PADDING); 2],
209210
stride: vec![STRIDE; 2],
210211
group: 1,
212+
data_format: DataFormat::NCHW,
213+
kernel_format: KernelFormat::OIHW,
211214
};
212215
let x = config
213216
.layer_config

examples/onnx/1d_conv/input.json

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
{
2+
"input_data": [
3+
[
4+
8761,
5+
7654,
6+
8501,
7+
2404,
8+
6929,
9+
8858,
10+
5946,
11+
3673,
12+
4131,
13+
3854,
14+
8137,
15+
8239,
16+
9038,
17+
6299,
18+
1118,
19+
9737,
20+
208,
21+
7954,
22+
3691,
23+
610,
24+
3468,
25+
3314,
26+
8658,
27+
8366,
28+
2850,
29+
477,
30+
6114,
31+
232,
32+
4601,
33+
7420,
34+
5713,
35+
2936,
36+
6061,
37+
2870,
38+
8421,
39+
177,
40+
7107,
41+
7382,
42+
6115,
43+
5487,
44+
8502,
45+
2559,
46+
1875,
47+
129,
48+
8533,
49+
8201,
50+
8414,
51+
4775,
52+
9817,
53+
3127,
54+
8761,
55+
7654,
56+
8501,
57+
2404,
58+
6929,
59+
8858,
60+
5946,
61+
3673,
62+
4131,
63+
3854,
64+
8137,
65+
8239,
66+
9038,
67+
6299,
68+
1118,
69+
9737,
70+
208,
71+
7954,
72+
3691,
73+
610,
74+
3468,
75+
3314,
76+
8658,
77+
8366,
78+
2850,
79+
477,
80+
6114,
81+
232,
82+
4601,
83+
7420,
84+
5713,
85+
2936,
86+
6061,
87+
2870,
88+
8421,
89+
177,
90+
7107,
91+
7382,
92+
6115,
93+
5487,
94+
8502,
95+
2559,
96+
1875,
97+
129,
98+
8533,
99+
8201,
100+
8414,
101+
4775,
102+
9817,
103+
3127
104+
]
105+
]
106+
}

examples/onnx/1d_conv/network.onnx

4.28 MB
Binary file not shown.

src/circuit/ops/hybrid.rs

+13-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::{
33
circuit::{layouts, utils, Tolerance},
44
fieldutils::{integer_rep_to_felt, IntegerRep},
55
graph::multiplier_to_scale,
6-
tensor::{self, Tensor, TensorType, ValTensor},
6+
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
77
};
88
use halo2curves::ff::PrimeField;
99
use serde::{Deserialize, Serialize};
@@ -57,11 +57,13 @@ pub enum HybridOp {
5757
stride: Vec<usize>,
5858
kernel_shape: Vec<usize>,
5959
normalized: bool,
60+
data_format: DataFormat,
6061
},
6162
MaxPool {
6263
padding: Vec<(usize, usize)>,
6364
stride: Vec<usize>,
6465
pool_dims: Vec<usize>,
66+
data_format: DataFormat,
6567
},
6668
ReduceMin {
6769
axes: Vec<usize>,
@@ -154,20 +156,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
154156
padding,
155157
stride,
156158
kernel_shape,
157-
normalized,
159+
normalized, data_format
158160
} => format!(
159-
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={})",
160-
padding, stride, kernel_shape, normalized
161+
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
162+
padding, stride, kernel_shape, normalized, data_format
161163
),
162164
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
163165
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
164166
HybridOp::MaxPool {
165167
padding,
166168
stride,
167169
pool_dims,
170+
data_format,
168171
} => format!(
169-
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
170-
padding, stride, pool_dims
172+
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?}, data_format={:?})",
173+
padding, stride, pool_dims, data_format
171174
),
172175
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
173176
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
@@ -239,6 +242,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
239242
stride,
240243
kernel_shape,
241244
normalized,
245+
data_format,
242246
} => layouts::sumpool(
243247
config,
244248
region,
@@ -247,6 +251,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
247251
stride,
248252
kernel_shape,
249253
*normalized,
254+
*data_format,
250255
)?,
251256
HybridOp::Recip {
252257
input_scale,
@@ -287,13 +292,15 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
287292
padding,
288293
stride,
289294
pool_dims,
295+
data_format,
290296
} => layouts::max_pool(
291297
config,
292298
region,
293299
values[..].try_into()?,
294300
padding,
295301
stride,
296302
pool_dims,
303+
*data_format,
297304
)?,
298305
HybridOp::ReduceMax { axes } => {
299306
layouts::max_axes(config, region, values[..].try_into()?, axes)?

0 commit comments

Comments
 (0)