Skip to content

Commit

Permalink
Merge pull request #5046 from JuliaLang/kf/llvmcall
Browse files Browse the repository at this point in the history
Add llvmcall
  • Loading branch information
Keno committed Aug 12, 2014
2 parents 631b098 + 38bb525 commit a636295
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 4 deletions.
2 changes: 1 addition & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ export
JULIA_HOME, nothing, Main,
# intrinsics module
Intrinsics
#ccall, cglobal, abs_float, add_float, add_int, and_int, ashr_int,
#ccall, cglobal, llvmcall, abs_float, add_float, add_int, and_int, ashr_int,
#box, bswap_int, checked_fptosi, checked_fptoui, checked_sadd,
#checked_smul, checked_ssub, checked_uadd, checked_umul, checked_usub,
#checked_trunc_sint, checked_trunc_uint,
Expand Down
4 changes: 4 additions & 0 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ t_func[nan_dom_err] = (2, 2, (a, b)->a)
t_func[eval(Core.Intrinsics,:ccall)] =
(3, Inf, (fptr, rt, at, a...)->(is(rt,Type{Void}) ? Nothing :
isType(rt) ? rt.parameters[1] : Any))
t_func[eval(Core.Intrinsics,:llvmcall)] =
(3, Inf, (fptr, rt, at, a...)->(is(rt,Type{Void}) ? Nothing :
isType(rt) ? rt.parameters[1] :
isa(rt,Tuple) ? map(x->x.parameters[1],rt) : Any))
t_func[eval(Core.Intrinsics,:cglobal)] =
(1, 2, (fptr, t...)->(isempty(t) ? Ptr{Void} :
isType(t[1]) ? Ptr{t[1].parameters[1]} : Ptr))
Expand Down
206 changes: 206 additions & 0 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,212 @@ static Value *emit_cglobal(jl_value_t **args, size_t nargs, jl_codectx_t *ctx)
return mark_julia_type(res, rt);
}

