Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(avm)!: variants for CAST/NOT opcode #8497

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions avm-transpiler/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ pub enum AvmOpcode {
OR_16,
XOR_8,
XOR_16,
NOT,
NOT_8,
NOT_16,
SHL_8,
SHL_16,
SHR_8,
SHR_16,
CAST,
CAST_8,
CAST_16,
// Execution environment
ADDRESS,
STORAGEADDRESS,
Expand Down Expand Up @@ -127,13 +129,15 @@ impl AvmOpcode {
AvmOpcode::OR_16 => "OR_16",
AvmOpcode::XOR_8 => "XOR_8",
AvmOpcode::XOR_16 => "XOR_16",
AvmOpcode::NOT => "NOT",
AvmOpcode::NOT_8 => "NOT_8",
AvmOpcode::NOT_16 => "NOT_16",
AvmOpcode::SHL_8 => "SHL_8",
AvmOpcode::SHL_16 => "SHL_16",
AvmOpcode::SHR_8 => "SHR_8",
AvmOpcode::SHR_16 => "SHR_16",
// Compute - Type Conversions
AvmOpcode::CAST => "CAST",
AvmOpcode::CAST_8 => "CAST_8",
AvmOpcode::CAST_16 => "CAST_16",

// Execution Environment
AvmOpcode::ADDRESS => "ADDRESS",
Expand Down
10 changes: 8 additions & 2 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,12 @@ fn generate_cast_instruction(
destination_indirect: bool,
dst_tag: AvmTypeTag,
) -> AvmInstruction {
let bits_needed = bits_needed_for(&source).max(bits_needed_for(&destination));
let avm_opcode = match bits_needed {
8 => AvmOpcode::CAST_8,
16 => AvmOpcode::CAST_16,
_ => panic!("CAST only supports 8 and 16 bit encodings, needed {}", bits_needed),
};
let mut indirect_flags = ALL_DIRECT;
if source_indirect {
indirect_flags |= ZEROTH_OPERAND_INDIRECT;
Expand All @@ -831,10 +837,10 @@ fn generate_cast_instruction(
indirect_flags |= FIRST_OPERAND_INDIRECT;
}
AvmInstruction {
opcode: AvmOpcode::CAST,
opcode: avm_opcode,
indirect: Some(indirect_flags),
tag: Some(dst_tag),
operands: vec![AvmOperand::U32 { value: source }, AvmOperand::U32 { value: destination }],
operands: vec![make_operand(bits_needed, &source), make_operand(bits_needed, &destination)],
}
}

Expand Down
60 changes: 30 additions & 30 deletions barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,11 @@ TEST_F(AvmExecutionTests, setAndCastOpcodes)
"02" // U16
"B813" // val 47123
"0011" // dst_offset 17
+ to_hex(OpCode::CAST) + // opcode CAST
+ to_hex(OpCode::CAST_8) + // opcode CAST
"00" // Indirect flag
"01" // U8
"00000011" // addr a
"00000012" // addr casted a
"11" // addr a
"12" // addr casted a
+ to_hex(OpCode::RETURN) + // opcode RETURN
"00" // Indirect flag
"00000000" // ret offset 0
Expand All @@ -747,12 +747,12 @@ TEST_F(AvmExecutionTests, setAndCastOpcodes)

// SUB
EXPECT_THAT(instructions.at(1),
AllOf(Field(&Instruction::op_code, OpCode::CAST),
AllOf(Field(&Instruction::op_code, OpCode::CAST_8),
Field(&Instruction::operands,
ElementsAre(VariantWith<uint8_t>(0),
VariantWith<AvmMemoryTag>(AvmMemoryTag::U8),
VariantWith<uint32_t>(17),
VariantWith<uint32_t>(18)))));
VariantWith<uint8_t>(17),
VariantWith<uint8_t>(18)))));

auto trace = gen_trace_from_instr(instructions);

