Skip to content

Commit

Permalink
Generate generic wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyichao committed May 28, 2015
1 parent d0aa4ca commit 93776b0
Showing 1 changed file with 153 additions and 26 deletions.
179 changes: 153 additions & 26 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ static Function *resetstkoflw_func;
#endif
static Function *diff_gc_total_bytes_func;
static Function *show_execution_point_func;
static Function *get_specfuncptr_func;

static std::vector<Type *> two_pvalue_llvmt;
static std::vector<Type *> three_pvalue_llvmt;
Expand Down Expand Up @@ -717,6 +718,17 @@ static void jl_setup_module(Module *m, bool add)
}
}

extern "C" DLLEXPORT void*
jl_get_specfunction_ptr(jl_function_t *f)
{
if (!f->linfo->specFunctionPtr) {
jl_generate_fptr(f);
((Function*)f->linfo->functionObject)->dump();
assert(f->linfo->specFunctionPtr);
}
return f->linfo->specFunctionPtr;
}

extern "C" void jl_generate_fptr(jl_function_t *f)
{
// objective: assign li->fptr
Expand Down Expand Up @@ -3731,75 +3743,104 @@ static Function *gen_cfun_wrapper(jl_function_t *ff, jl_value_t *jlrettype, jl_t
return cw;
}

// generate a julia-callable function that calls f (AKA lam)
static Function *gen_jlcall_wrapper(jl_lambda_info_t *lam, jl_expr_t *ast, Function *f)
// generate a generic julia-callable function
// It is uniquely determined by (ftype, nargs, specLTypes, jl_rettype)
static Function *gen_jlcall_wrapper_real(FunctionType *ftype,
size_t nargs,
Type **specLTypes,
jl_value_t *jl_rettype,
jl_value_t **specTypes,
Module *m)
{
std::stringstream funcName;
const std::string &fname = f->getName().str();
funcName << "jlcall_";
if (fname.compare(0, 6, "julia_") == 0)
funcName << fname.substr(6);
else
funcName << fname;
funcName << "jlcall___generic_wrapper__" << globalUnique++;

Function *w = Function::Create(jl_func_sig, imaging_mode ? GlobalVariable::InternalLinkage : GlobalVariable::ExternalLinkage,
funcName.str(), f->getParent());
funcName.str(), m);
addComdat(w);
Function::arg_iterator AI = w->arg_begin();
// const Argument &fArg = *AI++;
// AI++;
Value *fArg = AI++;
Value *argArray = AI++;
//const Argument &argCount = *AI++;
// const Argument &argCount = *AI++;
BasicBlock *b0 = BasicBlock::Create(jl_LLVMContext, "top", w);

builder.SetInsertPoint(b0);
DebugLoc noDbg;
builder.SetCurrentDebugLocation(noDbg);

jl_lambda_info_t *lam = jl_new_lambda_info(NULL, jl_emptysvec);
JL_GC_PUSH1(&lam);
jl_codectx_t ctx;
ctx.linfo = lam;
allocate_gc_frame(0, b0, &ctx);

size_t nargs = jl_array_dim0(jl_lam_args(ast));
size_t nfargs = f->getFunctionType()->getNumParams();
size_t nfargs = ftype->getNumParams();
Value **args = (Value**) alloca(nfargs*sizeof(Value*));
unsigned idx = 0;
for(size_t i=0; i < nargs; i++) {
jl_value_t *ty = jl_nth_slot_type(lam->specTypes, i);
Type *lty = julia_type_to_llvm(ty);
for (size_t i=0; i < nargs; i++) {
Type *lty = specLTypes[i];
if (lty != NULL && type_is_ghost(lty))
continue;
Value *argPtr = builder.CreateGEP(argArray,
ConstantInt::get(T_size, i));
Value *theArg = builder.CreateLoad(argPtr, false);
Value *theNewArg = theArg;
if (lty != NULL && lty != jl_pvalue_llvmt) {
theNewArg = emit_unbox(lty, theArg, ty);
if (specTypes[i]) {
// specTypes[i] is not used
theNewArg = emit_unbox(lty, theArg, specTypes[i]);
}
assert(dyn_cast<UndefValue>(theNewArg) == NULL);
args[idx] = theNewArg;
idx++;
}
// TODO: consider pulling the function pointer out of fArg so these
// wrappers can be reused for different functions of the same type.
Value *theLam = emit_nthptr(
fArg,
(ssize_t)(offsetof(jl_function_t, linfo)/sizeof(void*)),
tbaa_value);

FunctionType *ftype = f->getFunctionType();
Type *fptrtype = PointerType::get(PointerType::get(ftype, 0), 0);

Type *faddrtype = PointerType::get(ftype, 0);
Type *fptrtype = PointerType::get(faddrtype, 0);
Value *theFptr = emit_nthptr_recast(
theLam,
(ssize_t)(offsetof(jl_lambda_info_t, specFunctionPtr)/sizeof(void*)),
tbaa_func,
fptrtype);

Value *fptrInit = builder.CreateIsNotNull(theFptr);
#ifdef LLVM37
fptrInit = builder.CreateCall(expect_func, {fptrInit,
ConstantInt::get(T_int1, 1)});
#else
fptrInit = builder.CreateCall2(expect_func, fptrInit,
ConstantInt::get(T_int1, 1));
#endif

b0 = builder.GetInsertBlock();
BasicBlock *slowBB = BasicBlock::Create(jl_LLVMContext, "fptr_slow", w);
BasicBlock *outBB = BasicBlock::Create(jl_LLVMContext, "fptr_out");
builder.CreateCondBr(fptrInit, outBB, slowBB);

builder.SetInsertPoint(slowBB);
#ifdef LLVM37
Value *fptrSlow = builder.CreateCall(get_specfuncptr_func, {fArg});
#else
Value *fptrSlow = builder.CreateCall(get_specfuncptr_func, fArg);
#endif
fptrSlow = builder.CreateBitCast(fptrSlow, faddrtype);
builder.CreateBr(outBB);

w->getBasicBlockList().push_back(outBB);
builder.SetInsertPoint(outBB);
PHINode *fptrMerge = builder.CreatePHI(faddrtype, 2);
fptrMerge->addIncoming(theFptr, b0);
fptrMerge->addIncoming(fptrSlow, slowBB);

theFptr = fptrMerge;

Value *r = builder.CreateCall(prepare_call(theFptr),
ArrayRef<Value*>(&args[0], nfargs));
if (r->getType() != jl_pvalue_llvmt) {
r = boxed(r, &ctx, jl_ast_rettype(lam, (jl_value_t*)ast));
if (jl_rettype) {
r = boxed(r, &ctx, jl_rettype);
}

// gc pop. Usually this is done when we encounter the return statement
Expand All @@ -3815,9 +3856,85 @@ static Function *gen_jlcall_wrapper(jl_lambda_info_t *lam, jl_expr_t *ast, Funct

FPM->run(*w);

JL_GC_POP();
return w;
}

struct JLCallWrapperKey {
FunctionType *ftype;
std::vector<Type*> specLTypes;
// Do we need to make sure jl_rettype is never garbage collected?
jl_value_t *jl_rettype;
JLCallWrapperKey(FunctionType *_ftype, size_t nargs, Type **_specLTypes,
jl_value_t *_jl_rettype)
: ftype(_ftype),
specLTypes(nargs),
jl_rettype(_jl_rettype)
{
for (size_t i = 0;i < nargs;i++) {
specLTypes[i] = _specLTypes[i];
}
}
bool
operator<(const JLCallWrapperKey &rhs) const
{
if (ftype != rhs.ftype)
return ftype < rhs.ftype;
if (jl_rettype != rhs.jl_rettype)
return jl_rettype < rhs.jl_rettype;
if (specLTypes.size() != rhs.specLTypes.size())
return specLTypes.size() < rhs.specLTypes.size();
for (size_t i = 0;i < specLTypes.size();i++) {
if (specLTypes[i] != rhs.specLTypes[i]) {
return specLTypes[i] < rhs.specLTypes[i];
}
}
return false;
}
};

typedef std::map<JLCallWrapperKey, Function*> JLCallWrapperCache;
static JLCallWrapperCache jlcall_wrapper_cache;

static Function *gen_jlcall_wrapper_cached(FunctionType *ftype,
size_t nargs,
Type **specLTypes,
jl_value_t *jl_rettype,
jl_value_t **specTypes,
Module *m)
{
JLCallWrapperKey key(ftype, nargs, specLTypes, jl_rettype);
JLCallWrapperCache::iterator it = jlcall_wrapper_cache.find(key);
if (it == jlcall_wrapper_cache.end()) {
Function *wrapper = gen_jlcall_wrapper_real(ftype, nargs, specLTypes,
jl_rettype, specTypes, m);
jlcall_wrapper_cache[key] = wrapper;
return wrapper;
}
Function *wrapper = (*it).second;
return wrapper;
}

// generate a julia-callable function that calls f (AKA lam)
static Function *gen_jlcall_wrapper(jl_lambda_info_t *lam, jl_expr_t *ast,
Function *f)
{
size_t nargs = jl_array_dim0(jl_lam_args(ast));
FunctionType *ftype = f->getFunctionType();
jl_value_t *jl_rettype = (ftype->getReturnType() != jl_pvalue_llvmt ?
jl_ast_rettype(lam, (jl_value_t*)ast) : NULL);
jl_value_t **specTypes = (jl_value_t**)alloca(nargs * sizeof(jl_value_t*));
Type **specLTypes = (Type**)alloca(nargs * sizeof(Type*));
for (size_t i = 0;i < nargs;i++) {
jl_value_t *ty = jl_nth_slot_type(lam->specTypes, i);
Type *lty = julia_type_to_llvm(ty);
specLTypes[i] = lty;
specTypes[i] = (lty != NULL && lty != jl_pvalue_llvmt) ? ty : NULL;
}
return gen_jlcall_wrapper_cached(ftype, nargs, specLTypes,
jl_rettype, specTypes, f->getParent());
}

// cstyle = compile with c-callable signature, not jlcall
static Function *emit_function(jl_lambda_info_t *lam)
{
Expand Down Expand Up @@ -4755,6 +4872,8 @@ extern "C" void jl_fptr_to_llvm(void *fptr, jl_lambda_info_t *lam, int specsig)
if (imaging_mode) {
if (!specsig) {
lam->fptr = (jl_fptr_t)fptr; // in imaging mode, it's fine to use the fptr, but we don't want it in the shadow_module
} else {
lam->specFunctionPtr = fptr;
}
}
else {
Expand Down Expand Up @@ -5385,6 +5504,14 @@ static void init_julia_llvm_env(Module *m)
"show_execution_point", m);
add_named_global(show_execution_point_func, (void*)*show_execution_point);

std::vector<Type *> getspec_args(0);
getspec_args.push_back(jl_pvalue_llvmt);
get_specfuncptr_func =
Function::Create(FunctionType::get(T_pint8, getspec_args, false),
Function::ExternalLinkage,
"jl_get_specfunction_ptr", m);
add_named_global(get_specfuncptr_func, (void*)&jl_get_specfunction_ptr);

// set up optimization passes
FPM = new FunctionPassManager(m);

Expand Down

0 comments on commit 93776b0

Please sign in to comment.