Skip to content

Commit

Permalink
winch: Refactor the Masm associated types (#6451)
Browse files Browse the repository at this point in the history
This commit is a follow up to #6443,
in which we discussed potentially having `PtrSize` and `ABI` as
associated types to the `MacroAssembler` trait.

I considered having `PtrSize` associated to the `ABI`, but given the
amount of ABI details needed at the `MacroAssembler` level, I decided to
go with the approach in this change.

The chosen approach ended up cutting a decent amount of boilerplate from
the `MacroAssembler` itself, but also from each of the touchpoints where
the `MacroAssembler` is used.

This change also standardizes the signatures of the `ABI` trait. Some of
them borrowed `&self` and some didn't, but in practice, there's no need
to have any of them borrow `&self`.
  • Loading branch information
saulecabrera authored May 25, 2023
1 parent 72b641d commit f70b0f3
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 237 deletions.
8 changes: 4 additions & 4 deletions winch/codegen/src/abi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,20 @@ pub(crate) use local::*;
/// specific registers, etc.
pub(crate) trait ABI {
/// The required stack alignment.
fn stack_align(&self) -> u8;
fn stack_align() -> u8;

/// The required stack alignment for calls.
fn call_stack_align(&self) -> u8;
fn call_stack_align() -> u8;

/// The offset to the argument base, relative to the frame pointer.
fn arg_base_offset(&self) -> u8;
fn arg_base_offset() -> u8;

/// The offset to the return address, relative to the frame pointer.
fn ret_addr_offset() -> u8;

/// Construct the ABI-specific signature from a WebAssembly
/// function type.
fn sig(&self, wasm_sig: &FuncType, call_conv: &CallingConvention) -> ABISig;
fn sig(wasm_sig: &FuncType, call_conv: &CallingConvention) -> ABISig;

/// Returns the number of bits in a word.
fn word_bits() -> u32;
Expand Down
33 changes: 12 additions & 21 deletions winch/codegen/src/codegen/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<'a> FnCall<'a> {
/// want to calculate any adjustments to the caller's frame, after
/// having saved any live registers, so that we can account for
/// any pushes generated by register spilling.
pub fn new<A: ABI, M: MacroAssembler>(
pub fn new<M: MacroAssembler>(
callee_sig: &'a ABISig,
context: &mut CodeGenContext,
masm: &mut M,
Expand Down Expand Up @@ -135,52 +135,43 @@ impl<'a> FnCall<'a> {
Self {
abi_sig: &callee_sig,
arg_stack_space,
call_stack_space: (spilled_regs * <A as ABI>::word_bytes())
+ (memory_values * <A as ABI>::word_bytes()),
call_stack_space: (spilled_regs * <M::ABI as ABI>::word_bytes())
+ (memory_values * <M::ABI as ABI>::word_bytes()),
sp_offset_at_callsite,
}
}

/// Emit a direct function call, to a locally defined function.
pub fn direct<M: MacroAssembler, A: ABI>(
pub fn direct<M: MacroAssembler>(
&self,
masm: &mut M,
context: &mut CodeGenContext,
callee: FuncIndex,
alignment: u32,
addend: u32,
) {
let reserved_stack = masm.call(alignment, addend, self.arg_stack_space, |masm| {
self.assign_args(context, masm, <A as ABI>::scratch_reg());
let reserved_stack = masm.call(self.arg_stack_space, |masm| {
self.assign_args(context, masm, <M::ABI as ABI>::scratch_reg());
CalleeKind::Direct(callee.as_u32())
});
self.post_call::<M, A>(masm, context, reserved_stack);
self.post_call::<M>(masm, context, reserved_stack);
}

/// Emit an indirect function call, using a raw address.
pub fn indirect<M: MacroAssembler, A: ABI>(
pub fn indirect<M: MacroAssembler>(
&self,
masm: &mut M,
context: &mut CodeGenContext,
addr: M::Address,
alignment: u32,
addend: u32,
) {
let reserved_stack = masm.call(alignment, addend, self.arg_stack_space, |masm| {
let scratch = <A as ABI>::scratch_reg();
let reserved_stack = masm.call(self.arg_stack_space, |masm| {
let scratch = <M::ABI as ABI>::scratch_reg();
self.assign_args(context, masm, scratch);
masm.load(addr, scratch, OperandSize::S64);
CalleeKind::Indirect(scratch)
});
self.post_call::<M, A>(masm, context, reserved_stack);
self.post_call::<M>(masm, context, reserved_stack);
}

fn post_call<M: MacroAssembler, A: ABI>(
&self,
masm: &mut M,
context: &mut CodeGenContext,
size: u32,
) {
fn post_call<M: MacroAssembler>(&self, masm: &mut M, context: &mut CodeGenContext, size: u32) {
masm.free_stack(self.call_stack_space + size);
context.drop_last(self.abi_sig.params.len());
// The stack pointer at the end of the function call
Expand Down
62 changes: 21 additions & 41 deletions winch/codegen/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use call::FnCall;
use wasmparser::{
BinaryReader, FuncType, FuncValidator, ValType, ValidatorResources, VisitOperator,
};
use wasmtime_environ::{FuncIndex, PtrSize};
use wasmtime_environ::FuncIndex;

mod context;
pub(crate) use context::*;
Expand All @@ -18,11 +18,9 @@ pub use env::*;
pub mod call;

/// The code generation abstraction.
pub(crate) struct CodeGen<'a, A, M, P>
pub(crate) struct CodeGen<'a, M>
where
M: MacroAssembler,
A: ABI,
P: PtrSize,
{
/// The ABI-specific representation of the function signature, excluding results.
sig: ABISig,
Expand All @@ -31,33 +29,26 @@ where
pub context: CodeGenContext<'a>,

/// A reference to the function compilation environment.
pub env: FuncEnv<'a, P>,
pub env: FuncEnv<'a, M::Ptr>,

/// The MacroAssembler.
pub masm: &'a mut M,

/// A reference to the current ABI.
pub abi: &'a A,
}

impl<'a, A, M, P> CodeGen<'a, A, M, P>
impl<'a, M> CodeGen<'a, M>
where
M: MacroAssembler,
A: ABI,
P: PtrSize,
{
pub fn new(
masm: &'a mut M,
abi: &'a A,
context: CodeGenContext<'a>,
env: FuncEnv<'a, P>,
env: FuncEnv<'a, M::Ptr>,
sig: ABISig,
) -> Self {
Self {
sig,
context,
masm,
abi,
env,
}
}
Expand Down Expand Up @@ -89,17 +80,17 @@ where
) -> Result<()> {
self.spill_register_arguments();
let defined_locals_range = &self.context.frame.defined_locals_range;
self.masm.zero_mem_range(
defined_locals_range.as_range(),
<A as ABI>::word_bytes(),
&mut self.context.regalloc,
);
self.masm
.zero_mem_range(defined_locals_range.as_range(), &mut self.context.regalloc);

// Save the vmctx pointer to its local slot in case we need to reload it
// at any point.
let vmctx_addr = self.masm.local_address(&self.context.frame.vmctx_slot);
self.masm
.store(<A as ABI>::vmctx_reg().into(), vmctx_addr, OperandSize::S64);
self.masm.store(
<M::ABI as ABI>::vmctx_reg().into(),
vmctx_addr,
OperandSize::S64,
);

while !body.eof() {
let offset = body.original_position();
Expand Down Expand Up @@ -141,7 +132,7 @@ where
params.extend_from_slice(&callee.ty.params());
let sig = FuncType::new(params, callee.ty.results().to_owned());

let caller_vmctx = <A as ABI>::vmctx_reg();
let caller_vmctx = <M::ABI as ABI>::vmctx_reg();
let callee_vmctx = self.context.any_gpr(self.masm);
let callee_vmctx_offset = self.env.vmoffsets.vmctx_vmfunction_import_vmctx(index);
let callee_vmctx_addr = self.masm.address_at_reg(caller_vmctx, callee_vmctx_offset);
Expand All @@ -161,32 +152,21 @@ where
stack.insert(location as usize, Val::reg(caller_vmctx));
stack.insert(location as usize, Val::reg(callee_vmctx));
(
self.abi.sig(&sig, &CallingConvention::Default),
<M::ABI as ABI>::sig(&sig, &CallingConvention::Default),
Some(callee_addr),
)
} else {
(self.abi.sig(&callee.ty, &CallingConvention::Default), None)
(
<M::ABI as ABI>::sig(&callee.ty, &CallingConvention::Default),
None,
)
};

let fncall = FnCall::new::<A, M>(&sig, &mut self.context, self.masm);
let alignment = self.abi.call_stack_align();
let addend = self.abi.arg_base_offset();
let fncall = FnCall::new::<M>(&sig, &mut self.context, self.masm);
if let Some(addr) = callee_addr {
fncall.indirect::<M, A>(
self.masm,
&mut self.context,
addr,
alignment.into(),
addend.into(),
);
fncall.indirect::<M>(self.masm, &mut self.context, addr);
} else {
fncall.direct::<M, A>(
self.masm,
&mut self.context,
index,
alignment.into(),
addend.into(),
);
fncall.direct::<M>(self.masm, &mut self.context, index);
}
}

Expand Down
10 changes: 5 additions & 5 deletions winch/codegen/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ pub(crate) struct Frame {

impl Frame {
/// Allocate a new Frame.
pub fn new<A: ABI>(sig: &ABISig, defined_locals: &DefinedLocals, abi: &A) -> Result<Self> {
let (mut locals, defined_locals_start) = Self::compute_arg_slots(sig, abi)?;
pub fn new<A: ABI>(sig: &ABISig, defined_locals: &DefinedLocals) -> Result<Self> {
let (mut locals, defined_locals_start) = Self::compute_arg_slots::<A>(sig)?;

// The defined locals have a zero-based offset by default
// so we need to add the defined locals start to the offset.
Expand All @@ -96,7 +96,7 @@ impl Frame {
let vmctx_slots_size = <A as ABI>::word_bytes();
let vmctx_offset = defined_locals_start + defined_locals.stack_size + vmctx_slots_size;

let locals_size = align_to(vmctx_offset, abi.stack_align().into());
let locals_size = align_to(vmctx_offset, <A as ABI>::stack_align().into());

Ok(Self {
locals,
Expand All @@ -113,7 +113,7 @@ impl Frame {
self.locals.get(index as usize)
}

fn compute_arg_slots<A: ABI>(sig: &ABISig, abi: &A) -> Result<(Locals, u32)> {
fn compute_arg_slots<A: ABI>(sig: &ABISig) -> Result<(Locals, u32)> {
// Go over the function ABI-signature and
// calculate the stack slots.
//
Expand Down Expand Up @@ -142,7 +142,7 @@ impl Frame {
// we want positive addressing from the stack pointer
// for both locals and stack arguments.

let arg_base_offset = abi.arg_base_offset().into();
let arg_base_offset = <A as ABI>::arg_base_offset().into();
let mut next_stack = 0u32;
let slots: Locals = sig
.params
Expand Down
17 changes: 7 additions & 10 deletions winch/codegen/src/isa/aarch64/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ impl RegIndexEnv {

impl ABI for Aarch64ABI {
// TODO change to 16 once SIMD is supported
fn stack_align(&self) -> u8 {
fn stack_align() -> u8 {
8
}

fn call_stack_align(&self) -> u8 {
fn call_stack_align() -> u8 {
16
}

fn arg_base_offset(&self) -> u8 {
fn arg_base_offset() -> u8 {
16
}

Expand All @@ -63,7 +63,7 @@ impl ABI for Aarch64ABI {
64
}

fn sig(&self, wasm_sig: &FuncType, call_conv: &CallingConvention) -> ABISig {
fn sig(wasm_sig: &FuncType, call_conv: &CallingConvention) -> ABISig {
assert!(call_conv.is_apple_aarch64() || call_conv.is_default());

if wasm_sig.results().len() > 1 {
Expand Down Expand Up @@ -162,8 +162,7 @@ mod tests {
fn xreg_abi_sig() {
let wasm_sig = FuncType::new([I32, I64, I32, I64, I32, I32, I64, I32, I64], []);

let abi = Aarch64ABI::default();
let sig = abi.sig(&wasm_sig, &CallingConvention::Default);
let sig = Aarch64ABI::sig(&wasm_sig, &CallingConvention::Default);
let params = sig.params;

match_reg_arg(params.get(0).unwrap(), I32, regs::xreg(0));
Expand All @@ -181,8 +180,7 @@ mod tests {
fn vreg_abi_sig() {
let wasm_sig = FuncType::new([F32, F64, F32, F64, F32, F32, F64, F32, F64], []);

let abi = Aarch64ABI::default();
let sig = abi.sig(&wasm_sig, &CallingConvention::Default);
let sig = Aarch64ABI::sig(&wasm_sig, &CallingConvention::Default);
let params = sig.params;

match_reg_arg(params.get(0).unwrap(), F32, regs::vreg(0));
Expand All @@ -200,8 +198,7 @@ mod tests {
fn mixed_abi_sig() {
let wasm_sig = FuncType::new([F32, I32, I64, F64, I32, F32, F64, F32, F64], []);

let abi = Aarch64ABI::default();
let sig = abi.sig(&wasm_sig, &CallingConvention::Default);
let sig = Aarch64ABI::sig(&wasm_sig, &CallingConvention::Default);
let params = sig.params;

match_reg_arg(params.get(0).unwrap(), F32, regs::vreg(0));
Expand Down
11 changes: 5 additions & 6 deletions winch/codegen/src/isa/aarch64/masm.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use super::{
abi::Aarch64ABI,
address::Address,
asm::{Assembler, Operand},
regs,
};
use crate::{
abi::local::LocalSlot,
abi::{self, local::LocalSlot},
codegen::CodeGenContext,
isa::reg::Reg,
masm::{CalleeKind, DivKind, MacroAssembler as Masm, OperandSize, RegImm, RemKind},
Expand Down Expand Up @@ -54,6 +55,8 @@ impl MacroAssembler {

impl Masm for MacroAssembler {
type Address = Address;
type Ptr = u8;
type ABI = Aarch64ABI;

fn prologue(&mut self) {
let lr = regs::lr();
Expand Down Expand Up @@ -138,8 +141,6 @@ impl Masm for MacroAssembler {

fn call(
&mut self,
_alignment: u32,
_addend: u32,
_stack_args_size: u32,
_load_callee: impl FnMut(&mut Self) -> CalleeKind,
) -> u32 {
Expand Down Expand Up @@ -191,9 +192,7 @@ impl Masm for MacroAssembler {
}

fn push(&mut self, reg: Reg) -> u32 {
// The push is counted as pushing the 64-bit width in
// 64-bit architectures.
let size = 8u32;
let size = <Self::ABI as abi::ABI>::word_bytes();
self.reserve_stack(size);
let address = Address::from_shadow_sp(size as i64);
self.asm.str(reg, address, OperandSize::S64);
Expand Down
7 changes: 3 additions & 4 deletions winch/codegen/src/isa/aarch64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,15 @@ impl TargetIsa for Aarch64 {
let mut body = body.get_binary_reader();
let mut masm = Aarch64Masm::new(self.shared_flags.clone());
let stack = Stack::new();
let abi = abi::Aarch64ABI::default();
let abi_sig = abi.sig(sig, &CallingConvention::Default);
let abi_sig = abi::Aarch64ABI::sig(sig, &CallingConvention::Default);

let defined_locals = DefinedLocals::new(&mut body, validator)?;
let frame = Frame::new(&abi_sig, &defined_locals, &abi)?;
let frame = Frame::new::<abi::Aarch64ABI>(&abi_sig, &defined_locals)?;
// TODO: Add floating point bitmask
let regalloc = RegAlloc::new(RegSet::new(ALL_GPR, 0), scratch());
let codegen_context = CodeGenContext::new(regalloc, stack, &frame);
let env = FuncEnv::new(self.pointer_bytes(), translation);
let mut codegen = CodeGen::new(&mut masm, &abi, codegen_context, env, abi_sig);
let mut codegen = CodeGen::new(&mut masm, codegen_context, env, abi_sig);

codegen.emit(&mut body, validator)?;
Ok(masm.finalize())
Expand Down
Loading

0 comments on commit f70b0f3

Please sign in to comment.