// llvmcall(ir, (rettypes...), (argtypes...), args...)
static Value *emit_llvmcall(jl_value_t **args, size_t nargs, jl_codectx_t *ctx)
{

JL_NARGSV(llvmcall, 3)
jl_value_t *rt = NULL, *at = NULL, *ir = NULL;
JL_GC_PUSH3(&ir, &rt, &at);
{
JL_TRY {
at = jl_interpret_toplevel_expr_in(ctx->module, args[3],
&jl_tupleref(ctx->sp,0),
jl_tuple_len(ctx->sp)/2);
}
JL_CATCH {
jl_rethrow_with_add("error interpreting llvmcall return type");
}
}
{
JL_TRY {
rt = jl_interpret_toplevel_expr_in(ctx->module, args[2],
&jl_tupleref(ctx->sp,0),
jl_tuple_len(ctx->sp)/2);
}
JL_CATCH {
jl_rethrow_with_add("error interpreting llvmcall argument tuple");
}
}
{
JL_TRY {
ir = jl_interpret_toplevel_expr_in(ctx->module, args[1],
&jl_tupleref(ctx->sp,0),
jl_tuple_len(ctx->sp)/2);
}
JL_CATCH {
jl_rethrow_with_add("error interpreting IR argument");
}
}
int i = 1;
if (ir == NULL) {
jl_error("Cannot statically evaluate first argument to llvmcall");
}
bool isString = jl_is_byte_string(ir);
bool isPtr = jl_is_cpointer(ir);
if (!isString && !isPtr)
{
jl_error("First argument to llvmcall must be a string or pointer to an LLVM Function");
}

JL_TYPECHK(llvmcall, type, rt);
JL_TYPECHK(llvmcall, tuple, at);
JL_TYPECHK(llvmcall, type, at);

std::stringstream ir_stream;

jl_tuple_t *stt = jl_alloc_tuple(nargs - 3);

for (size_t i = 0; i < nargs-3; ++i)
{
jl_tupleset(stt,i,expr_type(args[4+i],ctx));
}

// Generate arguments
std::string arguments;
llvm::raw_string_ostream argstream(arguments);
jl_tuple_t *tt = (jl_tuple_t*)at;
jl_value_t *rtt = rt;

size_t nargt = jl_tuple_len(tt);
Value *argvals[nargt];
std::vector<llvm::Type*> argtypes;
/*
* Semantics for arguments are as follows:
* If the argument type is immutable (including bitstype), we pass the loaded llvm value
* type. Otherwise we pass a pointer to a jl_value_t.
*/
for (size_t i = 0; i < nargt; ++i)
{
jl_value_t *tti = jl_tupleref(tt,i);
Type *t = julia_type_to_llvm(tti);
argtypes.push_back(t);
if (4+i > nargs)
{
jl_error("Missing arguments to llvmcall!");
}
jl_value_t *argi = args[4+i];
Value *arg;
bool needroot = false;
if (t == jl_pvalue_llvmt || !jl_isbits(tti)) {
arg = emit_expr(argi, ctx, true);
if (t == jl_pvalue_llvmt && arg->getType() != jl_pvalue_llvmt) {
arg = boxed(arg, ctx);
needroot = true;
}
}
else {
arg = emit_unboxed(argi, ctx);
if (jl_is_bitstype(expr_type(argi, ctx))) {
arg = emit_unbox(t, arg, tti);
}
}

#ifdef JL_GC_MARKSWEEP
// make sure args are rooted
if (t == jl_pvalue_llvmt && (needroot || might_need_root(argi))) {
make_gcroot(arg, ctx);
}
#endif
bool mightNeedTempSpace = false;
argvals[i] = julia_to_native(t,tti,arg,argi,false,i,ctx,&mightNeedTempSpace,&mightNeedTempSpace);
}

Function *f;
Type *rettype = julia_type_to_llvm(rtt);
if (isString) {
// Make sure to find a unique name
std::string ir_name;
while(true) {
std::stringstream name;
name << (ctx->f->getName().str()) << i++;
ir_name = name.str();
if(jl_Module->getFunction(ir_name) == NULL)
break;
}

bool first = true;
for (std::vector<Type *>::iterator it = argtypes.begin(); it != argtypes.end(); ++it) {
if(!first)
argstream << ",";
else
first = false;
(*it)->print(argstream);
argstream << " ";
}

std::string rstring;
llvm::raw_string_ostream rtypename(rstring);
rettype->print(rtypename);

ir_stream << "; Number of arguments: " << nargt << "\n"
<< "define "<<rtypename.str()<<" @\"" << ir_name << "\"("<<argstream.str()<<") {\n"
<< jl_string_data(ir) << "\n}";
SMDiagnostic Err = SMDiagnostic();
std::string ir_string = ir_stream.str();
Module *m = ParseAssemblyString(ir_string.data(),jl_Module,Err,jl_LLVMContext);
if (m == NULL) {
std::string message = "Failed to parse LLVM Assembly: \n";
llvm::raw_string_ostream stream(message);
Err.print("julia",stream,true);
jl_error(stream.str().c_str());
}
f = m->getFunction(ir_name);
} else {
assert(isPtr);
// Create Function skeleton
f = (llvm::Function*)jl_unbox_voidpointer(ir);
assert(f->getReturnType() == rettype);
int i = 0;
for (std::vector<Type *>::iterator it = argtypes.begin();
it != argtypes.end(); ++it, ++i)
assert(*it == f->getFunctionType()->getParamType(i));

#ifdef USE_MCJIT
if (f->getParent() != jl_Module)
{
FunctionMover mover(jl_Module,f->getParent());
f = (llvm::Function*)MapValue(f,mover.VMap,RF_None,NULL,&mover);
}
#endif

//f->dump();
#ifndef LLVM35
if (verifyFunction(*f,PrintMessageAction)) {
#else
llvm::raw_fd_ostream out(1,false);
if (verifyFunction(*f,&out))
{
#endif
f->dump();
jl_error("Malformed LLVM Function");
}
}

/*
* It might be tempting to just try to set the Always inline attribute on the function
* and hope for the best. However, this doesn't work since that would require an inlining
* pass (which is a Call Graph pass and cannot be managed by a FunctionPassManager). Instead
* We are sneaky and call the inliner directly. This however doesn't work until we've actually
* generated the entire function, so we need to store it in the context until the end of the
* function. This also has the benefit of looking exactly like we cut/pasted it in in `code_llvm`.
*/
f->setLinkage(GlobalValue::LinkOnceODRLinkage);

// the actual call
CallInst *inst = builder.CreateCall(prepare_call(f),ArrayRef<Value*>(&argvals[0],nargt));
ctx->to_inline.push_back(inst);

JL_GC_POP();

if(inst->getType() != rettype)
{
jl_error("Return type of llvmcall'ed function does not match declared return type");
}

return mark_julia_type(inst,rtt);
}

// --- code generator for ccall itself ---

// ccall(pointer, rettype, (argtypes...), args...)
Expand Down
1 change: 1 addition & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "llvm/Target/TargetMachine.h"
#else
#include "llvm/Analysis/Verifier.h"
#include "llvm/Assembly/Parser.h"
#endif
#include "llvm/DebugInfo/DIContext.h"
#if defined(LLVM_VERSION_MAJOR) && LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 4
Expand Down
4 changes: 3 additions & 1 deletion src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace JL_I {
// pointer access
pointerref, pointerset, pointertoref,
// c interface
ccall, cglobal, jl_alloca
ccall, cglobal, jl_alloca, llvmcall
};
};