Expand Down Expand Up @@ -1238,16 +1238,16 @@ TEST_F(AvmExecutionTests, embeddedCurveAddOpCode)
"00000000" // cd_offset
"00000001" // copy_size
"00000000" // dst_offset
+ to_hex(OpCode::CAST) + // opcode CAST inf to U8
+ to_hex(OpCode::CAST_8) + // opcode CAST inf to U8
"00" // Indirect flag
"01" // U8 tag field
"00000002" // a_is_inf
"00000002" // a_is_inf
+ to_hex(OpCode::CAST) + // opcode CAST inf to U8
"02" // a_is_inf
"02" // a_is_inf
+ to_hex(OpCode::CAST_8) + // opcode CAST inf to U8
"00" // Indirect flag
"01" // U8 tag field
"00000005" // b_is_inf
"00000005" // b_is_inf
"05" // b_is_inf
"05" // b_is_inf
+ to_hex(OpCode::SET_8) + // opcode SET for direct src_length
"00" // Indirect flag
"03" // U32
Expand Down Expand Up @@ -1314,16 +1314,16 @@ TEST_F(AvmExecutionTests, msmOpCode)
"00000000" // cd_offset 0
"00000001" // copy_size (10 elements)
"00000000" // dst_offset 0
+ to_hex(OpCode::CAST) + // opcode CAST inf to U8
+ to_hex(OpCode::CAST_8) + // opcode CAST inf to U8
"00" // Indirect flag
"01" // U8 tag field
"00000002" // a_is_inf
"00000002" //
+ to_hex(OpCode::CAST) + // opcode CAST inf to U8
"02" // a_is_inf
"02" //
+ to_hex(OpCode::CAST_8) + // opcode CAST inf to U8
"00" // Indirect flag
"01" // U8 tag field
"00000005" // b_is_inf
"00000005" //
"05" // b_is_inf
"05" //
+ to_hex(OpCode::SET_8) + // opcode SET for length
"00" // Indirect flag
"03" // U32
Expand Down Expand Up @@ -1758,11 +1758,11 @@ TEST_F(AvmExecutionTests, kernelOutputEmitOpcodes)
"01" // value 1
"01" // dst_offset 1
// Cast set to field
+ to_hex(OpCode::CAST) + // opcode CAST
+ to_hex(OpCode::CAST_8) + // opcode CAST
"00" // Indirect flag
"06" // tag field
"00000001" // dst 1
"00000001" // dst 1
"01" // dst 1
"01" // dst 1
+ to_hex(OpCode::EMITNOTEHASH) + // opcode EMITNOTEHASH
"00" // Indirect flag
"00000001" // src offset 1
Expand Down Expand Up @@ -1859,11 +1859,11 @@ TEST_F(AvmExecutionTests, kernelOutputStorageLoadOpcodeSimple)
"03" // U32
"09" // value 9
"01" // dst_offset 1
+ to_hex(OpCode::CAST) + // opcode CAST (Cast set to field)
+ to_hex(OpCode::CAST_8) + // opcode CAST (Cast set to field)
"00" // Indirect flag
"06" // tag field
"00000001" // dst 1
"00000001" // dst 1
"01" // dst 1
"01" // dst 1
+ to_hex(OpCode::SLOAD) + // opcode SLOAD
"00" // Indirect flag
"00000001" // slot offset 1
Expand Down Expand Up @@ -1972,11 +1972,11 @@ TEST_F(AvmExecutionTests, kernelOutputStorageOpcodes)
"09" // value 9
"01" // dst_offset 1
// Cast set to field
+ to_hex(OpCode::CAST) + // opcode CAST
+ to_hex(OpCode::CAST_8) + // opcode CAST
"00" // Indirect flag
"06" // tag field
"00000001" // dst 1
"00000001" // dst 1
"01" // dst 1
"01" // dst 1
+ to_hex(OpCode::SLOAD) + // opcode SLOAD
"00" // Indirect flag
"00000001" // slot offset 1
Expand Down Expand Up @@ -2047,11 +2047,11 @@ TEST_F(AvmExecutionTests, kernelOutputHashExistsOpcodes)
"01" // value 1
"01" // dst_offset 1
// Cast set to field
+ to_hex(OpCode::CAST) + // opcode CAST
+ to_hex(OpCode::CAST_8) + // opcode CAST
"00" // Indirect flag
"06" // tag field
"00000001" // dst 1
"00000001" // dst 1
"01" // dst 1
"01" // dst 1
+ to_hex(OpCode::NOTEHASHEXISTS) + // opcode NOTEHASHEXISTS
"00" // Indirect flag
"00000001" // slot offset 1
Expand Down
15 changes: 8 additions & 7 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/alu_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ FF AvmAluTraceBuilder::op_not(FF const& a, AvmMemoryTag in_tag, uint32_t const c

alu_trace.push_back(AvmAluTraceBuilder::AluTraceEntry{
.alu_clk = clk,
.opcode = OpCode::NOT,
.opcode = OpCode::NOT_8, // FIXME: take into account all opcodes.
.tag = in_tag,
.alu_ia = a,
.alu_ic = c,
Expand Down Expand Up @@ -585,7 +585,7 @@ FF AvmAluTraceBuilder::op_cast(FF const& a, AvmMemoryTag in_tag, uint32_t clk)
}
alu_trace.push_back(AvmAluTraceBuilder::AluTraceEntry{
.alu_clk = clk,
.opcode = OpCode::CAST,
.opcode = OpCode::CAST_8, // FIXME: take into account all opcodes.
.tag = in_tag,
.alu_ia = a,
.alu_ic = c,
Expand Down Expand Up @@ -618,9 +618,10 @@ bool AvmAluTraceBuilder::is_range_check_required() const
bool AvmAluTraceBuilder::is_alu_row_enabled(const AvmAluTraceBuilder::AluTraceEntry& r)
{
return (r.opcode == OpCode::ADD_8 || r.opcode == OpCode::SUB_8 || r.opcode == OpCode::MUL_8 ||
r.opcode == OpCode::EQ_8 || r.opcode == OpCode::NOT || r.opcode == OpCode::LT_8 ||
r.opcode == OpCode::LTE_8 || r.opcode == OpCode::SHR_8 || r.opcode == OpCode::SHL_8 ||
r.opcode == OpCode::CAST || r.opcode == OpCode::DIV_8);
r.opcode == OpCode::EQ_8 || r.opcode == OpCode::NOT_8 || r.opcode == OpCode::NOT_16 ||
r.opcode == OpCode::LT_8 || r.opcode == OpCode::LTE_8 || r.opcode == OpCode::SHR_8 ||
r.opcode == OpCode::SHL_8 || r.opcode == OpCode::CAST_8 || r.opcode == OpCode::CAST_8 ||
r.opcode == OpCode::CAST_16 || r.opcode == OpCode::DIV_8);
}

/**
Expand All @@ -640,11 +641,11 @@ void AvmAluTraceBuilder::finalize(std::vector<AvmFullRow<FF>>& main_trace)
dest.alu_op_add = FF(src.opcode == OpCode::ADD_8 || src.opcode == OpCode::ADD_16 ? 1 : 0);
dest.alu_op_sub = FF(src.opcode == OpCode::SUB_8 || src.opcode == OpCode::SUB_16 ? 1 : 0);
dest.alu_op_mul = FF(src.opcode == OpCode::MUL_8 || src.opcode == OpCode::MUL_16 ? 1 : 0);
dest.alu_op_not = FF(src.opcode == OpCode::NOT ? 1 : 0);
dest.alu_op_not = FF(src.opcode == OpCode::NOT_8 || src.opcode == OpCode::NOT_16 ? 1 : 0);
dest.alu_op_eq = FF(src.opcode == OpCode::EQ_8 || src.opcode == OpCode::EQ_16 ? 1 : 0);
dest.alu_op_lt = FF(src.opcode == OpCode::LT_8 || src.opcode == OpCode::LT_16 ? 1 : 0);
dest.alu_op_lte = FF(src.opcode == OpCode::LTE_8 || src.opcode == OpCode::LTE_16 ? 1 : 0);
dest.alu_op_cast = FF(src.opcode == OpCode::CAST ? 1 : 0);
dest.alu_op_cast = FF(src.opcode == OpCode::CAST_8 || src.opcode == OpCode::CAST_16 ? 1 : 0);
dest.alu_op_shr = FF(src.opcode == OpCode::SHR_8 || src.opcode == OpCode::SHR_16 ? 1 : 0);
dest.alu_op_shl = FF(src.opcode == OpCode::SHL_8 || src.opcode == OpCode::SHL_16 ? 1 : 0);
dest.alu_op_div = FF(src.opcode == OpCode::DIV_8 || src.opcode == OpCode::DIV_16 ? 1 : 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ const std::unordered_map<OpCode, std::vector<OperandType>> OPCODE_WIRE_FORMAT =
{ OpCode::OR_16, three_operand_format16 },
{ OpCode::XOR_8, three_operand_format8 },
{ OpCode::XOR_16, three_operand_format16 },
{ OpCode::NOT, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT8, OperandType::UINT8 } },
{ OpCode::NOT_8, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT8, OperandType::UINT8 } },
{ OpCode::NOT_16, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT16, OperandType::UINT16 } },
{ OpCode::SHL_8, three_operand_format8 },
{ OpCode::SHL_16, three_operand_format16 },
{ OpCode::SHR_8, three_operand_format8 },
{ OpCode::SHR_16, three_operand_format16 },
// Compute - Type Conversions
{ OpCode::CAST, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT32, OperandType::UINT32 } },
{ OpCode::CAST_8, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT8, OperandType::UINT8 } },
{ OpCode::CAST_16, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT16, OperandType::UINT16 } },

// Execution Environment - Globals
{ OpCode::ADDRESS, getter_format },
Expand Down
24 changes: 18 additions & 6 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,16 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
std::get<uint16_t>(inst.operands.at(4)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
break;
case OpCode::NOT:
case OpCode::NOT_8:
trace_builder.op_not(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(2)),
std::get<uint32_t>(inst.operands.at(3)),
std::get<uint8_t>(inst.operands.at(2)),
std::get<uint8_t>(inst.operands.at(3)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
break;
case OpCode::NOT_16:
trace_builder.op_not(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint16_t>(inst.operands.at(2)),
std::get<uint16_t>(inst.operands.at(3)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
break;
case OpCode::SHL_8:
Expand Down Expand Up @@ -592,10 +598,16 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
break;

// Compute - Type Conversions
case OpCode::CAST:
case OpCode::CAST_8:
trace_builder.op_cast(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(2)),
std::get<uint32_t>(inst.operands.at(3)),
std::get<uint8_t>(inst.operands.at(2)),
std::get<uint8_t>(inst.operands.at(3)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
break;
case OpCode::CAST_16:
trace_builder.op_cast(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint16_t>(inst.operands.at(2)),
std::get<uint16_t>(inst.operands.at(3)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
break;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ const std::unordered_map<OpCode, FixedGasTable::GasRow> GAS_COST_TABLE = {
{ OpCode::OR_16, make_cost(AVM_OR_BASE_L2_GAS, 0, AVM_OR_DYN_L2_GAS, 0) },
{ OpCode::XOR_8, make_cost(AVM_XOR_BASE_L2_GAS, 0, AVM_XOR_DYN_L2_GAS, 0) },
{ OpCode::XOR_16, make_cost(AVM_XOR_BASE_L2_GAS, 0, AVM_XOR_DYN_L2_GAS, 0) },
{ OpCode::NOT, make_cost(AVM_NOT_BASE_L2_GAS, 0, AVM_NOT_DYN_L2_GAS, 0) },
{ OpCode::NOT_8, make_cost(AVM_NOT_BASE_L2_GAS, 0, AVM_NOT_DYN_L2_GAS, 0) },
{ OpCode::NOT_16, make_cost(AVM_NOT_BASE_L2_GAS, 0, AVM_NOT_DYN_L2_GAS, 0) },
{ OpCode::SHL_8, make_cost(AVM_SHL_BASE_L2_GAS, 0, AVM_SHL_DYN_L2_GAS, 0) },
{ OpCode::SHL_16, make_cost(AVM_SHL_BASE_L2_GAS, 0, AVM_SHL_DYN_L2_GAS, 0) },
{ OpCode::SHR_8, make_cost(AVM_SHR_BASE_L2_GAS, 0, AVM_SHR_DYN_L2_GAS, 0) },
{ OpCode::SHR_16, make_cost(AVM_SHR_BASE_L2_GAS, 0, AVM_SHR_DYN_L2_GAS, 0) },
{ OpCode::CAST, make_cost(AVM_CAST_BASE_L2_GAS, 0, AVM_CAST_DYN_L2_GAS, 0) },
{ OpCode::CAST_8, make_cost(AVM_CAST_BASE_L2_GAS, 0, AVM_CAST_DYN_L2_GAS, 0) },
{ OpCode::CAST_16, make_cost(AVM_CAST_BASE_L2_GAS, 0, AVM_CAST_DYN_L2_GAS, 0) },
{ OpCode::ADDRESS, make_cost(AVM_ADDRESS_BASE_L2_GAS, 0, AVM_ADDRESS_DYN_L2_GAS, 0) },
{ OpCode::STORAGEADDRESS, make_cost(AVM_STORAGEADDRESS_BASE_L2_GAS, 0, AVM_STORAGEADDRESS_DYN_L2_GAS, 0) },
{ OpCode::SENDER, make_cost(AVM_SENDER_BASE_L2_GAS, 0, AVM_SENDER_DYN_L2_GAS, 0) },
Expand Down
Loading
Loading