diff --git a/src/codegen.cpp b/src/codegen.cpp index 8d415b3376e7f..f16c669486c7f 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -334,6 +334,7 @@ static Function *resetstkoflw_func; #endif static Function *diff_gc_total_bytes_func; static Function *show_execution_point_func; +static Function *get_specfuncptr_func; static std::vector two_pvalue_llvmt; static std::vector three_pvalue_llvmt; @@ -717,6 +718,17 @@ static void jl_setup_module(Module *m, bool add) } } +extern "C" DLLEXPORT void* +jl_get_specfunction_ptr(jl_function_t *f) +{ + if (!f->linfo->specFunctionPtr) { + jl_generate_fptr(f); + ((Function*)f->linfo->functionObject)->dump(); + assert(f->linfo->specFunctionPtr); + } + return f->linfo->specFunctionPtr; +} + extern "C" void jl_generate_fptr(jl_function_t *f) { // objective: assign li->fptr @@ -3731,75 +3743,104 @@ static Function *gen_cfun_wrapper(jl_function_t *ff, jl_value_t *jlrettype, jl_t return cw; } -// generate a julia-callable function that calls f (AKA lam) -static Function *gen_jlcall_wrapper(jl_lambda_info_t *lam, jl_expr_t *ast, Function *f) +// generate a generic julia-callable function +// It is uniquely determined by (ftype, nargs, specLTypes, jl_rettype) +static Function *gen_jlcall_wrapper_real(FunctionType *ftype, + size_t nargs, + Type **specLTypes, + jl_value_t *jl_rettype, + jl_value_t **specTypes, + Module *m) { std::stringstream funcName; - const std::string &fname = f->getName().str(); - funcName << "jlcall_"; - if (fname.compare(0, 6, "julia_") == 0) - funcName << fname.substr(6); - else - funcName << fname; + funcName << "jlcall___generic_wrapper__" << globalUnique++; Function *w = Function::Create(jl_func_sig, imaging_mode ? GlobalVariable::InternalLinkage : GlobalVariable::ExternalLinkage, - funcName.str(), f->getParent()); + funcName.str(), m); addComdat(w); Function::arg_iterator AI = w->arg_begin(); - // const Argument &fArg = *AI++; - // AI++; Value *fArg = AI++; Value *argArray = AI++; - //const Argument &argCount = *AI++; + // const Argument &argCount = *AI++; BasicBlock *b0 = BasicBlock::Create(jl_LLVMContext, "top", w); builder.SetInsertPoint(b0); DebugLoc noDbg; builder.SetCurrentDebugLocation(noDbg); + jl_lambda_info_t *lam = jl_new_lambda_info(NULL, jl_emptysvec); + JL_GC_PUSH1(&lam); jl_codectx_t ctx; ctx.linfo = lam; allocate_gc_frame(0, b0, &ctx); - size_t nargs = jl_array_dim0(jl_lam_args(ast)); - size_t nfargs = f->getFunctionType()->getNumParams(); + size_t nfargs = ftype->getNumParams(); Value **args = (Value**) alloca(nfargs*sizeof(Value*)); unsigned idx = 0; - for(size_t i=0; i < nargs; i++) { - jl_value_t *ty = jl_nth_slot_type(lam->specTypes, i); - Type *lty = julia_type_to_llvm(ty); + for (size_t i=0; i < nargs; i++) { + Type *lty = specLTypes[i]; if (lty != NULL && type_is_ghost(lty)) continue; Value *argPtr = builder.CreateGEP(argArray, ConstantInt::get(T_size, i)); Value *theArg = builder.CreateLoad(argPtr, false); Value *theNewArg = theArg; - if (lty != NULL && lty != jl_pvalue_llvmt) { - theNewArg = emit_unbox(lty, theArg, ty); + if (specTypes[i]) { + // specTypes[i] is not used + theNewArg = emit_unbox(lty, theArg, specTypes[i]); } assert(dyn_cast(theNewArg) == NULL); args[idx] = theNewArg; idx++; } - // TODO: consider pulling the function pointer out of fArg so these - // wrappers can be reused for different functions of the same type. Value *theLam = emit_nthptr( fArg, (ssize_t)(offsetof(jl_function_t, linfo)/sizeof(void*)), tbaa_value); - FunctionType *ftype = f->getFunctionType(); - Type *fptrtype = PointerType::get(PointerType::get(ftype, 0), 0); - + Type *faddrtype = PointerType::get(ftype, 0); + Type *fptrtype = PointerType::get(faddrtype, 0); Value *theFptr = emit_nthptr_recast( theLam, (ssize_t)(offsetof(jl_lambda_info_t, specFunctionPtr)/sizeof(void*)), tbaa_func, fptrtype); + + Value *fptrInit = builder.CreateIsNotNull(theFptr); +#ifdef LLVM37 + fptrInit = builder.CreateCall(expect_func, {fptrInit, + ConstantInt::get(T_int1, 1)}); +#else + fptrInit = builder.CreateCall2(expect_func, fptrInit, + ConstantInt::get(T_int1, 1)); +#endif + + b0 = builder.GetInsertBlock(); + BasicBlock *slowBB = BasicBlock::Create(jl_LLVMContext, "fptr_slow", w); + BasicBlock *outBB = BasicBlock::Create(jl_LLVMContext, "fptr_out"); + builder.CreateCondBr(fptrInit, outBB, slowBB); + + builder.SetInsertPoint(slowBB); +#ifdef LLVM37 + Value *fptrSlow = builder.CreateCall(get_specfuncptr_func, {fArg}); +#else + Value *fptrSlow = builder.CreateCall(get_specfuncptr_func, fArg); +#endif + fptrSlow = builder.CreateBitCast(fptrSlow, faddrtype); + builder.CreateBr(outBB); + + w->getBasicBlockList().push_back(outBB); + builder.SetInsertPoint(outBB); + PHINode *fptrMerge = builder.CreatePHI(faddrtype, 2); + fptrMerge->addIncoming(theFptr, b0); + fptrMerge->addIncoming(fptrSlow, slowBB); + + theFptr = fptrMerge; + Value *r = builder.CreateCall(prepare_call(theFptr), ArrayRef(&args[0], nfargs)); - if (r->getType() != jl_pvalue_llvmt) { - r = boxed(r, &ctx, jl_ast_rettype(lam, (jl_value_t*)ast)); + if (jl_rettype) { + r = boxed(r, &ctx, jl_rettype); } // gc pop. Usually this is done when we encounter the return statement @@ -3815,9 +3856,85 @@ static Function *gen_jlcall_wrapper(jl_lambda_info_t *lam, jl_expr_t *ast, Funct FPM->run(*w); + JL_GC_POP(); return w; } +struct JLCallWrapperKey { + FunctionType *ftype; + std::vector specLTypes; + // Do we need to make sure jl_rettype is never garbage collected? + jl_value_t *jl_rettype; + JLCallWrapperKey(FunctionType *_ftype, size_t nargs, Type **_specLTypes, + jl_value_t *_jl_rettype) + : ftype(_ftype), + specLTypes(nargs), + jl_rettype(_jl_rettype) + { + for (size_t i = 0;i < nargs;i++) { + specLTypes[i] = _specLTypes[i]; + } + } + bool + operator<(const JLCallWrapperKey &rhs) const + { + if (ftype != rhs.ftype) + return ftype < rhs.ftype; + if (jl_rettype != rhs.jl_rettype) + return jl_rettype < rhs.jl_rettype; + if (specLTypes.size() != rhs.specLTypes.size()) + return specLTypes.size() < rhs.specLTypes.size(); + for (size_t i = 0;i < specLTypes.size();i++) { + if (specLTypes[i] != rhs.specLTypes[i]) { + return specLTypes[i] < rhs.specLTypes[i]; + } + } + return false; + } +}; + +typedef std::map JLCallWrapperCache; +static JLCallWrapperCache jlcall_wrapper_cache; + +static Function *gen_jlcall_wrapper_cached(FunctionType *ftype, + size_t nargs, + Type **specLTypes, + jl_value_t *jl_rettype, + jl_value_t **specTypes, + Module *m) +{ + JLCallWrapperKey key(ftype, nargs, specLTypes, jl_rettype); + JLCallWrapperCache::iterator it = jlcall_wrapper_cache.find(key); + if (it == jlcall_wrapper_cache.end()) { + Function *wrapper = gen_jlcall_wrapper_real(ftype, nargs, specLTypes, + jl_rettype, specTypes, m); + jlcall_wrapper_cache[key] = wrapper; + return wrapper; + } + Function *wrapper = (*it).second; + return wrapper; +} + +// generate a julia-callable function that calls f (AKA lam) +static Function *gen_jlcall_wrapper(jl_lambda_info_t *lam, jl_expr_t *ast, + Function *f) +{ + size_t nargs = jl_array_dim0(jl_lam_args(ast)); + FunctionType *ftype = f->getFunctionType(); + jl_value_t *jl_rettype = (ftype->getReturnType() != jl_pvalue_llvmt ? + jl_ast_rettype(lam, (jl_value_t*)ast) : NULL); + jl_value_t **specTypes = (jl_value_t**)alloca(nargs * sizeof(jl_value_t*)); + Type **specLTypes = (Type**)alloca(nargs * sizeof(Type*)); + for (size_t i = 0;i < nargs;i++) { + jl_value_t *ty = jl_nth_slot_type(lam->specTypes, i); + Type *lty = julia_type_to_llvm(ty); + specLTypes[i] = lty; + specTypes[i] = (lty != NULL && lty != jl_pvalue_llvmt) ? ty : NULL; + } + return gen_jlcall_wrapper_cached(ftype, nargs, specLTypes, + jl_rettype, specTypes, f->getParent()); +} + // cstyle = compile with c-callable signature, not jlcall static Function *emit_function(jl_lambda_info_t *lam) { @@ -4755,6 +4872,8 @@ extern "C" void jl_fptr_to_llvm(void *fptr, jl_lambda_info_t *lam, int specsig) if (imaging_mode) { if (!specsig) { lam->fptr = (jl_fptr_t)fptr; // in imaging mode, it's fine to use the fptr, but we don't want it in the shadow_module + } else { + lam->specFunctionPtr = fptr; } } else { @@ -5385,6 +5504,14 @@ static void init_julia_llvm_env(Module *m) "show_execution_point", m); add_named_global(show_execution_point_func, (void*)*show_execution_point); + std::vector getspec_args(0); + getspec_args.push_back(jl_pvalue_llvmt); + get_specfuncptr_func = + Function::Create(FunctionType::get(T_pint8, getspec_args, false), + Function::ExternalLinkage, + "jl_get_specfunction_ptr", m); + add_named_global(get_specfuncptr_func, (void*)&jl_get_specfunction_ptr); + // set up optimization passes FPM = new FunctionPassManager(m);