Skip to content

Commit

Permalink
build compilation pipeline in miden's sem tests;
Browse files Browse the repository at this point in the history
rename translate to parse in wasm frontend;
  • Loading branch information
greenhat committed Jun 26, 2023
1 parent 7691b4e commit ce2f0eb
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 49 deletions.
6 changes: 4 additions & 2 deletions crates/codegen-midenvm/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ pub use miden_inst::*;
use ozk_miden_dialect::ops::*;
use pliron::context::Context;
use pliron::context::Ptr;
use pliron::dialects::builtin::op_interfaces::get_callees_syms;
use pliron::dialects::builtin::op_interfaces::SymbolOpInterface;
use pliron::op::Op;
use pliron::operation::Operation;
use pliron::with_context::AttachContext;
use thiserror::Error;
Expand All @@ -18,7 +20,7 @@ use crate::MidenError;
use crate::MidenTargetConfig;

pub fn emit_prog(
ctx: &mut Context,
ctx: &Context,
op: Ptr<Operation>,
target_config: &MidenTargetConfig,
) -> Result<InstBuffer, MidenError> {
Expand Down Expand Up @@ -47,7 +49,7 @@ pub fn topo_sort_procedures(
for proc in procedures {
let proc_name = proc.get_symbol_name(ctx);
topo_sort.insert(proc_name.clone());
for dep in proc.get_callees_sym(ctx) {
for dep in get_callees_syms(ctx, proc.get_operation()) {
topo_sort.add_dependency(dep, proc_name.clone());
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/codegen-midenvm/src/codegen/inst_buf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ impl InstBuffer {
}
}

pub(crate) fn pretty_print(&self) -> String {
pub fn pretty_print(&self) -> String {
self.inner
.iter()
.map(|inst| {
Expand Down
83 changes: 51 additions & 32 deletions crates/codegen-midenvm/tests/sem_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

use std::ops::RangeFrom;

use c2zk_ir_transform::miden::lowering::WasmToMidenArithLoweringPass;
use c2zk_ir_transform::miden::lowering::WasmToMidenCFLoweringPass;
use c2zk_ir_transform::miden::lowering::WasmToMidenFinalLoweringPass;
use miden_assembly::Assembler;
use miden_processor::math::Felt;
use miden_processor::AdviceInputs;
Expand All @@ -14,12 +17,58 @@ use miden_processor::StackInputs;
use miden_processor::VmState;
use miden_processor::VmStateIterator;
use miden_stdlib::StdLibrary;
use ozk_codegen_midenvm::emit_prog;
use ozk_codegen_midenvm::MidenTargetConfig;
use ozk_compiler::compile;
use ozk_frontend_wasm::WasmFrontendConfig;
use ozk_wasm_dialect::ops::ModuleOp;
use pliron::context::Context;
use pliron::context::Ptr;
use pliron::dialects::builtin;
use pliron::dialects::builtin::op_interfaces::SingleBlockRegionInterface;
use pliron::linked_list::ContainsLinkedList;
use pliron::op::Op;
use pliron::operation::Operation;
use pliron::pass::PassManager;
use wasmtime::*;
use winter_math::StarkField;

pub fn compile(source: &[u8]) -> String {
let frontend_config = WasmFrontendConfig::default();
let target_config = MidenTargetConfig::default();
let mut ctx = Context::new();
frontend_config.register(&mut ctx);
target_config.register(&mut ctx);
let wasm_module_op =
ozk_frontend_wasm::parse_module(&mut ctx, source, &frontend_config).unwrap();
let miden_prog = run_conversion_passes(&mut ctx, wasm_module_op);
let inst_buf = emit_prog(&ctx, miden_prog, &target_config).unwrap();
inst_buf.pretty_print()
}

fn run_conversion_passes(ctx: &mut Context, wasm_module: ModuleOp) -> Ptr<Operation> {
// we need to wrap the wasm in an op because passes cannot replace the root op
let wrapper_module = builtin::ops::ModuleOp::new(ctx, "wrapper");
wasm_module
.get_operation()
.insert_at_back(wrapper_module.get_body(ctx, 0), ctx);
let mut pass_manager = PassManager::new();
pass_manager.add_pass(Box::<WasmToMidenCFLoweringPass>::default());
pass_manager.add_pass(Box::<WasmToMidenArithLoweringPass>::default());
pass_manager.add_pass(Box::<WasmToMidenFinalLoweringPass>::default());
pass_manager
.run(ctx, wrapper_module.get_operation())
.unwrap();
let inner_module = wrapper_module
.get_body(ctx, 0)
.deref(ctx)
.iter(ctx)
.collect::<Vec<Ptr<Operation>>>()
.first()
.cloned()
.unwrap();
inner_module
}

pub fn check_wasm(
source: &[u8],
input: Vec<u64>,
Expand All @@ -34,24 +83,6 @@ pub fn check_wasm(
check_miden(wat, input, secret_input, expected_output, expected_miden);
}

// fn compile(
// source: &[u8],
// frontend_config: &WasmFrontendConfig,
// target_config: &MidenTargetConfig,
// ) -> Vec<u8> {
// let mut ctx = Context::new();
// frontend_config.register(&mut ctx);
// target_config.register(&mut ctx);
// let module = translate_module(&mut ctx, source).unwrap();
// target_config
// .pass_manager
// .run(&mut ctx, module.get_operation())
// .unwrap();
// let target = MidenTarget::new(target_config);
// let code = target.compile_module(module).unwrap();
// code
// }

#[allow(unreachable_code)]
pub fn check_miden(
source: String,
Expand All @@ -60,20 +91,8 @@ pub fn check_miden(
expected_output: Vec<u64>,
expected_miden: expect_test::Expect,
) {
let frontend_config = WasmFrontendConfig::default();
let target_config = MidenTargetConfig::default();
let wasm = wat::parse_str(source).unwrap();
let miden_prog = compile(&wasm, frontend_config.into(), target_config.into()).unwrap();
// let module = translate(&mut ctx, &wasm, frontend_config).unwrap();
// run_ir_passes(&mut module, &target_config.ir_passes);
// let inst_buf = compile_prog(module, &target_config).unwrap();
// todo!("compile_module");
// let inst_buf: InstBuffer = InstBuffer::new(&target_config);
// let out_source = inst_buf.pretty_print();
// expected_miden.assert_eq(&out_source);
// let program = inst_buf.pretty_print();
let program = String::from_utf8(miden_prog).unwrap();

let program = compile(&wasm);
let assembler = Assembler::default()
.with_library(&StdLibrary::default())
.unwrap();
Expand Down
25 changes: 16 additions & 9 deletions crates/dialects/miden/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use pliron::dialects::builtin::attributes::FloatAttr;
use pliron::dialects::builtin::attributes::IntegerAttr;
use pliron::dialects::builtin::attributes::StringAttr;
use pliron::dialects::builtin::attributes::TypeAttr;
use pliron::dialects::builtin::op_interfaces::CallOpInterface;
use pliron::dialects::builtin::op_interfaces::OneRegionInterface;
use pliron::dialects::builtin::op_interfaces::SingleBlockRegionInterface;
use pliron::dialects::builtin::op_interfaces::SymbolOpInterface;
Expand Down Expand Up @@ -127,15 +128,15 @@ impl ProcOp {
.flat_map(|bb| bb.deref(ctx).iter(ctx))
}

pub fn get_callees_sym(&self, ctx: &Context) -> impl Iterator<Item = String> {
let mut callees = Vec::new();
for op in self.op_iter(ctx) {
if let Some(call_op) = op.deref(ctx).get_op(ctx).downcast_ref::<CallOp>() {
callees.push(call_op.get_callee_sym(ctx));
}
}
callees.into_iter()
}
// pub fn get_callees_sym(&self, ctx: &Context) -> impl Iterator<Item = String> {
// let mut callees = Vec::new();
// for op in self.op_iter(ctx) {
// if let Some(call_op) = op.deref(ctx).get_op(ctx).downcast_ref::<CallOp>() {
// callees.push(call_op.get_callee_sym(ctx));
// }
// }
// callees.into_iter()
// }
}

impl OneRegionInterface for ProcOp {}
Expand Down Expand Up @@ -358,6 +359,12 @@ impl Verify for CallOp {
}
}

impl CallOpInterface for CallOp {
fn get_callee_sym(&self, ctx: &Context) -> String {
self.get_callee_sym(ctx)
}
}

declare_op!(
/// Push local variable with the given index onto the stack.
///
Expand Down
2 changes: 1 addition & 1 deletion crates/frontend-wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mod types;

pub use crate::config::WasmFrontendConfig;
pub use crate::error::WasmError;
pub use crate::module_translator::translate_module;
pub use crate::module_translator::parse_module;

// Convenience reexport of the wasmparser crate that we're linking against,
// since a number of types in `wasmparser` show up in the public API of
Expand Down
11 changes: 8 additions & 3 deletions crates/frontend-wasm/src/module_translator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use crate::error::WasmError;
use crate::func_builder::FuncBuilder;
use crate::types::{from_func_type, from_val_type, FuncIndex};
use crate::WasmFrontendConfig;
use crate::{code_translator::translate_operator, mod_builder::ModuleBuilder};
use ozk_wasm_dialect::ops::ModuleOp;
use pliron::context::Context;
Expand All @@ -15,12 +16,16 @@ use wasmparser::{
Payload, Type, TypeRef, Validator, ValidatorResources, WasmModuleResources,
};

/// Translate a sequence of bytes forming a valid Wasm binary into a list of valid IR
pub fn translate_module(ctx: &mut Context, data: &[u8]) -> Result<ModuleOp, WasmError> {
/// Translate a sequence of bytes forming a valid Wasm binary into a `wasm.module` operation.
pub fn parse_module(
ctx: &mut Context,
wasm: &[u8],
_config: &WasmFrontendConfig,
) -> Result<ModuleOp, WasmError> {
let mut validator = Validator::new();
let mut mod_builder = ModuleBuilder::new();

for payload in Parser::new(0).parse_all(data) {
for payload in Parser::new(0).parse_all(wasm) {
// dbg!(&mod_builder);
match payload? {
Payload::Version {
Expand Down
2 changes: 1 addition & 1 deletion crates/frontend/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub fn translate(
config: FrontendConfig,
) -> Result<ModuleOp, FrontendError> {
Ok(match config {
FrontendConfig::Wasm(_) => ozk_frontend_wasm::translate_module(ctx, source)?,
FrontendConfig::Wasm(config) => ozk_frontend_wasm::parse_module(ctx, source, &config)?,
})
}

Expand Down

0 comments on commit ce2f0eb

Please sign in to comment.