diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/c_bind.cpp b/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/c_bind.cpp index a989d45f5c6b..b206430a3f33 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/c_bind.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/c_bind.cpp @@ -63,12 +63,17 @@ WASM_EXPORT void ecc_grumpkin__reduce512_buffer_mod_circuit_modulus(uint8_t* inp write(result, target_output.lo); } -WASM_EXPORT void grumpkin_fr_sqrt(uint8_t const* field_buf, uint8_t* result) +WASM_EXPORT void grumpkin_fr_sqrt(uint8_t const* input, uint8_t* result) { using serialize::write; - auto fr = from_buffer(field_buf); - auto [is_sqr, root] = fr.sqrt(); - write(result, root); + auto input_fr = from_buffer(input); + auto [is_sqr, root] = input_fr.sqrt(); + + uint8_t* is_sqrt_result_ptr = result; + uint8_t* root_result_ptr = result + 1; + + write(is_sqrt_result_ptr, is_sqr); + write(root_result_ptr, root); } // NOLINTEND(cert-dcl37-c, cert-dcl51-cpp, bugprone-reserved-identifier) \ No newline at end of file diff --git a/yarn-project/foundation/src/fields/fields.test.ts b/yarn-project/foundation/src/fields/fields.test.ts index 2f8686bb4c5b..d1a728e67d77 100644 --- a/yarn-project/foundation/src/fields/fields.test.ts +++ b/yarn-project/foundation/src/fields/fields.test.ts @@ -109,7 +109,7 @@ describe('Bn254 arithmetic', () => { expect(actual).toEqual(expected); }); - it('High Bonudary', () => { + it('High Boundary', () => { // -1 - (-1) = 0 const a = new Fr(Fr.MODULUS - 1n); const b = new Fr(Fr.MODULUS - 1n); @@ -184,6 +184,16 @@ describe('Bn254 arithmetic', () => { }); }); + describe('Square root', () => { + it('Should return the correct square root', () => { + const a = new Fr(16); + const expected = new Fr(4); + + const actual = a.sqrt(); + expect(actual).toEqual(expected); + }); + }); + describe('Comparison', () => { it.each([ [new Fr(5), new Fr(10), -1], diff --git a/yarn-project/foundation/src/fields/fields.ts b/yarn-project/foundation/src/fields/fields.ts index 94a39a18d07d..9682b04d975d 100644 --- a/yarn-project/foundation/src/fields/fields.ts +++ b/yarn-project/foundation/src/fields/fields.ts @@ -280,11 +280,22 @@ export class Fr extends BaseField { return new Fr(this.toBigInt() / rhs.toBigInt()); } - sqrt() { + /** + * Computes the square root of the field element. + * @returns The square root of the field element if it exists (undefined if not). + */ + sqrt(): Fr | undefined { const wasm = BarretenbergSync.getSingleton().getWasm(); wasm.writeMemory(0, this.toBuffer()); wasm.call('grumpkin_fr_sqrt', 0, Fr.SIZE_IN_BYTES); - return Fr.fromBuffer(Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES, Fr.SIZE_IN_BYTES * 2))); + const isSqrtBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES, Fr.SIZE_IN_BYTES + 1)); + const isSqrt = isSqrtBuf[0] === 1; + if (!isSqrt) { + return undefined; + } + + const rootBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES + 1, Fr.SIZE_IN_BYTES * 2 + 1)) + return Fr.fromBuffer(rootBuf); } toJSON() {