@@ -103,12 +103,12 @@ namespace nil {
103
103
ArithmetizationParams>>:
104
104
public plonk_component<BlueprintFieldType, ArithmetizationParams, 1 , 0 > {
105
105
106
- static std::size_t comaprisons_per_gate_instance_internal (std::size_t witness_amount) {
106
+ static std::size_t comparisons_per_gate_instance_internal (std::size_t witness_amount) {
107
107
return 1 + (witness_amount - 3 ) / 2 ;
108
108
}
109
109
110
110
static std::size_t bits_per_gate_instance_internal (std::size_t witness_amount) {
111
- return comaprisons_per_gate_instance_internal (witness_amount) * chunk_size;
111
+ return comparisons_per_gate_instance_internal (witness_amount) * chunk_size;
112
112
}
113
113
114
114
static std::size_t rows_amount_internal (std::size_t witness_amount, std::size_t bits_amount) {
@@ -123,7 +123,7 @@ namespace nil {
123
123
124
124
static std::size_t padded_chunks_internal (std::size_t witness_amount, std::size_t bits_amount) {
125
125
return gate_instances_internal (witness_amount, bits_amount) *
126
- comaprisons_per_gate_instance_internal (witness_amount);
126
+ comparisons_per_gate_instance_internal (witness_amount);
127
127
}
128
128
129
129
static std::size_t padding_bits_internal (std::size_t witness_amount, std::size_t bits_amount) {
@@ -206,6 +206,9 @@ namespace nil {
206
206
comparison_mode mode) {
207
207
return rows_amount_internal (witness_amount, bits_amount);
208
208
}
209
+ constexpr static std::size_t get_empty_rows_amount () {
210
+ return 1 ;
211
+ }
209
212
210
213
/*
211
214
It's CRITICAL that these three variables remain on top
@@ -217,12 +220,13 @@ namespace nil {
217
220
/* Do NOT move the above variables! */
218
221
219
222
const std::size_t comparisons_per_gate_instance =
220
- comaprisons_per_gate_instance_internal (this ->witness_amount ());
223
+ comparisons_per_gate_instance_internal (this ->witness_amount ());
221
224
const std::size_t bits_per_gate_instance =
222
225
bits_per_gate_instance_internal (this ->witness_amount ());
223
226
const bool needs_bonus_row = needs_bonus_row_internal(this ->witness_amount ());
224
227
225
228
const std::size_t rows_amount = rows_amount_internal(this ->witness_amount (), bits_amount);
229
+ const std::size_t empty_rows_amount = get_empty_rows_amount();
226
230
227
231
const std::size_t gate_instances = gate_instances_internal(this ->witness_amount (), bits_amount);
228
232
const std::size_t padded_chunks = padded_chunks_internal(this ->witness_amount (), bits_amount);
@@ -245,6 +249,9 @@ namespace nil {
245
249
std::size_t outuput_w = component.needs_bonus_row ? 0 : 3 ;
246
250
flag = var (component.W (outuput_w), start_row_index + component.rows_amount - 1 , false );
247
251
}
252
+ result_type (const comparison_flag &component, std::size_t start_row_index, bool skip) {
253
+ flag = var (component.W (0 ), start_row_index, false );
254
+ }
248
255
249
256
std::vector<var> all_vars () const {
250
257
return {flag};
@@ -281,6 +288,118 @@ namespace nil {
281
288
282
289
check_params (bits_amount, mode);
283
290
};
291
+
292
+ static typename BlueprintFieldType::value_type calculate (std::size_t witness_amount,
293
+ typename BlueprintFieldType::value_type x,
294
+ typename BlueprintFieldType::value_type y,
295
+ std::size_t arg_bits_amount, comparison_mode arg_mode) {
296
+
297
+ using value_type = typename BlueprintFieldType::value_type;
298
+ using integral_type = typename BlueprintFieldType::integral_type;
299
+ using chunk_type = std::uint8_t ;
300
+
301
+ auto chunk_size = 2 ;
302
+ auto padding_bits = padding_bits_internal (witness_amount, arg_bits_amount);
303
+ auto padded_chunks = padded_chunks_internal (witness_amount, arg_bits_amount);
304
+ auto comparisons_per_gate_instance = comparisons_per_gate_instance_internal (witness_amount);
305
+ auto gate_instances = gate_instances_internal (witness_amount, arg_bits_amount);
306
+
307
+ BOOST_ASSERT (chunk_size <= 8 );
308
+
309
+ std::array<integral_type, 2 > integrals = {integral_type (x.data ), integral_type (y.data )};
310
+
311
+ std::array<std::vector<bool >, 2 > bits;
312
+ for (std::size_t i = 0 ; i < 2 ; i++) {
313
+ std::fill (bits[i].begin (), bits[i].end (), false );
314
+ bits[i].resize (arg_bits_amount + padding_bits);
315
+
316
+ nil::marshalling::status_type status;
317
+ std::array<bool , BlueprintFieldType::modulus_bits> bytes_all =
318
+ nil::marshalling::pack<nil::marshalling::option::big_endian>(integrals[i], status);
319
+ std::copy (bytes_all.end () - arg_bits_amount, bytes_all.end (),
320
+ bits[i].begin () + padding_bits);
321
+ assert (status == nil::marshalling::status_type::success);
322
+ }
323
+
324
+ BOOST_ASSERT (padded_chunks * chunk_size ==
325
+ arg_bits_amount + padding_bits);
326
+ std::array<std::vector<chunk_type>, 2 > chunks;
327
+ for (std::size_t i = 0 ; i < 2 ; i++) {
328
+ chunks[i].resize (padded_chunks);
329
+ for (std::size_t j = 0 ; j < padded_chunks; j++) {
330
+ chunk_type chunk_value = 0 ;
331
+ for (std::size_t k = 0 ; k < chunk_size; k++) {
332
+ chunk_value <<= 1 ;
333
+ chunk_value |= bits[i][j * chunk_size + k];
334
+ }
335
+ chunks[i][j] = chunk_value;
336
+ }
337
+ }
338
+
339
+ value_type greater_val = - value_type (2 ).pow (chunk_size),
340
+ last_flag = 0 ;
341
+ std::array<value_type, 2 > sum = {0 , 0 };
342
+
343
+ for (std::size_t i = 0 ; i < gate_instances; i++) {
344
+ std::array<chunk_type, 2 > current_chunk = {0 , 0 };
345
+ std::size_t base_idx, chunk_idx;
346
+
347
+ // I basically used lambdas instead of macros to cut down on code reuse.
348
+ // Note that the captures are by reference!
349
+ auto calculate_flag = [¤t_chunk, &greater_val](value_type last_flag) {
350
+ return last_flag != 0 ? last_flag
351
+ : (current_chunk[0 ] > current_chunk[1 ] ? 1
352
+ : current_chunk[0 ] == current_chunk[1 ] ? 0 : greater_val);
353
+ };
354
+ auto calculate_temp = [¤t_chunk](value_type last_flag) {
355
+ return last_flag != 0 ? last_flag : current_chunk[0 ] - current_chunk[1 ];
356
+ };
357
+ // WARNING: this one is impure! But the code after it gets to look nicer.
358
+ auto place_chunk_pair = [¤t_chunk, &chunks, &sum, &chunk_size](
359
+ std::size_t base_idx, std::size_t chunk_idx) {
360
+ for (std::size_t k = 0 ; k < 2 ; k++) {
361
+ current_chunk[k] = chunks[k][chunk_idx];
362
+ sum[k] *= (1 << chunk_size);
363
+ sum[k] += current_chunk[k];
364
+ }
365
+ };
366
+
367
+ for (std::size_t j = 0 ; j < comparisons_per_gate_instance - 1 ; j++) {
368
+ base_idx = 3 + j * 2 ;
369
+ chunk_idx = i * comparisons_per_gate_instance + j;
370
+
371
+ place_chunk_pair (base_idx, chunk_idx);
372
+ last_flag = calculate_flag (last_flag);
373
+ }
374
+ // Last chunk
375
+ base_idx = 0 ;
376
+ chunk_idx = i * comparisons_per_gate_instance +
377
+ comparisons_per_gate_instance - 1 ;
378
+
379
+ place_chunk_pair (base_idx, chunk_idx);
380
+ last_flag = calculate_flag (last_flag);
381
+ }
382
+ value_type output;
383
+ switch (arg_mode) {
384
+ case comparison_mode::FLAG:
385
+ output = last_flag != greater_val ? last_flag : -1 ;
386
+ break ;
387
+ case comparison_mode::LESS_THAN:
388
+ output = last_flag == greater_val;
389
+ break ;
390
+ case comparison_mode::LESS_EQUAL:
391
+ output = (last_flag == greater_val) || (last_flag == 0 );
392
+ break ;
393
+ case comparison_mode::GREATER_THAN:
394
+ output = last_flag == 1 ;
395
+ break ;
396
+ case comparison_mode::GREATER_EQUAL:
397
+ output = (last_flag == 1 ) || (last_flag == 0 );
398
+ break ;
399
+ }
400
+
401
+ return output;
402
+ }
284
403
};
285
404
286
405
template <typename BlueprintFieldType, typename ArithmetizationParams>
@@ -465,6 +584,31 @@ namespace nil {
465
584
return typename component_type::result_type (component, start_row_index);
466
585
}
467
586
587
+ template <typename BlueprintFieldType, typename ArithmetizationParams>
588
+ typename plonk_comparison_flag<BlueprintFieldType, ArithmetizationParams>::result_type
589
+ generate_empty_assignments (
590
+ const plonk_comparison_flag<BlueprintFieldType, ArithmetizationParams>
591
+ &component,
592
+ assignment<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
593
+ ArithmetizationParams>>
594
+ &assignment,
595
+ const typename plonk_comparison_flag<BlueprintFieldType, ArithmetizationParams>::input_type
596
+ &instance_input,
597
+ const std::uint32_t start_row_index) {
598
+
599
+ using component_type = plonk_comparison_flag<BlueprintFieldType, ArithmetizationParams>;
600
+ using value_type = typename BlueprintFieldType::value_type;
601
+ using integral_type = typename BlueprintFieldType::integral_type;
602
+
603
+ value_type x = var_value (assignment, instance_input.x ),
604
+ y = var_value (assignment, instance_input.y );
605
+
606
+ assignment.witness (component.W (0 ), start_row_index) =
607
+ component_type::calculate (component.witness_amount (), x, y, component.bits_amount , component.mode );
608
+
609
+ return typename component_type::result_type (component, start_row_index, true );
610
+ }
611
+
468
612
template <typename BlueprintFieldType, typename ArithmetizationParams>
469
613
std::vector<std::size_t > generate_gates (
470
614
const plonk_comparison_flag<BlueprintFieldType, ArithmetizationParams>
0 commit comments