-
Notifications
You must be signed in to change notification settings - Fork 243
/
Copy pathcontext.rs
1021 lines (926 loc) · 45.5 KB
/
context.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
use std::rc::Rc;
use std::sync::{Mutex, RwLock};
use acvm::{acir::AcirField, FieldElement};
use iter_extended::vecmap;
use noirc_errors::Location;
use noirc_frontend::ast::{BinaryOpKind, Signedness};
use noirc_frontend::monomorphization::ast::{self, LocalId, Parameters};
use noirc_frontend::monomorphization::ast::{FuncId, Program};
use crate::errors::RuntimeError;
use crate::ssa::function_builder::FunctionBuilder;
use crate::ssa::ir::basic_block::BasicBlockId;
use crate::ssa::ir::function::FunctionId as IrFunctionId;
use crate::ssa::ir::function::{Function, RuntimeType};
use crate::ssa::ir::instruction::BinaryOp;
use crate::ssa::ir::instruction::Instruction;
use crate::ssa::ir::map::AtomicCounter;
use crate::ssa::ir::types::{NumericType, Type};
use crate::ssa::ir::value::ValueId;
use super::value::{Tree, Value, Values};
use super::SSA_WORD_SIZE;
use fxhash::FxHashMap as HashMap;
/// The FunctionContext is the main context object for translating a
/// function into SSA form during the SSA-gen pass.
///
/// This context can be used to build any amount of functions,
/// so long as it is cleared out in between each function via
/// calling self.new_function().
///
/// If compiling many functions across multiple threads, there should
/// be a separate FunctionContext for each thread. Each FunctionContext
/// can communicate via the SharedContext field which as its name suggests
/// is the only part of the context that needs to be shared between threads.
pub(super) struct FunctionContext<'a> {
definitions: HashMap<LocalId, Values>,
pub(super) builder: FunctionBuilder,
shared_context: &'a SharedContext,
/// Contains any loops we're currently in the middle of translating.
/// These are ordered such that an inner loop is at the end of the vector and
/// outer loops are at the beginning. When a loop is finished, it is popped.
loops: Vec<Loop>,
}
/// Shared context for all functions during ssa codegen. This is the only
/// object that is shared across all threads when generating ssa in multiple threads.
///
/// The main job of the SharedContext is to remember which functions are already
/// compiled, what their IDs are, and keep a queue of which functions still need to
/// be compiled.
///
/// SSA can be generated by continuously popping from this function_queue and using
/// FunctionContext to generate from the popped function id. Once the queue is empty,
/// no other functions are reachable and the SSA generation is finished.
pub(super) struct SharedContext {
/// All currently known functions which have already been assigned function ids.
/// These functions are all either currently having their SSA generated or are
/// already finished.
functions: RwLock<HashMap<FuncId, IrFunctionId>>,
/// Queue of which functions still need to be compiled.
///
/// The queue is currently Last-in First-out (LIFO) but this is an
/// implementation detail that can be trivially changed and should
/// not impact the resulting SSA besides changing which IDs are assigned
/// to which functions.
function_queue: Mutex<FunctionQueue>,
/// Shared counter used to assign the ID of the next function
function_counter: AtomicCounter<Function>,
/// The entire monomorphized source program
pub(super) program: Program,
}
#[derive(Copy, Clone)]
pub(super) struct Loop {
pub(super) loop_entry: BasicBlockId,
pub(super) loop_index: ValueId,
pub(super) loop_end: BasicBlockId,
}
/// The queue of functions remaining to compile
type FunctionQueue = Vec<(ast::FuncId, IrFunctionId)>;
impl<'a> FunctionContext<'a> {
/// Create a new FunctionContext to compile the first function in the shared_context's
/// function queue.
///
/// This will pop from the function queue, so it is expected the shared_context's function
/// queue is non-empty at the time of calling this function. This can be ensured by calling
/// `shared_context.get_or_queue_function(function_to_queue)` before calling this constructor.
///
/// `function_name` and `parameters` are expected to be the name and parameters of the function
/// this constructor will pop from the function queue.
pub(super) fn new(
function_name: String,
parameters: &Parameters,
runtime: RuntimeType,
shared_context: &'a SharedContext,
) -> Self {
let function_id = shared_context
.pop_next_function_in_queue()
.expect("No function in queue for the FunctionContext to compile")
.1;
let mut builder = FunctionBuilder::new(function_name, function_id);
builder.set_runtime(runtime);
let definitions = HashMap::default();
let mut this = Self { definitions, builder, shared_context, loops: Vec::new() };
this.add_parameters_to_scope(parameters);
this
}
/// Finish building the current function and switch to building a new function with the
/// given name, id, and parameters.
///
/// Note that the previous function cannot be resumed after calling this. Developers should
/// avoid calling new_function until the previous function is completely finished with ssa-gen.
pub(super) fn new_function(&mut self, id: IrFunctionId, func: &ast::Function) {
self.definitions.clear();
if func.unconstrained {
self.builder.new_brillig_function(func.name.clone(), id);
} else {
self.builder.new_function(func.name.clone(), id, func.inline_type);
}
self.add_parameters_to_scope(&func.parameters);
}
/// Add each parameter to the current scope, and return the list of parameter types.
///
/// The returned parameter type list will be flattened, so any struct parameters will
/// be returned as one entry for each field (recursively).
fn add_parameters_to_scope(&mut self, parameters: &Parameters) {
for (id, mutable, _, typ) in parameters {
self.add_parameter_to_scope(*id, typ, *mutable);
}
}
/// Adds a "single" parameter to scope.
///
/// Single is in quotes here because in the case of tuple parameters, the tuple is flattened
/// into a new parameter for each field recursively.
fn add_parameter_to_scope(
&mut self,
parameter_id: LocalId,
parameter_type: &ast::Type,
mutable: bool,
) {
// Add a separate parameter for each field type in 'parameter_type'
let parameter_value = Self::map_type(parameter_type, |typ| {
let value = self.builder.add_parameter(typ);
if mutable {
self.new_mutable_variable(value)
} else {
value.into()
}
});
self.definitions.insert(parameter_id, parameter_value);
}
/// Allocate a single slot of memory and store into it the given initial value of the variable.
/// Always returns a Value::Mutable wrapping the allocate instruction.
pub(super) fn new_mutable_variable(&mut self, value_to_store: ValueId) -> Value {
let element_type = self.builder.current_function.dfg.type_of_value(value_to_store);
let alloc = self.builder.insert_allocate(element_type);
self.builder.insert_store(alloc, value_to_store);
let typ = self.builder.type_of_value(value_to_store);
Value::Mutable(alloc, typ)
}
/// Maps the given type to a Tree of the result type.
///
/// This can be used to (for example) flatten a tuple type, creating
/// and returning a new parameter for each field type.
pub(super) fn map_type<T>(typ: &ast::Type, mut f: impl FnMut(Type) -> T) -> Tree<T> {
Self::map_type_helper(typ, &mut f)
}
// This helper is needed because we need to take f by mutable reference,
// otherwise we cannot move it multiple times each loop of vecmap.
fn map_type_helper<T>(typ: &ast::Type, f: &mut dyn FnMut(Type) -> T) -> Tree<T> {
match typ {
ast::Type::Tuple(fields) => {
Tree::Branch(vecmap(fields, |field| Self::map_type_helper(field, f)))
}
ast::Type::Unit => Tree::empty(),
// A mutable reference wraps each element into a reference.
// This can be multiple values if the element type is a tuple.
ast::Type::MutableReference(element) => {
Self::map_type_helper(element, &mut |typ| f(Type::Reference(Rc::new(typ))))
}
ast::Type::FmtString(len, fields) => {
// A format string is represented by multiple values
// The message string, the number of fields to be formatted, and
// then the encapsulated fields themselves
let final_fmt_str_fields =
vec![ast::Type::String(*len), ast::Type::Field, *fields.clone()];
let fmt_str_tuple = ast::Type::Tuple(final_fmt_str_fields);
Self::map_type_helper(&fmt_str_tuple, f)
}
ast::Type::Slice(elements) => {
let element_types = Self::convert_type(elements).flatten();
Tree::Branch(vec![
Tree::Leaf(f(Type::length_type())),
Tree::Leaf(f(Type::Slice(Rc::new(element_types)))),
])
}
other => Tree::Leaf(f(Self::convert_non_tuple_type(other))),
}
}
/// Convert a monomorphized type to an SSA type, preserving the structure
/// of any tuples within.
pub(super) fn convert_type(typ: &ast::Type) -> Tree<Type> {
// Do nothing in the closure here - map_type_helper already calls
// convert_non_tuple_type internally.
Self::map_type_helper(typ, &mut |x| x)
}
/// Converts a non-tuple type into an SSA type. Panics if a tuple type is passed.
///
/// This function is needed since this SSA IR has no concept of tuples and thus no type for
/// them. Use `convert_type` if tuple types need to be handled correctly.
pub(super) fn convert_non_tuple_type(typ: &ast::Type) -> Type {
match typ {
ast::Type::Field => Type::field(),
ast::Type::Array(len, element) => {
let element_types = Self::convert_type(element).flatten();
Type::Array(Rc::new(element_types), *len as usize)
}
ast::Type::Integer(Signedness::Signed, bits) => Type::signed((*bits).into()),
ast::Type::Integer(Signedness::Unsigned, bits) => Type::unsigned((*bits).into()),
ast::Type::Bool => Type::unsigned(1),
ast::Type::String(len) => Type::Array(Rc::new(vec![Type::char()]), *len as usize),
ast::Type::FmtString(_, _) => {
panic!("convert_non_tuple_type called on a fmt string: {typ}")
}
ast::Type::Unit => panic!("convert_non_tuple_type called on a unit type"),
ast::Type::Tuple(_) => panic!("convert_non_tuple_type called on a tuple: {typ}"),
ast::Type::Function(_, _, _) => Type::Function,
ast::Type::Slice(_) => panic!("convert_non_tuple_type called on a slice: {typ}"),
ast::Type::MutableReference(element) => {
// Recursive call to panic if element is a tuple
let element = Self::convert_non_tuple_type(element);
Type::Reference(Rc::new(element))
}
}
}
/// Returns the unit value, represented as an empty tree of values
pub(super) fn unit_value() -> Values {
Values::empty()
}
/// Insert a numeric constant into the current function
///
/// Unlike FunctionBuilder::numeric_constant, this version checks the given constant
/// is within the range of the given type. This is needed for user provided values where
/// otherwise values like 2^128 can be assigned to a u8 without error or wrapping.
pub(super) fn checked_numeric_constant(
&mut self,
value: impl Into<FieldElement>,
typ: Type,
) -> Result<ValueId, RuntimeError> {
let value = value.into();
if let Type::Numeric(typ) = typ {
if !typ.value_is_within_limits(value) {
let call_stack = self.builder.get_call_stack();
return Err(RuntimeError::IntegerOutOfBounds { value, typ, call_stack });
}
} else {
panic!("Expected type for numeric constant to be a numeric type, found {typ}");
}
Ok(self.builder.numeric_constant(value, typ))
}
/// helper function which add instructions to the block computing the absolute value of the
/// given signed integer input. When the input is negative, we return its two complement, and itself when it is positive.
fn absolute_value_helper(&mut self, input: ValueId, sign: ValueId, bit_size: u32) -> ValueId {
assert_eq!(self.builder.type_of_value(sign), Type::bool());
// We compute the absolute value of lhs
let bit_width =
self.builder.numeric_constant(FieldElement::from(2_i128.pow(bit_size)), Type::field());
let sign_not = self.builder.insert_not(sign);
// We use unsafe casts here, this is fine as we're casting to a `field` type.
let as_field = self.builder.insert_cast(input, Type::field());
let sign_field = self.builder.insert_cast(sign, Type::field());
let positive_predicate = self.builder.insert_binary(sign_field, BinaryOp::Mul, as_field);
let two_complement = self.builder.insert_binary(bit_width, BinaryOp::Sub, as_field);
let sign_not_field = self.builder.insert_cast(sign_not, Type::field());
let negative_predicate =
self.builder.insert_binary(sign_not_field, BinaryOp::Mul, two_complement);
self.builder.insert_binary(positive_predicate, BinaryOp::Add, negative_predicate)
}
/// Insert constraints ensuring that the operation does not overflow the bit size of the result
///
/// If the result is unsigned, overflow will be checked during acir-gen (cf. issue #4456), except for bit-shifts, because we will convert them to field multiplication
///
/// If the result is signed, we just prepare it for check_signed_overflow() by casting it to
/// an unsigned value representing the signed integer.
/// We need to use a bigger bit size depending on the operation, in case the operation does overflow,
/// Then, we delegate the overflow checks to check_signed_overflow() and cast the result back to its type.
/// Note that we do NOT want to check for overflows here, only check_signed_overflow() is allowed to do so.
/// This is because an overflow might be valid. For instance if 'a' is a signed integer, then 'a - a', as an unsigned result will always
/// overflow the bit size, however the operation is still valid (i.e it is not a signed overflow)
fn check_overflow(
&mut self,
result: ValueId,
lhs: ValueId,
rhs: ValueId,
operator: BinaryOpKind,
location: Location,
) -> ValueId {
let result_type = self.builder.current_function.dfg.type_of_value(result);
match result_type {
Type::Numeric(NumericType::Signed { bit_size }) => {
match operator {
BinaryOpKind::Add | BinaryOpKind::Subtract => {
// Result is computed modulo the bit size
let result = self.builder.insert_truncate(result, bit_size, bit_size + 1);
let result =
self.insert_safe_cast(result, Type::unsigned(bit_size), location);
self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location);
self.insert_safe_cast(result, result_type, location)
}
BinaryOpKind::Multiply => {
// Result is computed modulo the bit size
let mut result =
self.builder.insert_cast(result, Type::unsigned(2 * bit_size));
result = self.builder.insert_truncate(result, bit_size, 2 * bit_size);
self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location);
self.insert_safe_cast(result, result_type, location)
}
BinaryOpKind::ShiftLeft | BinaryOpKind::ShiftRight => {
self.check_shift_overflow(result, rhs, bit_size, location)
}
_ => unreachable!("operator {} should not overflow", operator),
}
}
Type::Numeric(NumericType::Unsigned { bit_size }) => {
let dfg = &self.builder.current_function.dfg;
let max_lhs_bits = dfg.get_value_max_num_bits(lhs);
match operator {
BinaryOpKind::Add | BinaryOpKind::Subtract | BinaryOpKind::Multiply => {
// Overflow check is deferred to acir-gen
return result;
}
BinaryOpKind::ShiftLeft => {
if let Some(rhs_const) = dfg.get_numeric_constant(rhs) {
let bit_shift_size = rhs_const.to_u128() as u32;
if max_lhs_bits + bit_shift_size <= bit_size {
// `lhs` has been casted up from a smaller type such that shifting it by a constant
// `rhs` is known not to exceed the maximum bit size.
return result;
}
}
self.check_shift_overflow(result, rhs, bit_size, location);
}
_ => unreachable!("operator {} should not overflow", operator),
}
result
}
_ => result,
}
}
/// Overflow checks for bit-shift
/// We use Rust behavior for bit-shift:
/// If rhs is more or equal than the bit size, then we overflow
/// If not, we do not overflow and shift with 0 when bits are falling out of the bit size
fn check_shift_overflow(
&mut self,
result: ValueId,
rhs: ValueId,
bit_size: u32,
location: Location,
) -> ValueId {
let one = self.builder.numeric_constant(FieldElement::one(), Type::bool());
assert!(self.builder.current_function.dfg.type_of_value(rhs) == Type::unsigned(8));
let max =
self.builder.numeric_constant(FieldElement::from(bit_size as i128), Type::unsigned(8));
let overflow = self.builder.insert_binary(rhs, BinaryOp::Lt, max);
self.builder.set_location(location).insert_constrain(
overflow,
one,
Some("attempt to bit-shift with overflow".to_owned().into()),
);
self.builder.insert_truncate(result, bit_size, bit_size + 1)
}
/// Insert constraints ensuring that the operation does not overflow the bit size of the result
/// We assume that:
/// lhs and rhs are signed integers of bit size bit_size
/// result is the result of the operation, casted into an unsigned integer and not reduced
///
/// overflow check for signed integer is less straightforward than for unsigned integers.
/// We first compute the sign of the operands, and then we use the following rules:
/// addition: positive operands => result must be positive (i.e less than half the bit size)
/// negative operands => result must be negative (i.e not positive)
/// different sign => no overflow
/// multiplication: we check that the product of the operands' absolute values does not overflow the bit size
/// then we check that the result has the proper sign, using the rule of signs
fn check_signed_overflow(
&mut self,
result: ValueId,
lhs: ValueId,
rhs: ValueId,
operator: BinaryOpKind,
bit_size: u32,
location: Location,
) {
let is_sub = operator == BinaryOpKind::Subtract;
let half_width = self.builder.numeric_constant(
FieldElement::from(2_i128.pow(bit_size - 1)),
Type::unsigned(bit_size),
);
// We compute the sign of the operands. The overflow checks for signed integers depends on these signs
let lhs_as_unsigned = self.insert_safe_cast(lhs, Type::unsigned(bit_size), location);
let rhs_as_unsigned = self.insert_safe_cast(rhs, Type::unsigned(bit_size), location);
let lhs_sign = self.builder.insert_binary(lhs_as_unsigned, BinaryOp::Lt, half_width);
let mut rhs_sign = self.builder.insert_binary(rhs_as_unsigned, BinaryOp::Lt, half_width);
let message = if is_sub {
// lhs - rhs = lhs + (-rhs)
rhs_sign = self.builder.insert_not(rhs_sign);
"attempt to subtract with overflow".to_string()
} else {
"attempt to add with overflow".to_string()
};
// same_sign is true if both operands have the same sign
let same_sign = self.builder.insert_binary(lhs_sign, BinaryOp::Eq, rhs_sign);
match operator {
BinaryOpKind::Add | BinaryOpKind::Subtract => {
//Check the result has the same sign as its inputs
let result_sign = self.builder.insert_binary(result, BinaryOp::Lt, half_width);
let sign_diff = self.builder.insert_binary(result_sign, BinaryOp::Eq, lhs_sign);
let sign_diff_with_predicate =
self.builder.insert_binary(sign_diff, BinaryOp::Mul, same_sign);
let overflow_check = Instruction::Constrain(
sign_diff_with_predicate,
same_sign,
Some(message.into()),
);
self.builder.set_location(location).insert_instruction(overflow_check, None);
}
BinaryOpKind::Multiply => {
// Overflow check for the multiplication:
// First we compute the absolute value of operands, and their product
let lhs_abs = self.absolute_value_helper(lhs, lhs_sign, bit_size);
let rhs_abs = self.absolute_value_helper(rhs, rhs_sign, bit_size);
let product_field = self.builder.insert_binary(lhs_abs, BinaryOp::Mul, rhs_abs);
// It must not already overflow the bit_size
self.builder.set_location(location).insert_range_check(
product_field,
bit_size,
Some("attempt to multiply with overflow".to_string()),
);
let product = self.builder.insert_cast(product_field, Type::unsigned(bit_size));
// Then we check the signed product fits in a signed integer of bit_size-bits
let not_same = self.builder.insert_not(same_sign);
let not_same_sign_field =
self.insert_safe_cast(not_same, Type::unsigned(bit_size), location);
let positive_maximum_with_offset =
self.builder.insert_binary(half_width, BinaryOp::Add, not_same_sign_field);
let product_overflow_check =
self.builder.insert_binary(product, BinaryOp::Lt, positive_maximum_with_offset);
let one = self.builder.numeric_constant(FieldElement::one(), Type::bool());
self.builder.set_location(location).insert_constrain(
product_overflow_check,
one,
Some(message.into()),
);
}
_ => unreachable!("operator {} should not overflow", operator),
}
}
/// Insert a binary instruction at the end of the current block.
/// Converts the form of the binary instruction as necessary
/// (e.g. swapping arguments, inserting a not) to represent it in the IR.
/// For example, (a <= b) is represented as !(b < a)
pub(super) fn insert_binary(
&mut self,
mut lhs: ValueId,
operator: BinaryOpKind,
mut rhs: ValueId,
location: Location,
) -> Values {
let op = convert_operator(operator);
if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}
let mut result = self.builder.set_location(location).insert_binary(lhs, op, rhs);
// Check for integer overflow
if matches!(
operator,
BinaryOpKind::Add
| BinaryOpKind::Subtract
| BinaryOpKind::Multiply
| BinaryOpKind::ShiftLeft
) {
result = self.check_overflow(result, lhs, rhs, operator, location);
}
if operator_requires_not(operator) {
result = self.builder.insert_not(result);
}
result.into()
}
/// Inserts a call instruction at the end of the current block and returns the results
/// of the call.
///
/// Compared to self.builder.insert_call, this version will reshape the returned Vec<ValueId>
/// back into a Values tree of the proper shape.
pub(super) fn insert_call(
&mut self,
function: ValueId,
arguments: Vec<ValueId>,
result_type: &ast::Type,
location: Location,
) -> Values {
let result_types = Self::convert_type(result_type).flatten();
let results =
self.builder.set_location(location).insert_call(function, arguments, result_types);
let mut i = 0;
let reshaped_return_values = Self::map_type(result_type, |_| {
let result = results[i].into();
i += 1;
result
});
assert_eq!(i, results.len());
reshaped_return_values
}
/// Inserts a cast instruction at the end of the current block and returns the results
/// of the cast.
///
/// Compared to `self.builder.insert_cast`, this version will automatically truncate `value` to be a valid `typ`.
pub(super) fn insert_safe_cast(
&mut self,
mut value: ValueId,
typ: Type,
location: Location,
) -> ValueId {
self.builder.set_location(location);
// To ensure that `value` is a valid `typ`, we insert an `Instruction::Truncate` instruction beforehand if
// we're narrowing the type size.
let incoming_type_size = self.builder.type_of_value(value).bit_size();
let target_type_size = typ.bit_size();
if target_type_size < incoming_type_size {
value = self.builder.insert_truncate(value, target_type_size, incoming_type_size);
}
self.builder.insert_cast(value, typ)
}
/// Create a const offset of an address for an array load or store
pub(super) fn make_offset(&mut self, mut address: ValueId, offset: u128) -> ValueId {
if offset != 0 {
let offset = self.builder.numeric_constant(offset, self.builder.type_of_value(address));
address = self.builder.insert_binary(address, BinaryOp::Add, offset);
}
address
}
/// Array indexes are u32. This function casts values used as indexes to u32.
pub(super) fn make_array_index(&mut self, index: ValueId) -> ValueId {
self.builder.insert_cast(index, Type::unsigned(SSA_WORD_SIZE))
}
/// Define a local variable to be some Values that can later be retrieved
/// by calling self.lookup(id)
pub(super) fn define(&mut self, id: LocalId, value: Values) {
let existing = self.definitions.insert(id, value);
assert!(existing.is_none(), "Variable {id:?} was defined twice in ssa-gen pass");
}
/// Looks up the value of a given local variable. Expects the variable to have
/// been previously defined or panics otherwise.
pub(super) fn lookup(&self, id: LocalId) -> Values {
self.definitions.get(&id).expect("lookup: variable not defined").clone()
}
/// Extract the given field of the tuple. Panics if the given Values is not
/// a Tree::Branch or does not have enough fields.
pub(super) fn get_field(tuple: Values, field_index: usize) -> Values {
match tuple {
Tree::Branch(mut trees) => trees.remove(field_index),
Tree::Leaf(value) => {
unreachable!("Tried to extract tuple index {field_index} from non-tuple {value:?}")
}
}
}
/// Extract the given field of the tuple by reference. Panics if the given Values is not
/// a Tree::Branch or does not have enough fields.
pub(super) fn get_field_ref(tuple: &Values, field_index: usize) -> &Values {
match tuple {
Tree::Branch(trees) => &trees[field_index],
Tree::Leaf(value) => {
unreachable!("Tried to extract tuple index {field_index} from non-tuple {value:?}")
}
}
}
/// Replace the given field of the tuple with a new one. Panics if the given Values is not
/// a Tree::Branch or does not have enough fields.
pub(super) fn replace_field(tuple: Values, field_index: usize, new_value: Values) -> Values {
match tuple {
Tree::Branch(mut trees) => {
trees[field_index] = new_value;
Tree::Branch(trees)
}
Tree::Leaf(value) => {
unreachable!("Tried to extract tuple index {field_index} from non-tuple {value:?}")
}
}
}
/// Retrieves the given function, adding it to the function queue
/// if it is not yet compiled.
pub(super) fn get_or_queue_function(&mut self, id: FuncId) -> Values {
let function = self.shared_context.get_or_queue_function(id);
self.builder.import_function(function).into()
}
/// Extracts the current value out of an LValue.
///
/// Goal: Handle the case of assigning to nested expressions such as `foo.bar[i1].baz[i2] = e`
/// while also noting that assigning to arrays will create a new array rather than mutate
/// the original.
///
/// Method: First `extract_current_value` must recurse on the lvalue to extract the current
/// value contained:
///
/// v0 = foo.bar ; allocate instruction for bar
/// v1 = load v0 ; loading the bar array
/// v2 = add i1, baz_index ; field offset for index i1, field baz
/// v3 = array_get v1, index v2 ; foo.bar[i1].baz
///
/// Method (part 2): Then, `assign_new_value` will recurse in the opposite direction to
/// construct the larger value as needed until we can `store` to the nearest
/// allocation.
///
/// v4 = array_set v3, index i2, e ; finally create a new array setting the desired value
/// v5 = array_set v1, index v2, v4 ; now must also create the new bar array
/// store v5 in v0 ; and store the result in the only mutable reference
///
/// The returned `LValueRef` tracks the current value at each step of the lvalue.
/// This is later used by `assign_new_value` to construct a new updated value that
/// can be assigned to an allocation within the LValueRef::Ident.
///
/// This is operationally equivalent to extract_current_value_recursive, but splitting these
/// into two separate functions avoids cloning the outermost `Values` returned by the recursive
/// version, as it is only needed for recursion.
pub(super) fn extract_current_value(
&mut self,
lvalue: &ast::LValue,
) -> Result<LValue, RuntimeError> {
Ok(match lvalue {
ast::LValue::Ident(ident) => {
let (reference, should_auto_deref) = self.ident_lvalue(ident);
if should_auto_deref {
LValue::Dereference { reference }
} else {
LValue::Ident
}
}
ast::LValue::Index { array, index, location, .. } => {
self.index_lvalue(array, index, location)?.2
}
ast::LValue::MemberAccess { object, field_index } => {
let (old_object, object_lvalue) = self.extract_current_value_recursive(object)?;
let object_lvalue = Box::new(object_lvalue);
LValue::MemberAccess { old_object, object_lvalue, index: *field_index }
}
ast::LValue::Dereference { reference, .. } => {
let (reference, _) = self.extract_current_value_recursive(reference)?;
LValue::Dereference { reference }
}
})
}
fn dereference_lvalue(&mut self, values: &Values, element_type: &ast::Type) -> Values {
let element_types = Self::convert_type(element_type);
values.map_both(element_types, |value, element_type| {
let reference = value.eval_reference();
self.builder.insert_load(reference, element_type).into()
})
}
/// Compile the given identifier as a reference - ie. avoid calling .eval().
/// Returns the variable's value and whether the variable is mutable.
fn ident_lvalue(&self, ident: &ast::Ident) -> (Values, bool) {
match &ident.definition {
ast::Definition::Local(id) => (self.lookup(*id), ident.mutable),
other => panic!("Unexpected definition found for mutable value: {other}"),
}
}
/// Compile the given `array[index]` expression as a reference.
/// This will return a triple of (array, index, lvalue_ref, Option<length>) where the lvalue_ref records the
/// structure of the lvalue expression for use by `assign_new_value`.
/// The optional length is for indexing slices rather than arrays since slices
/// are represented as a tuple in the form: (length, slice contents).
fn index_lvalue(
&mut self,
array: &ast::LValue,
index: &ast::Expression,
location: &Location,
) -> Result<(ValueId, ValueId, LValue, Option<ValueId>), RuntimeError> {
let (old_array, array_lvalue) = self.extract_current_value_recursive(array)?;
let index = self.codegen_non_tuple_expression(index)?;
let array_lvalue = Box::new(array_lvalue);
let array_values = old_array.clone().into_value_list(self);
let location = *location;
// A slice is represented as a tuple (length, slice contents).
// We need to fetch the second value.
Ok(if array_values.len() > 1 {
let slice_lvalue = LValue::SliceIndex {
old_slice: old_array,
index,
slice_lvalue: array_lvalue,
location,
};
(array_values[1], index, slice_lvalue, Some(array_values[0]))
} else {
let array_lvalue =
LValue::Index { old_array: array_values[0], index, array_lvalue, location };
(array_values[0], index, array_lvalue, None)
})
}
fn extract_current_value_recursive(
&mut self,
lvalue: &ast::LValue,
) -> Result<(Values, LValue), RuntimeError> {
match lvalue {
ast::LValue::Ident(ident) => {
let (variable, should_auto_deref) = self.ident_lvalue(ident);
if should_auto_deref {
let dereferenced = self.dereference_lvalue(&variable, &ident.typ);
Ok((dereferenced, LValue::Dereference { reference: variable }))
} else {
Ok((variable.clone(), LValue::Ident))
}
}
ast::LValue::Index { array, index, element_type, location } => {
let (old_array, index, index_lvalue, max_length) =
self.index_lvalue(array, index, location)?;
let element = self.codegen_array_index(
old_array,
index,
element_type,
*location,
max_length,
)?;
Ok((element, index_lvalue))
}
ast::LValue::MemberAccess { object, field_index: index } => {
let (old_object, object_lvalue) = self.extract_current_value_recursive(object)?;
let object_lvalue = Box::new(object_lvalue);
let element = Self::get_field_ref(&old_object, *index).clone();
Ok((element, LValue::MemberAccess { old_object, object_lvalue, index: *index }))
}
ast::LValue::Dereference { reference, element_type } => {
let (reference, _) = self.extract_current_value_recursive(reference)?;
let dereferenced = self.dereference_lvalue(&reference, element_type);
Ok((dereferenced, LValue::Dereference { reference }))
}
}
}
/// Assigns a new value to the given LValue.
/// The LValue can be created via a previous call to extract_current_value.
/// This method recurs on the given LValue to create a new value to assign an allocation
/// instruction within an LValue::Ident or LValue::Dereference - see the comment on
/// `extract_current_value` for more details.
pub(super) fn assign_new_value(&mut self, lvalue: LValue, new_value: Values) {
match lvalue {
LValue::Ident => unreachable!("Cannot assign to a variable without a reference"),
LValue::Index { old_array: mut array, index, array_lvalue, location } => {
array = self.assign_lvalue_index(new_value, array, index, location);
self.assign_new_value(*array_lvalue, array.into());
}
LValue::SliceIndex { old_slice: slice, index, slice_lvalue, location } => {
let mut slice_values = slice.into_value_list(self);
slice_values[1] =
self.assign_lvalue_index(new_value, slice_values[1], index, location);
// The size of the slice does not change in a slice index assignment so we can reuse the same length value
let new_slice = Tree::Branch(vec![slice_values[0].into(), slice_values[1].into()]);
self.assign_new_value(*slice_lvalue, new_slice);
}
LValue::MemberAccess { old_object, index, object_lvalue } => {
let new_object = Self::replace_field(old_object, index, new_value);
self.assign_new_value(*object_lvalue, new_object);
}
LValue::Dereference { reference } => {
self.assign(reference, new_value);
}
}
}
fn assign_lvalue_index(
&mut self,
new_value: Values,
mut array: ValueId,
index: ValueId,
location: Location,
) -> ValueId {
let index = self.make_array_index(index);
let element_size =
self.builder.numeric_constant(self.element_size(array), Type::unsigned(SSA_WORD_SIZE));
// The actual base index is the user's index * the array element type's size
let mut index =
self.builder.set_location(location).insert_binary(index, BinaryOp::Mul, element_size);
let one = self.builder.numeric_constant(FieldElement::one(), Type::unsigned(SSA_WORD_SIZE));
new_value.for_each(|value| {
let value = value.eval(self);
array = self.builder.insert_array_set(array, index, value);
index = self.builder.insert_binary(index, BinaryOp::Add, one);
});
array
}
fn element_size(&self, array: ValueId) -> FieldElement {
let size = self.builder.type_of_value(array).element_size();
FieldElement::from(size as u128)
}
/// Given an lhs containing only references, create a store instruction to store each value of
/// rhs into its corresponding value in lhs.
fn assign(&mut self, lhs: Values, rhs: Values) {
match (lhs, rhs) {
(Tree::Branch(lhs_branches), Tree::Branch(rhs_branches)) => {
assert_eq!(lhs_branches.len(), rhs_branches.len());
for (lhs, rhs) in lhs_branches.into_iter().zip(rhs_branches) {
self.assign(lhs, rhs);
}
}
(Tree::Leaf(lhs), Tree::Leaf(rhs)) => {
let (lhs, rhs) = (lhs.eval_reference(), rhs.eval(self));
self.builder.insert_store(lhs, rhs);
}
(lhs, rhs) => {
unreachable!(
"assign: Expected lhs and rhs values to match but found {lhs:?} and {rhs:?}"
)
}
}
}
/// Increments the reference count of all parameters. Returns the entry block of the function.
///
/// This is done on parameters rather than call arguments so that we can optimize out
/// paired inc/dec instructions within brillig functions more easily.
pub(crate) fn increment_parameter_rcs(&mut self) -> BasicBlockId {
let entry = self.builder.current_function.entry_block();
let parameters = self.builder.current_function.dfg.block_parameters(entry).to_vec();
for parameter in parameters {
self.builder.increment_array_reference_count(parameter);
}
entry
}
/// Ends a local scope of a function.
/// This will issue DecrementRc instructions for any arrays in the given starting scope
/// block's parameters. Arrays that are also used in terminator instructions for the scope are
/// ignored.
pub(crate) fn end_scope(&mut self, scope: BasicBlockId, terminator_args: &[ValueId]) {
let mut dropped_parameters =
self.builder.current_function.dfg.block_parameters(scope).to_vec();
dropped_parameters.retain(|parameter| !terminator_args.contains(parameter));
for parameter in dropped_parameters {
self.builder.decrement_array_reference_count(parameter);
}
}
pub(crate) fn enter_loop(
&mut self,
loop_entry: BasicBlockId,
loop_index: ValueId,
loop_end: BasicBlockId,
) {
self.loops.push(Loop { loop_entry, loop_index, loop_end });
}
pub(crate) fn exit_loop(&mut self) {
self.loops.pop();
}
pub(crate) fn current_loop(&self) -> Loop {
// The frontend should ensure break/continue are never used outside a loop
*self.loops.last().expect("current_loop: not in a loop!")
}
}
/// True if the given operator cannot be encoded directly and needs
/// to be represented as !(some other operator)
fn operator_requires_not(op: BinaryOpKind) -> bool {
use BinaryOpKind::*;
matches!(op, NotEqual | LessEqual | GreaterEqual)
}
/// True if the given operator cannot be encoded directly and needs
/// to have its lhs and rhs swapped to be represented with another operator.
/// Example: (a > b) needs to be represented as (b < a)
fn operator_requires_swapped_operands(op: BinaryOpKind) -> bool {
use BinaryOpKind::*;
matches!(op, Greater | LessEqual)
}
/// Converts the given operator to the appropriate BinaryOp.
/// Take care when using this to insert a binary instruction: this requires
/// checking operator_requires_not and operator_requires_swapped_operands
/// to represent the full operation correctly.
fn convert_operator(op: BinaryOpKind) -> BinaryOp {
match op {
BinaryOpKind::Add => BinaryOp::Add,
BinaryOpKind::Subtract => BinaryOp::Sub,
BinaryOpKind::Multiply => BinaryOp::Mul,
BinaryOpKind::Divide => BinaryOp::Div,
BinaryOpKind::Modulo => BinaryOp::Mod,
BinaryOpKind::Equal => BinaryOp::Eq,
BinaryOpKind::NotEqual => BinaryOp::Eq, // Requires not
BinaryOpKind::Less => BinaryOp::Lt,
BinaryOpKind::Greater => BinaryOp::Lt, // Requires operand swap
BinaryOpKind::LessEqual => BinaryOp::Lt, // Requires operand swap and not
BinaryOpKind::GreaterEqual => BinaryOp::Lt, // Requires not
BinaryOpKind::And => BinaryOp::And,
BinaryOpKind::Or => BinaryOp::Or,
BinaryOpKind::Xor => BinaryOp::Xor,
BinaryOpKind::ShiftLeft => BinaryOp::Shl,
BinaryOpKind::ShiftRight => BinaryOp::Shr,
}
}
impl SharedContext {
/// Create a new SharedContext for the given monomorphized program.
pub(super) fn new(program: Program) -> Self {
Self {
functions: Default::default(),
function_queue: Default::default(),
function_counter: Default::default(),
program,
}
}
/// Pops the next function from the shared function queue, returning None if the queue is empty.
pub(super) fn pop_next_function_in_queue(&self) -> Option<(ast::FuncId, IrFunctionId)> {
self.function_queue.lock().expect("Failed to lock function_queue").pop()
}
/// Return the matching id for the given function if known. If it is not known this
/// will add the function to the queue of functions to compile, assign it a new id,
/// and return this new id.
pub(super) fn get_or_queue_function(&self, id: ast::FuncId) -> IrFunctionId {
// Start a new block to guarantee the destructor for the map lock is released
// before map needs to be acquired again in self.functions.write() below
{
let map = self.functions.read().expect("Failed to read self.functions");
if let Some(existing_id) = map.get(&id) {
return *existing_id;
}
}