@@ -167,6 +167,7 @@ impl Context<'_> {
167
167
let lhs_typ = self . function . dfg . type_of_value ( lhs) . unwrap_numeric ( ) ;
168
168
let base = self . field_constant ( FieldElement :: from ( 2_u128 ) ) ;
169
169
let pow = self . pow ( base, rhs) ;
170
+ let pow = self . pow_or_max_for_bit_size ( pow, rhs, bit_size, lhs_typ) ;
170
171
let pow = self . insert_cast ( pow, lhs_typ) ;
171
172
if lhs_typ. is_unsigned ( ) {
172
173
// unsigned right bit shift is just a normal division
@@ -205,6 +206,56 @@ impl Context<'_> {
205
206
}
206
207
}
207
208
209
+ /// Returns `pow` or the maximum value allowed for `typ` if 2^rhs is guaranteed to exceed that maximum.
210
+ fn pow_or_max_for_bit_size (
211
+ & mut self ,
212
+ pow : ValueId ,
213
+ rhs : ValueId ,
214
+ bit_size : u32 ,
215
+ typ : NumericType ,
216
+ ) -> ValueId {
217
+ let max = if typ. is_unsigned ( ) {
218
+ if bit_size == 128 {
219
+ u128:: MAX
220
+ } else {
221
+ ( 1_u128 << bit_size) - 1
222
+ }
223
+ } else {
224
+ 1_u128 << ( bit_size - 1 )
225
+ } ;
226
+ let max = self . field_constant ( FieldElement :: from ( max) ) ;
227
+
228
+ // Here we check whether rhs is less than the bit_size: if it's not then it will overflow.
229
+ // Then we do:
230
+ //
231
+ // rhs_is_less_than_bit_size = lt rhs, bit_size
232
+ // rhs_is_not_less_than_bit_size = not rhs_is_less_than_bit_size
233
+ // pow_when_is_less_than_bit_size = rhs_is_less_than_bit_size * pow
234
+ // pow_when_is_not_less_than_bit_size = rhs_is_not_less_than_bit_size * max
235
+ // pow = add pow_when_is_less_than_bit_size, pow_when_is_not_less_than_bit_size
236
+ //
237
+ // All operations here are unchecked because they work on field types.
238
+ let bit_size_as_field = self . numeric_constant ( bit_size as u128 , typ) ;
239
+ let rhs_is_less_than_bit_size = self . insert_binary ( rhs, BinaryOp :: Lt , bit_size_as_field) ;
240
+ let rhs_is_not_less_than_bit_size = self . insert_not ( rhs_is_less_than_bit_size) ;
241
+ let rhs_is_less_than_bit_size =
242
+ self . insert_cast ( rhs_is_less_than_bit_size, NumericType :: NativeField ) ;
243
+ let rhs_is_not_less_than_bit_size =
244
+ self . insert_cast ( rhs_is_not_less_than_bit_size, NumericType :: NativeField ) ;
245
+ let pow_when_is_less_than_bit_size =
246
+ self . insert_binary ( rhs_is_less_than_bit_size, BinaryOp :: Mul { unchecked : true } , pow) ;
247
+ let pow_when_is_not_less_than_bit_size = self . insert_binary (
248
+ rhs_is_not_less_than_bit_size,
249
+ BinaryOp :: Mul { unchecked : true } ,
250
+ max,
251
+ ) ;
252
+ self . insert_binary (
253
+ pow_when_is_less_than_bit_size,
254
+ BinaryOp :: Add { unchecked : true } ,
255
+ pow_when_is_not_less_than_bit_size,
256
+ )
257
+ }
258
+
208
259
/// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs
209
260
/// Pseudo-code of the computation:
210
261
/// let mut r = 1;
0 commit comments