Expand Down Expand Up @@ -831,6 +831,7 @@ static Value *emit_intrinsic(intrinsic f, jl_value_t **args, size_t nargs,
switch (f) {
case ccall: return emit_ccall(args, nargs, ctx);
case cglobal: return emit_cglobal(args, nargs, ctx);
case llvmcall: return emit_llvmcall(args, nargs, ctx);

HANDLE(box,2) return generic_box(args[1], args[2], ctx);
HANDLE(unbox,2) return generic_unbox(args[1], args[2], ctx);
Expand Down Expand Up @@ -1459,4 +1460,5 @@ extern "C" void jl_init_intrinsic_functions(void)
ADD_I(nan_dom_err);
ADD_I(ccall); ADD_I(cglobal);
ADD_I(jl_alloca);
ADD_I(llvmcall);
}
2 changes: 1 addition & 1 deletion test/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ TESTS = all core keywordargs numbers strings unicode collections hashing \
git pkg resolve suitesparse complex version pollfd mpfr broadcast \
socket floatapprox priorityqueue readdlm regex float16 combinatorics \
sysinfo rounding ranges mod2pi euler show lineedit \
replcompletions backtrace repl test goto
replcompletions backtrace repl test goto llvmcall

default: all

Expand Down
33 changes: 33 additions & 0 deletions test/llvmcall.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using Base.llvmcall

function add1234(x::(Int32,Int32,Int32,Int32))
llvmcall("""%3 = add <4 x i32> %1, %0
ret <4 x i32> %3""",(Int32,Int32,Int32,Int32),
((Int32,Int32,Int32,Int32),(Int32,Int32,Int32,Int32)),
(int32(1),int32(2),int32(3),int32(4)),
x)
end

function add1234(x::NTuple{4,Int64})
llvmcall("""%3 = add <4 x i64> %1, %0
ret <4 x i64> %3""",NTuple{4,Int64},
(NTuple{4,Int64},NTuple{4,Int64}),
(int64(1),int64(2),int64(3),int64(4)),
x)
end

@test add1234(map(int32,(2,3,4,5))) === map(int32,(3,5,7,9))
@test add1234(map(int64,(2,3,4,5))) === map(int64,(3,5,7,9))

# Test whether llvmcall escapes the function name correctly
baremodule PlusTest
using Base.llvmcall
using Base.Test
using Base

function +(x::Int32, y::Int32)
llvmcall("""%3 = add i32 %1, %0
ret i32 %3""", Int32, (Int32, Int32), x, y)
end
@test int32(1)+int32(2)==int32(3)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ testnames = [
"resolve", "pollfd", "mpfr", "broadcast", "complex", "socket",
"floatapprox", "readdlm", "regex", "float16", "combinatorics",
"sysinfo", "rounding", "ranges", "mod2pi", "euler", "show",
"lineedit", "replcompletions", "repl", "test", "examples", "goto"
"lineedit", "replcompletions", "repl", "test", "examples", "goto", "llvmcall"
]
@unix_only push!(testnames, "unicode")

Expand Down

0 comments on commit a636295

Please sign in to comment.