Skip to content

Commit

Permalink
Allow creating pointers to module methods
Browse files Browse the repository at this point in the history
If a module method is defined using "extern" and has a body, the
compiler guarantees calls of the method use the C calling convention.
This in turn makes it safe/possible to pass pointers to such methods to
C libraries.

As per this commit this isn't terribly useful just yet, as the compiler
still passes the process and state variables as hidden arguments. This
will be solved in a separate commit.

This fixes #693.

Changelog: added
  • Loading branch information
yorickpeterse committed Feb 18, 2024
1 parent 6063bba commit 01bd04e
Show file tree
Hide file tree
Showing 14 changed files with 283 additions and 20 deletions.
32 changes: 29 additions & 3 deletions ast/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -866,13 +866,17 @@ impl Parser {
self.optional_type_parameter_definitions()?
};
let arguments = self.optional_method_arguments(allow_variadic)?;
let variadic = arguments.as_ref().map_or(false, |v| v.variadic);
let return_type = self.optional_return_type()?;
let body = if let MethodKind::Extern = kind {
None
} else {
let body = if (self.peek().kind == TokenKind::CurlyOpen
|| kind != MethodKind::Extern)
&& !variadic
{
let token = self.expect(TokenKind::CurlyOpen)?;

Some(self.expressions(token)?)
} else {
None
};

let location = SourceLocation::start_end(
Expand Down Expand Up @@ -4465,6 +4469,27 @@ mod tests {
}))
);

assert_eq!(
top(parse("fn extern foo {}")),
TopLevelExpression::DefineMethod(Box::new(DefineMethod {
public: false,
operator: false,
kind: MethodKind::Extern,
name: Identifier {
name: "foo".to_string(),
location: cols(11, 13)
},
type_parameters: None,
arguments: None,
return_type: None,
body: Some(Expressions {
values: Vec::new(),
location: cols(15, 16)
}),
location: cols(1, 16),
}))
);

