Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(avm): fix usage of Fr with tagged memory #4240

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 30 additions & 27 deletions yarn-project/acir-simulator/src/avm/avm_memory_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,43 @@ import { Fr } from '@aztec/foundation/fields';

import { strict as assert } from 'assert';

export interface MemoryValue {
add(rhs: MemoryValue): MemoryValue;
sub(rhs: MemoryValue): MemoryValue;
mul(rhs: MemoryValue): MemoryValue;
div(rhs: MemoryValue): MemoryValue;
export abstract class MemoryValue {
public abstract add(rhs: MemoryValue): MemoryValue;
public abstract sub(rhs: MemoryValue): MemoryValue;
public abstract mul(rhs: MemoryValue): MemoryValue;
public abstract div(rhs: MemoryValue): MemoryValue;

// We need this to be able to build an instance of the subclasses.
public abstract build(n: bigint): MemoryValue;

// Use sparingly.
toBigInt(): bigint;
public abstract toBigInt(): bigint;
}

export interface IntegralValue extends MemoryValue {
shl(rhs: IntegralValue): IntegralValue;
shr(rhs: IntegralValue): IntegralValue;
and(rhs: IntegralValue): IntegralValue;
or(rhs: IntegralValue): IntegralValue;
xor(rhs: IntegralValue): IntegralValue;
not(): IntegralValue;
export abstract class IntegralValue extends MemoryValue {
public abstract shl(rhs: IntegralValue): IntegralValue;
public abstract shr(rhs: IntegralValue): IntegralValue;
public abstract and(rhs: IntegralValue): IntegralValue;
public abstract or(rhs: IntegralValue): IntegralValue;
public abstract xor(rhs: IntegralValue): IntegralValue;
public abstract not(): IntegralValue;
}

// TODO: Optimize calculation of mod, etc. Can only do once per class?
abstract class UnsignedInteger implements IntegralValue {
abstract class UnsignedInteger extends IntegralValue {
private readonly bitmask: bigint;
private readonly mod: bigint;

protected constructor(private n: bigint, private bits: bigint) {
super();
assert(bits > 0);
// x % 2^n == x & (2^n - 1)
this.mod = 1n << bits;
this.bitmask = this.mod - 1n;
assert(n < this.mod);
}

// We need this to be able to build an instance of the subclass
// and not of type UnsignedInteger.
protected abstract build(n: bigint): UnsignedInteger;
public abstract build(n: bigint): UnsignedInteger;

public add(rhs: UnsignedInteger): UnsignedInteger {
assert(this.bits == rhs.bits);
Expand Down Expand Up @@ -93,18 +95,14 @@ abstract class UnsignedInteger implements IntegralValue {
public toBigInt(): bigint {
return this.n;
}

public equals(rhs: UnsignedInteger) {
return this.bits == rhs.bits && this.toBigInt() == rhs.toBigInt();
}
}

export class Uint8 extends UnsignedInteger {
constructor(n: number | bigint) {
super(BigInt(n), 8n);
}

protected build(n: bigint): Uint8 {
public build(n: bigint): Uint8 {
return new Uint8(n);
}
}
Expand All @@ -114,7 +112,7 @@ export class Uint16 extends UnsignedInteger {
super(BigInt(n), 16n);
}

protected build(n: bigint): Uint16 {
public build(n: bigint): Uint16 {
return new Uint16(n);
}
}
Expand All @@ -124,7 +122,7 @@ export class Uint32 extends UnsignedInteger {
super(BigInt(n), 32n);
}

protected build(n: bigint): Uint32 {
public build(n: bigint): Uint32 {
return new Uint32(n);
}
}
Expand All @@ -134,7 +132,7 @@ export class Uint64 extends UnsignedInteger {
super(BigInt(n), 64n);
}

protected build(n: bigint): Uint64 {
public build(n: bigint): Uint64 {
return new Uint64(n);
}
}
Expand All @@ -144,19 +142,24 @@ export class Uint128 extends UnsignedInteger {
super(BigInt(n), 128n);
}

protected build(n: bigint): Uint128 {
public build(n: bigint): Uint128 {
return new Uint128(n);
}
}

export class Field implements MemoryValue {
export class Field extends MemoryValue {
public static readonly MODULUS: bigint = Fr.MODULUS;
private readonly rep: Fr;

constructor(v: number | bigint | Fr) {
super();
this.rep = new Fr(v);
}

public build(n: bigint): Field {
return new Field(n);
}

public add(rhs: Field): Field {
return new Field(this.rep.add(rhs.rep));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ describe('External Calls', () => {
const addr = new Fr(123456n);

const argsOffset = 2;
const args = [new Fr(1n), new Fr(2n), new Fr(3n)];
const args = [new Field(1n), new Field(2n), new Field(3n)];
const argsSize = args.length;

const retOffset = 8;
const retSize = 2;

const successOffset = 7;

machineState.memory.set(0, gas);
machineState.memory.set(1, addr);
machineState.memory.set(0, new Field(gas));
machineState.memory.set(1, new Field(addr));
machineState.memory.setSlice(2, args);

const otherContextInstructions: [Opcode, any[]][] = [
Expand All @@ -72,10 +72,10 @@ describe('External Calls', () => {
await instruction.execute(machineState, journal);

const successValue = machineState.memory.get(successOffset);
expect(successValue).toEqual(new Fr(1n));
expect(successValue).toEqual(new Field(1n));

const retValue = machineState.memory.getSlice(retOffset, retSize);
expect(retValue).toEqual([new Fr(1n), new Fr(2n)]);
expect(retValue).toEqual([new Field(1n), new Field(2n)]);

// Check that the storage call has been merged into the parent journal
const { storageWrites } = journal.flush();
Expand Down Expand Up @@ -126,7 +126,7 @@ describe('External Calls', () => {

// No revert has occurred, but the nested execution has failed
const successValue = machineState.memory.get(successOffset);
expect(successValue).toEqual(new Fr(0n));
expect(successValue).toEqual(new Field(0n));
});
});
});
10 changes: 6 additions & 4 deletions yarn-project/acir-simulator/src/avm/opcodes/external_calls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ export class Call extends Instruction {

// We only take as much data as was specified in the return size -> TODO: should we be reverting here
const returnData = returnObject.output.slice(0, this.retSize);
const convertedReturnData = returnData.map(f => new Field(f));

// Write our return data into memory
machineState.memory.set(this.successOffset, new Fr(success));
machineState.memory.setSlice(this.retOffset, returnData);
machineState.memory.set(this.successOffset, new Field(success ? 1 : 0));
machineState.memory.setSlice(this.retOffset, convertedReturnData);

if (success) {
avmContext.mergeJournal();
Expand Down Expand Up @@ -84,10 +85,11 @@ export class StaticCall extends Instruction {

// We only take as much data as was specified in the return size -> TODO: should we be reverting here
const returnData = returnObject.output.slice(0, this.retSize);
const convertedReturnData = returnData.map(f => new Field(f));

// Write our return data into memory
machineState.memory.set(this.successOffset, new Fr(success));
machineState.memory.setSlice(this.retOffset, returnData);
machineState.memory.set(this.successOffset, new Field(success ? 1 : 0));
machineState.memory.setSlice(this.retOffset, convertedReturnData);

if (success) {
avmContext.mergeJournal();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ describe('Storage Instructions', () => {
expect(journal.readStorage).toBeCalledWith(address, new Fr(a.toBigInt()));

const actual = machineState.memory.get(1);
expect(actual).toEqual(expectedResult);
expect(actual).toEqual(new Field(expectedResult));
});
});
8 changes: 6 additions & 2 deletions yarn-project/acir-simulator/src/avm/opcodes/storage.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Fr } from '@aztec/foundation/fields';

import { AvmMachineState } from '../avm_machine_state.js';
import { Field } from '../avm_memory_types.js';
import { AvmInterpreterError } from '../interpreter/interpreter.js';
import { AvmJournal } from '../journal/journal.js';
import { Instruction } from './instruction.js';
Expand Down Expand Up @@ -44,9 +45,12 @@ export class SLoad extends Instruction {
async execute(machineState: AvmMachineState, journal: AvmJournal): Promise<void> {
const slot = machineState.memory.get(this.slotOffset);

const data = journal.readStorage(machineState.executionEnvironment.storageAddress, new Fr(slot.toBigInt()));
const data: Fr = await journal.readStorage(
machineState.executionEnvironment.storageAddress,
new Fr(slot.toBigInt()),
);

machineState.memory.set(this.destOffset, await data);
machineState.memory.set(this.destOffset, new Field(data));

this.incrementPc(machineState);
}
Expand Down
Loading