Skip to content
This repository was archived by the owner on Feb 17, 2025. It is now read-only.

Commit 5d13667

Browse files
tshchelovekIluvmagick
authored andcommitted
comparison flag #308
1 parent 1bec3d3 commit 5d13667

File tree

3 files changed

+155
-8
lines changed

3 files changed

+155
-8
lines changed

include/nil/blueprint/components/algebra/fields/plonk/non_native/comparison_flag.hpp

+148-4
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@ namespace nil {
103103
ArithmetizationParams>>:
104104
public plonk_component<BlueprintFieldType, ArithmetizationParams, 1, 0> {
105105

106-
static std::size_t comaprisons_per_gate_instance_internal(std::size_t witness_amount) {
106+
static std::size_t comparisons_per_gate_instance_internal(std::size_t witness_amount) {
107107
return 1 + (witness_amount - 3) / 2;
108108
}
109109

110110
static std::size_t bits_per_gate_instance_internal(std::size_t witness_amount) {
111-
return comaprisons_per_gate_instance_internal(witness_amount) * chunk_size;
111+
return comparisons_per_gate_instance_internal(witness_amount) * chunk_size;
112112
}
113113

114114
static std::size_t rows_amount_internal(std::size_t witness_amount, std::size_t bits_amount) {
@@ -123,7 +123,7 @@ namespace nil {
123123

124124
static std::size_t padded_chunks_internal(std::size_t witness_amount, std::size_t bits_amount) {
125125
return gate_instances_internal(witness_amount, bits_amount) *
126-
comaprisons_per_gate_instance_internal(witness_amount);
126+
comparisons_per_gate_instance_internal(witness_amount);
127127
}
128128

129129
static std::size_t padding_bits_internal(std::size_t witness_amount, std::size_t bits_amount) {
@@ -206,6 +206,9 @@ namespace nil {
206206
comparison_mode mode) {
207207
return rows_amount_internal(witness_amount, bits_amount);
208208
}
209+
constexpr static std::size_t get_empty_rows_amount() {
210+
return 1;
211+
}
209212

210213
/*
211214
It's CRITICAL that these three variables remain on top
@@ -217,12 +220,13 @@ namespace nil {
217220
/* Do NOT move the above variables! */
218221

219222
const std::size_t comparisons_per_gate_instance =
220-
comaprisons_per_gate_instance_internal(this->witness_amount());
223+
comparisons_per_gate_instance_internal(this->witness_amount());
221224
const std::size_t bits_per_gate_instance =
222225
bits_per_gate_instance_internal(this->witness_amount());
223226
const bool needs_bonus_row = needs_bonus_row_internal(this->witness_amount());
224227

225228
const std::size_t rows_amount = rows_amount_internal(this->witness_amount(), bits_amount);
229+
const std::size_t empty_rows_amount = get_empty_rows_amount();
226230

227231
const std::size_t gate_instances = gate_instances_internal(this->witness_amount(), bits_amount);
228232
const std::size_t padded_chunks = padded_chunks_internal(this->witness_amount(), bits_amount);
@@ -245,6 +249,9 @@ namespace nil {
245249
std::size_t outuput_w = component.needs_bonus_row ? 0 : 3;
246250
flag = var(component.W(outuput_w), start_row_index + component.rows_amount - 1, false);
247251
}
252+
result_type(const comparison_flag &component, std::size_t start_row_index, bool skip) {
253+
flag = var(component.W(0), start_row_index, false);
254+
}
248255

249256
std::vector<var> all_vars() const {
250257
return {flag};
@@ -281,6 +288,118 @@ namespace nil {
281288

282289
check_params(bits_amount, mode);
283290
};
291+
292+
static typename BlueprintFieldType::value_type calculate(std::size_t witness_amount,
293+
typename BlueprintFieldType::value_type x,
294+
typename BlueprintFieldType::value_type y,
295+
std::size_t arg_bits_amount, comparison_mode arg_mode) {
296+
297+
using value_type = typename BlueprintFieldType::value_type;
298+
using integral_type = typename BlueprintFieldType::integral_type;
299+
using chunk_type = std::uint8_t;
300+
301+
auto chunk_size = 2;
302+
auto padding_bits = padding_bits_internal(witness_amount, arg_bits_amount);
303+
auto padded_chunks = padded_chunks_internal(witness_amount, arg_bits_amount);
304+
auto comparisons_per_gate_instance = comparisons_per_gate_instance_internal(witness_amount);
305+
auto gate_instances = gate_instances_internal(witness_amount, arg_bits_amount);
306+
307+
BOOST_ASSERT(chunk_size <= 8);
308+
309+
std::array<integral_type, 2> integrals = {integral_type(x.data), integral_type(y.data)};
310+
311+
std::array<std::vector<bool>, 2> bits;
312+
for (std::size_t i = 0; i < 2; i++) {
313+
std::fill(bits[i].begin(), bits[i].end(), false);
314+
bits[i].resize(arg_bits_amount + padding_bits);
315+
316+
nil::marshalling::status_type status;
317+
std::array<bool, BlueprintFieldType::modulus_bits> bytes_all =
318+
nil::marshalling::pack<nil::marshalling::option::big_endian>(integrals[i], status);
319+
std::copy(bytes_all.end() - arg_bits_amount, bytes_all.end(),
320+
bits[i].begin() + padding_bits);
321+
assert(status == nil::marshalling::status_type::success);
322+
}
323+
324+
BOOST_ASSERT(padded_chunks * chunk_size ==
325+
arg_bits_amount + padding_bits);
326+
std::array<std::vector<chunk_type>, 2> chunks;
327+
for (std::size_t i = 0; i < 2; i++) {
328+
chunks[i].resize(padded_chunks);
329+
for (std::size_t j = 0; j < padded_chunks; j++) {
330+
chunk_type chunk_value = 0;
331+
for (std::size_t k = 0; k < chunk_size; k++) {
332+
chunk_value <<= 1;
333+
chunk_value |= bits[i][j * chunk_size + k];
334+
}
335+
chunks[i][j] = chunk_value;
336+
}
337+
}
338+
339+
value_type greater_val = - value_type(2).pow(chunk_size),
340+
last_flag = 0;
341+
std::array<value_type, 2> sum = {0, 0};
342+
343+
for (std::size_t i = 0; i < gate_instances; i++) {
344+
std::array<chunk_type, 2> current_chunk = {0, 0};
345+
std::size_t base_idx, chunk_idx;
346+
347+
// I basically used lambdas instead of macros to cut down on code reuse.
348+
// Note that the captures are by reference!
349+
auto calculate_flag = [&current_chunk, &greater_val](value_type last_flag) {
350+
return last_flag != 0 ? last_flag
351+
: (current_chunk[0] > current_chunk[1] ? 1
352+
: current_chunk[0] == current_chunk[1] ? 0 : greater_val);
353+
};
354+
auto calculate_temp = [&current_chunk](value_type last_flag) {
355+
return last_flag != 0 ? last_flag : current_chunk[0] - current_chunk[1];
356+
};
357+
// WARNING: this one is impure! But the code after it gets to look nicer.
358+
auto place_chunk_pair = [&current_chunk, &chunks, &sum, &chunk_size](
359+
std::size_t base_idx, std::size_t chunk_idx) {
360+
for (std::size_t k = 0; k < 2; k++) {
361+
current_chunk[k] = chunks[k][chunk_idx];
362+
sum[k] *= (1 << chunk_size);
363+
sum[k] += current_chunk[k];
364+
}
365+
};
366+
367+
for (std::size_t j = 0; j < comparisons_per_gate_instance - 1; j++) {
368+
base_idx = 3 + j * 2;
369+
chunk_idx = i * comparisons_per_gate_instance + j;
370+
371+
place_chunk_pair(base_idx, chunk_idx);
372+
last_flag = calculate_flag(last_flag);
373+
}
374+
// Last chunk
375+
base_idx = 0;
376+
chunk_idx = i * comparisons_per_gate_instance +
377+
comparisons_per_gate_instance - 1;
378+
379+
place_chunk_pair(base_idx, chunk_idx);
380+
last_flag = calculate_flag(last_flag);
381+
}
382+
value_type output;
383+
switch (arg_mode) {
384+
case comparison_mode::FLAG:
385+
output = last_flag != greater_val ? last_flag : -1;
386+
break;
387+
case comparison_mode::LESS_THAN:
388+
output = last_flag == greater_val;
389+
break;
390+
case comparison_mode::LESS_EQUAL:
391+
output = (last_flag == greater_val) || (last_flag == 0);
392+
break;
393+
case comparison_mode::GREATER_THAN:
394+
output = last_flag == 1;
395+
break;
396+
case comparison_mode::GREATER_EQUAL:
397+
output = (last_flag == 1) || (last_flag == 0);
398+
break;
399+
}
400+
401+
return output;
402+
}
284403
};
285404

286405
template<typename BlueprintFieldType, typename ArithmetizationParams>
@@ -465,6 +584,31 @@ namespace nil {
465584
return typename component_type::result_type(component, start_row_index);
466585
}
467586

587+
template<typename BlueprintFieldType, typename ArithmetizationParams>
588+
typename plonk_comparison_flag<BlueprintFieldType, ArithmetizationParams>::result_type
589+
generate_empty_assignments(
590+
const plonk_comparison_flag<BlueprintFieldType, ArithmetizationParams>
591+
&component,
592+
assignment<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
593+
ArithmetizationParams>>
594+
&assignment,
595+
const typename plonk_comparison_flag<BlueprintFieldType, ArithmetizationParams>::input_type
596+
&instance_input,
597+
const std::uint32_t start_row_index) {
598+
599+
using component_type = plonk_comparison_flag<BlueprintFieldType, ArithmetizationParams>;
600+
using value_type = typename BlueprintFieldType::value_type;
601+
using integral_type = typename BlueprintFieldType::integral_type;
602+
603+
value_type x = var_value(assignment, instance_input.x),
604+
y = var_value(assignment, instance_input.y);
605+
606+
assignment.witness(component.W(0), start_row_index) =
607+
component_type::calculate(component.witness_amount(), x, y, component.bits_amount, component.mode);
608+
609+
return typename component_type::result_type(component, start_row_index, true);
610+
}
611+
468612
template<typename BlueprintFieldType, typename ArithmetizationParams>
469613
std::vector<std::size_t> generate_gates(
470614
const plonk_comparison_flag<BlueprintFieldType, ArithmetizationParams>

test/algebra/fields/plonk/non_native/comparison_flag.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ auto test_comparison_flag(typename BlueprintFieldType::value_type x, typename Bl
116116
nil::crypto3::test_component<component_type, BlueprintFieldType, ArithmetizationParams, hash_type, Lambda>(
117117
component_instance, public_input, result_check, instance_input,
118118
nil::crypto3::detail::connectedness_check_type::STRONG, R, Mode);
119+
nil::crypto3::test_empty_component<component_type, BlueprintFieldType, ArithmetizationParams, hash_type, Lambda>(
120+
component_instance, public_input, result_check, instance_input,
121+
nil::crypto3::detail::connectedness_check_type::STRONG, R, Mode);
119122
} else {
120123
nil::crypto3::test_component_to_fail<component_type, BlueprintFieldType, ArithmetizationParams,
121124
hash_type, Lambda>(

test/hashes/plonk/sha256.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ void test_sha256(std::vector<typename BlueprintFieldType::value_type> public_inp
8787

8888
stretched_component_type stretched_instance(component_instance, WitnessColumns / 2, WitnessColumns);
8989

90-
// crypto3::test_component<stretched_component_type, BlueprintFieldType, ArithmetizationParams, hash_type, Lambda>(
91-
// stretched_instance, public_input, result_check, instance_input);
90+
crypto3::test_component<stretched_component_type, BlueprintFieldType, ArithmetizationParams, hash_type, Lambda>(
91+
stretched_instance, public_input, result_check, instance_input);
9292
} else {
93-
// crypto3::test_component<component_type, BlueprintFieldType, ArithmetizationParams, hash_type, Lambda>(
94-
// component_instance, public_input, result_check, instance_input);
93+
crypto3::test_component<component_type, BlueprintFieldType, ArithmetizationParams, hash_type, Lambda>(
94+
component_instance, public_input, result_check, instance_input);
9595
crypto3::test_empty_component<component_type, BlueprintFieldType, ArithmetizationParams, hash_type, Lambda>(
9696
component_instance, public_input, result_check, instance_input);
9797
}

0 commit comments

Comments
 (0)