assert_eq!(
top(parse("fn extern foo(...)")),
TopLevelExpression::DefineMethod(Box::new(DefineMethod {
Expand Down Expand Up @@ -4498,6 +4523,7 @@ mod tests {
assert_error!("fn foo {", cols(8, 8));
assert_error!("fn foo", cols(6, 6));
assert_error!("fn extern foo[T](arg: T)", cols(14, 14));
assert_error!("fn extern foo(...) {}", cols(20, 20));
}

#[test]
Expand Down
63 changes: 62 additions & 1 deletion compiler/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ pub(crate) struct DefineInstanceMethod {
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct DefineModuleMethod {
pub(crate) public: bool,
pub(crate) c_calling_convention: bool,
pub(crate) name: Identifier,
pub(crate) type_parameters: Vec<TypeParameter>,
pub(crate) arguments: Vec<MethodArgument>,
Expand Down Expand Up @@ -846,6 +847,7 @@ pub(crate) struct Ref {

#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct Mut {
pub(crate) pointer_to_method: Option<types::MethodId>,
pub(crate) resolved_type: types::TypeRef,
pub(crate) value: Expression,
pub(crate) location: SourceLocation,
Expand Down Expand Up @@ -1133,7 +1135,9 @@ impl<'a> LowerToHir<'a> {
) -> TopLevelExpression {
self.operator_method_not_allowed(node.operator, &node.location);

if let ast::MethodKind::Extern = node.kind {
let external = matches!(node.kind, ast::MethodKind::Extern);

if external && node.body.is_none() {
TopLevelExpression::ExternFunction(Box::new(DefineExternFunction {
public: node.public,
name: self.identifier(node.name),
Expand All @@ -1146,6 +1150,7 @@ impl<'a> LowerToHir<'a> {
} else {
TopLevelExpression::ModuleMethod(Box::new(DefineModuleMethod {
public: node.public,
c_calling_convention: external,
name: self.identifier(node.name),
type_parameters: self
.optional_type_parameters(node.type_parameters),
Expand Down Expand Up @@ -2624,6 +2629,7 @@ impl<'a> LowerToHir<'a> {

fn mut_reference(&mut self, node: ast::Mut) -> Box<Mut> {
Box::new(Mut {
pointer_to_method: None,
resolved_type: types::TypeRef::Unknown,
value: self.expression(node.value),
location: node.location,
Expand Down Expand Up @@ -3436,6 +3442,7 @@ mod tests {
hir,
TopLevelExpression::ModuleMethod(Box::new(DefineModuleMethod {
public: false,
c_calling_convention: false,
name: Identifier {
name: "foo".to_string(),
location: cols(4, 6)
Expand Down Expand Up @@ -3521,6 +3528,59 @@ mod tests {
);
}

#[test]
fn test_lower_extern_method_with_body() {
let (hir, diags) = lower_top_expr("fn extern foo(a: A) -> B { 10 }");

assert_eq!(diags, 0);
assert_eq!(
hir,
TopLevelExpression::ModuleMethod(Box::new(DefineModuleMethod {
public: false,
c_calling_convention: true,
name: Identifier {
name: "foo".to_string(),
location: cols(11, 13)
},
type_parameters: Vec::new(),
arguments: vec![MethodArgument {
name: Identifier {
name: "a".to_string(),
location: cols(15, 15)
},
value_type: Type::Named(Box::new(TypeName {
source: None,
resolved_type: types::TypeRef::Unknown,
name: Constant {
name: "A".to_string(),
location: cols(18, 18)
},
arguments: Vec::new(),
location: cols(18, 18)
})),
location: cols(15, 18)
}],
return_type: Some(Type::Named(Box::new(TypeName {
source: None,
resolved_type: types::TypeRef::Unknown,
name: Constant {
name: "B".to_string(),
location: cols(24, 24)
},
arguments: Vec::new(),
location: cols(24, 24)
}))),
body: vec![Expression::Int(Box::new(IntLiteral {
value: 10,
resolved_type: types::TypeRef::Unknown,
location: cols(28, 29)
}))],
method_id: None,
location: cols(1, 31),
})),
);
}

#[test]
fn test_lower_extern_variadic_function() {
let (hir, diags) = lower_top_expr("fn extern foo(...)");
Expand Down Expand Up @@ -5672,6 +5732,7 @@ mod tests {
assert_eq!(
hir,
Expression::Mut(Box::new(Mut {
pointer_to_method: None,
resolved_type: types::TypeRef::Unknown,
value: Expression::Int(Box::new(IntLiteral {
value: 10,
Expand Down
36 changes: 27 additions & 9 deletions compiler/src/llvm/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use inkwell::types::{
BasicMetadataTypeEnum, BasicType, FunctionType, StructType,
};
use inkwell::AddressSpace;
use types::{BOOL_ID, BYTE_ARRAY_ID, FLOAT_ID, INT_ID, NIL_ID, STRING_ID};
use types::{
CallConvention, BOOL_ID, BYTE_ARRAY_ID, FLOAT_ID, INT_ID, NIL_ID, STRING_ID,
};

/// The size of an object header.
const HEADER_SIZE: u32 = 16;
Expand All @@ -16,6 +18,9 @@ const HEADER_SIZE: u32 = 16;
pub(crate) struct Method<'ctx> {
pub(crate) signature: FunctionType<'ctx>,

/// The calling convention to use for this method.
pub(crate) call_convention: CallConvention,

/// If the function returns a structure on the stack, its type is stored
/// here.
///
Expand Down Expand Up @@ -161,6 +166,7 @@ impl<'ctx> Layouts<'ctx> {
// KiB.
let num_methods = db.number_of_methods();
let dummy_method = Method {
call_convention: CallConvention::Inko,
signature: context.void_type().fn_type(&[], false),
struct_return: None,
};
Expand Down Expand Up @@ -273,8 +279,11 @@ impl<'ctx> Layouts<'ctx> {
context.void_type().fn_type(&args, false)
});

layouts.methods[method.0 as usize] =
Method { signature, struct_return: None };
layouts.methods[method.0 as usize] = Method {
call_convention: CallConvention::new(method.is_extern(db)),
signature,
struct_return: None,
};
}
}

Expand Down Expand Up @@ -318,8 +327,11 @@ impl<'ctx> Layouts<'ctx> {
})
};

layouts.methods[method.0 as usize] =
Method { signature: typ, struct_return: None };
layouts.methods[method.0 as usize] = Method {
call_convention: CallConvention::new(method.is_extern(db)),
signature: typ,
struct_return: None,
};
}
}

Expand All @@ -338,8 +350,11 @@ impl<'ctx> Layouts<'ctx> {
.map(|t| t.fn_type(&args, false))
.unwrap_or_else(|| context.void_type().fn_type(&args, false));

layouts.methods[method.0 as usize] =
Method { signature: typ, struct_return: None };
layouts.methods[method.0 as usize] = Method {
call_convention: CallConvention::new(method.is_extern(db)),
signature: typ,
struct_return: None,
};
}

for &method in &mir.extern_methods {
Expand Down Expand Up @@ -387,8 +402,11 @@ impl<'ctx> Layouts<'ctx> {
context.void_type().fn_type(&args, variadic)
});

layouts.methods[method.0 as usize] =
Method { signature: sig, struct_return: sret };
layouts.methods[method.0 as usize] = Method {
call_convention: CallConvention::C,
signature: sig,
struct_return: sret,
};
}

layouts
Expand Down
12 changes: 11 additions & 1 deletion compiler/src/llvm/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::collections::HashMap;
use std::ops::Deref;
use std::path::Path;
use types::module_name::ModuleName;
use types::{ClassId, MethodId};
use types::{CallConvention, ClassId, MethodId};

/// A wrapper around an LLVM Module that provides some additional methods.
pub(crate) struct Module<'a, 'ctx> {
Expand Down Expand Up @@ -105,6 +105,16 @@ impl<'a, 'ctx> Module<'a, 'ctx> {
self.inner.get_function(name).unwrap_or_else(|| {
let info = &self.layouts.methods[method.0 as usize];
let func = self.inner.add_function(name, info.signature, None);
let conv = match info.call_convention {
// LLVM uses 0 for the C calling convention.
CallConvention::C => 0,

// For the time being the Inko calling convention is the same as
// the C calling convention, but this may change in the future.
CallConvention::Inko => 0,
};

func.set_call_conventions(conv);

if let Some(typ) = info.struct_return {
let sret = self.context.type_attribute("sret", typ.into());
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/llvm/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2347,6 +2347,14 @@ impl<'a, 'b, 'mir, 'ctx> LowerMethod<'a, 'b, 'mir, 'ctx> {

self.builder.store(reg_var, addr);
}
Instruction::MethodPointer(ins) => {
let reg_var = self.variables[&ins.register];
let func_name = &self.names.methods[&ins.method];
let func = self.module.add_method(func_name, ins.method);
let ptr = func.as_global_value().as_pointer_value();

self.builder.store(reg_var, ptr);
}
Instruction::SetField(ins) => {
let rec_var = self.variables[&ins.receiver];
let rec_typ = self.variable_types[&ins.receiver];
Expand Down
27 changes: 27 additions & 0 deletions compiler/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,17 @@ impl Block {
)));
}

pub(crate) fn method_pointer(
&mut self,
register: RegisterId,
method: types::MethodId,
location: LocationId,
) {
self.instructions.push(Instruction::MethodPointer(Box::new(
MethodPointer { register, method, location },
)))
}

pub(crate) fn read_pointer(
&mut self,
register: RegisterId,
Expand Down Expand Up @@ -1035,6 +1046,13 @@ pub(crate) struct Pointer {
pub(crate) location: LocationId,
}

#[derive(Clone, Debug, Copy)]
pub(crate) struct MethodPointer {
pub(crate) register: RegisterId,
pub(crate) method: types::MethodId,
pub(crate) location: LocationId,
}

#[derive(Clone)]
pub(crate) struct FieldPointer {
pub(crate) class: types::ClassId,
Expand Down Expand Up @@ -1103,6 +1121,7 @@ pub(crate) enum Instruction {
ReadPointer(Box<ReadPointer>),
WritePointer(Box<WritePointer>),
FieldPointer(Box<FieldPointer>),
MethodPointer(Box<MethodPointer>),
}

impl Instruction {
Expand Down Expand Up @@ -1147,6 +1166,7 @@ impl Instruction {
Instruction::ReadPointer(ref v) => v.location,
Instruction::WritePointer(ref v) => v.location,
Instruction::FieldPointer(ref v) => v.location,
Instruction::MethodPointer(ref v) => v.location,
}
}

Expand Down Expand Up @@ -1344,6 +1364,13 @@ impl Instruction {
v.field.name(db)
)
}
Instruction::MethodPointer(v) => {
format!(
"r{} = method_pointer {}",
v.register.0,
method_name(db, v.method)
)
}
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/mir/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2442,7 +2442,13 @@ impl<'a> LowerMethod<'a> {
}

fn mut_expression(&mut self, node: hir::Mut) -> RegisterId {
if node.resolved_type.is_pointer(self.db()) {
if let Some(id) = node.pointer_to_method {
let loc = self.add_location(node.location);
let reg = self.new_register(node.resolved_type);

self.current_block_mut().method_pointer(reg, id, loc);
reg
} else if node.resolved_type.is_pointer(self.db()) {
let loc = self.add_location(node.location);
let val = self.expression(node.value);
let reg = self.new_register(node.resolved_type);
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/mir/specialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,12 @@ impl<'a, 'b> Specialize<'a, 'b> {
.class_id(&self.state.db)
.unwrap();
}
Instruction::MethodPointer(ins) => {
let rec = ins.method.receiver(&self.state.db);
let cls = rec.class_id(&self.state.db).unwrap();

ins.method = self.call_static(cls, ins.method, None);
}
Instruction::Cast(ins) => {
let from = method.registers.value_type(ins.source);
let to = method.registers.value_type(ins.register);
Expand Down
Loading

0 comments on commit 01bd04e

Please sign in to comment.