Skip to content

Commit 1be6ee4

Browse files
committed
cfunction: reimplement, as originally planned, for reliable performance
This implements several sources of bugfixes and improvements: - direct edges are correctly represented - performance does not degrade when edges trigger - the JIT does not call `jl_infer_type` until actually required - constant return can handle invalidation and emitting efficient code This generates the code for the exact signature specified by the user, instead of using `jl_infer_type` to compute a different signature which previously might cause unnecessary boxing and previously introduced unstable performance characteristics. This lets us defer generating the actual thunk required at runtime with the JIT when the required information is already available, and also to validate that information is still correct, and regenerate it when not correct anymore.
1 parent fea26dd commit 1be6ee4

File tree

13 files changed

+924
-325
lines changed

13 files changed

+924
-325
lines changed

Compiler/src/typeinfer.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,7 @@ function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::
12661266
tocompile = Vector{CodeInstance}()
12671267
codeinfos = []
12681268
# first compute the ABIs of everything
1269+
latest = true # whether this_world == world_counter()
12691270
for this_world in reverse(sort!(worlds))
12701271
interp = NativeInterpreter(this_world)
12711272
for i = 1:length(methods)
@@ -1278,18 +1279,18 @@ function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::
12781279
# then we want to compile and emit this
12791280
if item.def.primary_world <= this_world <= item.def.deleted_world
12801281
ci = typeinf_ext(interp, item, SOURCE_MODE_NOT_REQUIRED)
1281-
ci isa CodeInstance && !use_const_api(ci) && push!(tocompile, ci)
1282+
ci isa CodeInstance && push!(tocompile, ci)
12821283
end
1283-
elseif item isa SimpleVector
1284+
elseif item isa SimpleVector && latest
12841285
(rt::Type, sig::Type) = item
12851286
# make a best-effort attempt to enqueue the relevant code for the ccallable
12861287
ptr = ccall(:jl_get_specialization1,
12871288
#= MethodInstance =# Ptr{Cvoid}, (Any, Csize_t, Cint),
12881289
sig, this_world, #= mt_cache =# 0)
12891290
if ptr !== C_NULL
1290-
mi = unsafe_pointer_to_objref(ptr)
1291+
mi = unsafe_pointer_to_objref(ptr)::MethodInstance
12911292
ci = typeinf_ext(interp, mi, SOURCE_MODE_NOT_REQUIRED)
1292-
ci isa CodeInstance && !use_const_api(ci) && push!(tocompile, ci)
1293+
ci isa CodeInstance && push!(tocompile, ci)
12931294
end
12941295
# additionally enqueue the ccallable entrypoint / adapter, which implicitly
12951296
# invokes the above ci
@@ -1305,7 +1306,7 @@ function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::
13051306
mi = get_ci_mi(callee)
13061307
def = mi.def
13071308
if use_const_api(callee)
1308-
src = codeinfo_for_const(interp, mi, code.rettype_const)
1309+
src = codeinfo_for_const(interp, mi, callee.rettype_const)
13091310
elseif haskey(interp.codegen, callee)
13101311
src = interp.codegen[callee]
13111312
elseif isa(def, Method) && ccall(:jl_get_module_infer, Cint, (Any,), def.module) == 0 && !trim
@@ -1327,6 +1328,7 @@ function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::
13271328
println("warning: failed to get code for ", mi)
13281329
end
13291330
end
1331+
latest = false
13301332
end
13311333
return codeinfos
13321334
end

src/aotcompile.cpp

Lines changed: 174 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ static void resolve_workqueue(jl_codegen_params_t &params, egal_set &method_root
423423
if (decls.functionObject == "jl_fptr_args") {
424424
preal_decl = decls.specFunctionObject;
425425
}
426-
else if (decls.functionObject != "jl_fptr_sparam" && decls.functionObject != "jl_f_opaque_closure_call") {
426+
else if (decls.functionObject != "jl_fptr_sparam" && decls.functionObject != "jl_f_opaque_closure_call" && decls.functionObject != "jl_fptr_const_return") {
427427
preal_decl = decls.specFunctionObject;
428428
preal_specsig = true;
429429
}
@@ -439,6 +439,13 @@ static void resolve_workqueue(jl_codegen_params_t &params, egal_set &method_root
439439
Module *mod = proto.decl->getParent();
440440
assert(proto.decl->isDeclaration());
441441
Function *pinvoke = nullptr;
442+
if (preal_decl.empty() && jl_atomic_load_relaxed(&codeinst->invoke) == jl_fptr_const_return_addr) {
443+
std::string gf_thunk_name = emit_abi_constreturn(mod, params, proto.specsig, codeinst);
444+
preal_specsig = proto.specsig;
445+
if (invokeName.empty())
446+
invokeName = "jl_fptr_const_return";
447+
preal_decl = mod->getNamedValue(gf_thunk_name)->getName();
448+
}
442449
if (preal_decl.empty()) {
443450
if (invokeName.empty() && params.params->trim) {
444451
jl_safe_printf("warning: bailed out to invoke when compiling: ");
@@ -483,6 +490,7 @@ static void resolve_workqueue(jl_codegen_params_t &params, egal_set &method_root
483490
ocinvokeDecl = pinvoke->getName();
484491
assert(!ocinvokeDecl.empty());
485492
assert(ocinvokeDecl != "jl_fptr_args");
493+
assert(ocinvokeDecl != "jl_fptr_const_return");
486494
assert(ocinvokeDecl != "jl_fptr_sparam");
487495
// merge and/or rename this prototype to the real function
488496
if (Value *specfun = mod->getNamedValue(ocinvokeDecl)) {
@@ -499,6 +507,134 @@ static void resolve_workqueue(jl_codegen_params_t &params, egal_set &method_root
499507
JL_GC_POP();
500508
}
501509

510+
/// Link the function in the source module into the destination module if
511+
/// needed, setting up mapping information.
512+
/// Similar to orc::cloneFunctionDecl, but more complete for greater correctness
513+
Function *IRLinker_copyFunctionProto(Module *DstM, Function *SF) {
514+
// If there is no linkage to be performed or we are linking from the source,
515+
// bring SF over, if we haven't already.
516+
if (SF->getParent() == DstM)
517+
return SF;
518+
if (auto *F = DstM->getNamedValue(SF->getName()))
519+
return cast<Function>(F);
520+
auto *F = Function::Create(SF->getFunctionType(), SF->getLinkage(),
521+
SF->getAddressSpace(), SF->getName(), DstM);
522+
F->copyAttributesFrom(SF);
523+
F->IsNewDbgInfoFormat = SF->IsNewDbgInfoFormat;
524+
525+
// Remove these copied constants since they point to the source module.
526+
F->setPersonalityFn(nullptr);
527+
F->setPrefixData(nullptr);
528+
F->setPrologueData(nullptr);
529+
return F;
530+
}
531+
532+
static Function *aot_abi_converter(jl_codegen_params_t &params, Module *M, jl_value_t *declrt, jl_value_t *sigt, size_t nargs, bool specsig, jl_code_instance_t *codeinst, Module *defM, StringRef func, StringRef specfunc, bool target_specsig)
533+
{
534+
std::string gf_thunk_name;
535+
if (!specfunc.empty()) {
536+
Value *llvmtarget = IRLinker_copyFunctionProto(M, defM->getFunction(specfunc));
537+
gf_thunk_name = emit_abi_converter(M, params, declrt, sigt, nargs, specsig, codeinst, llvmtarget, target_specsig);
538+
}
539+
else {
540+
Value *llvmtarget = func.empty() ? nullptr : IRLinker_copyFunctionProto(M, defM->getFunction(func));
541+
gf_thunk_name = emit_abi_dispatcher(M, params, declrt, sigt, nargs, specsig, codeinst, llvmtarget);
542+
}
543+
auto F = M->getFunction(gf_thunk_name);
544+
assert(F);
545+
return F;
546+
}
547+
548+
static void generate_cfunc_thunks(jl_codegen_params_t &params, jl_compiled_functions_t &compiled_functions)
549+
{
550+
DenseMap<jl_method_instance_t*, jl_code_instance_t*> compiled_mi;
551+
for (auto &def : compiled_functions) {
552+
jl_code_instance_t *this_code = def.first;
553+
jl_method_instance_t *mi = jl_get_ci_mi(this_code);
554+
if (this_code->owner == jl_nothing && jl_atomic_load_relaxed(&this_code->max_world) == ~(size_t)0 && this_code->def == (jl_value_t*)mi)
555+
compiled_mi[mi] = this_code;
556+
}
557+
size_t latestworld = jl_atomic_load_acquire(&jl_world_counter);
558+
for (cfunc_decl_t &cfunc : params.cfuncs) {
559+
Module *M = cfunc.theFptr->getParent();
560+
jl_value_t *sigt = cfunc.sigt;
561+
JL_GC_PROMISE_ROOTED(sigt);
562+
jl_value_t *declrt = cfunc.declrt;
563+
JL_GC_PROMISE_ROOTED(declrt);
564+
Function *unspec = aot_abi_converter(params, M, declrt, sigt, cfunc.nargs, cfunc.specsig, nullptr, nullptr, "", "", false);
565+
jl_code_instance_t *codeinst = nullptr;
566+
auto assign_fptr = [&params, &cfunc, &codeinst, &unspec](Function *f) {
567+
ConstantArray *init = cast<ConstantArray>(cfunc.cfuncdata->getInitializer());
568+
SmallVector<Constant*,6> initvals;
569+
for (unsigned i = 0; i < init->getNumOperands(); ++i)
570+
initvals.push_back(init->getOperand(i));
571+
assert(initvals.size() == 6);
572+
assert(initvals[0]->isNullValue());
573+
if (codeinst) {
574+
Constant *llvmcodeinst = literal_pointer_val_slot(params, f->getParent(), (jl_value_t*)codeinst);
575+
initvals[0] = llvmcodeinst; // plast_codeinst
576+
}
577+
assert(initvals[2]->isNullValue());
578+
initvals[2] = unspec;
579+
cfunc.cfuncdata->setInitializer(ConstantArray::get(init->getType(), initvals));
580+
cfunc.theFptr->setInitializer(f);
581+
};
582+
Module *defM = nullptr;
583+
StringRef func;
584+
jl_method_instance_t *mi = jl_get_specialization1((jl_tupletype_t*)sigt, latestworld, 0);
585+
if (mi) {
586+
auto it = compiled_mi.find(mi);
587+
if (it != compiled_mi.end()) {
588+
codeinst = it->second;
589+
JL_GC_PROMISE_ROOTED(codeinst);
590+
auto defs = compiled_functions.find(codeinst);
591+
defM = std::get<0>(defs->second).getModuleUnlocked();
592+
const jl_llvm_functions_t &decls = std::get<1>(defs->second);
593+
func = decls.functionObject;
594+
StringRef specfunc = decls.specFunctionObject;
595+
jl_value_t *astrt = codeinst->rettype;
596+
if (astrt != (jl_value_t*)jl_bottom_type &&
597+
jl_type_intersection(astrt, declrt) == jl_bottom_type) {
598+
// Do not warn if the function never returns since it is
599+
// occasionally required by the C API (typically error callbacks)
600+
// even though we're likely to encounter memory errors in that case
601+
jl_printf(JL_STDERR, "WARNING: cfunction: return type of %s does not match\n", name_from_method_instance(mi));
602+
}
603+
if (func == "jl_fptr_const_return") {
604+
std::string gf_thunk_name = emit_abi_constreturn(M, params, declrt, sigt, cfunc.nargs, cfunc.specsig, codeinst->rettype_const);
605+
auto F = M->getFunction(gf_thunk_name);
606+
assert(F);
607+
assign_fptr(F);
608+
continue;
609+
}
610+
else if (func == "jl_fptr_args") {
611+
assert(!specfunc.empty());
612+
if (!cfunc.specsig && jl_subtype(astrt, declrt)) {
613+
assign_fptr(IRLinker_copyFunctionProto(M, defM->getFunction(specfunc)));
614+
continue;
615+
}
616+
assign_fptr(aot_abi_converter(params, M, declrt, sigt, cfunc.nargs, cfunc.specsig, codeinst, defM, func, specfunc, false));
617+
continue;
618+
}
619+
else if (func == "jl_fptr_sparam" || func == "jl_f_opaque_closure_call") {
620+
func = ""; // use jl_invoke instead for these, since we don't declare these prototypes
621+
}
622+
else {
623+
assert(!specfunc.empty());
624+
if (jl_egal(mi->specTypes, sigt) && jl_egal(declrt, astrt)) {
625+
assign_fptr(IRLinker_copyFunctionProto(M, defM->getFunction(specfunc)));
626+
continue;
627+
}
628+
assign_fptr(aot_abi_converter(params, M, declrt, sigt, cfunc.nargs, cfunc.specsig, codeinst, defM, func, specfunc, true));
629+
continue;
630+
}
631+
}
632+
}
633+
Function *f = codeinst ? aot_abi_converter(params, M, declrt, sigt, cfunc.nargs, cfunc.specsig, codeinst, defM, func, "", false) : unspec;
634+
return assign_fptr(f);
635+
}
636+
}
637+
502638

503639
// takes the running content that has collected in the shadow module and dump it to disk
504640
// this builds the object file portion of the sysimage files for fast startup
@@ -651,7 +787,11 @@ void *jl_emit_native_impl(jl_array_t *codeinfos, LLVMOrcThreadSafeModuleRef llvm
651787
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(jl_get_ci_mi(codeinst)),
652788
params.tsctx, clone.getModuleUnlocked()->getDataLayout(),
653789
Triple(clone.getModuleUnlocked()->getTargetTriple()));
654-
jl_llvm_functions_t decls = jl_emit_codeinst(result_m, codeinst, src, params);
790+
jl_llvm_functions_t decls;
791+
if (jl_atomic_load_relaxed(&codeinst->invoke) == jl_fptr_const_return_addr)
792+
decls.functionObject = "jl_fptr_const_return";
793+
else
794+
decls = jl_emit_codeinst(result_m, codeinst, src, params);
655795
record_method_roots(method_roots, jl_get_ci_mi(codeinst));
656796
if (result_m)
657797
compiled_functions[codeinst] = {std::move(result_m), std::move(decls)};
@@ -671,6 +811,8 @@ void *jl_emit_native_impl(jl_array_t *codeinfos, LLVMOrcThreadSafeModuleRef llvm
671811
}
672812
// finally, make sure all referenced methods get fixed up, particularly if the user declined to compile them
673813
resolve_workqueue(params, method_roots, compiled_functions);
814+
// including generating cfunction thunks
815+
generate_cfunc_thunks(params, compiled_functions);
674816
aot_optimize_roots(params, method_roots, compiled_functions);
675817
params.temporary_roots = nullptr;
676818
JL_GC_POP();
@@ -728,9 +870,12 @@ void *jl_emit_native_impl(jl_array_t *codeinfos, LLVMOrcThreadSafeModuleRef llvm
728870
else if (func == "jl_fptr_sparam") {
729871
func_id = -2;
730872
}
731-
else if (decls.functionObject == "jl_f_opaque_closure_call") {
873+
else if (func == "jl_f_opaque_closure_call") {
732874
func_id = -4;
733875
}
876+
else if (func == "jl_fptr_const_return") {
877+
func_id = -5;
878+
}
734879
else {
735880
//Safe b/c context is locked by params
736881
data->jl_sysimg_fvars.push_back(cast<Function>(clone.getModuleUnlocked()->getNamedValue(func)));
@@ -2201,7 +2346,7 @@ extern "C" JL_DLLEXPORT_CODEGEN jl_code_info_t *jl_gdbdumpcode(jl_method_instanc
22012346
// for use in reflection from Julia.
22022347
// This is paired with jl_dump_function_ir and jl_dump_function_asm, either of which will free all memory allocated here
22032348
extern "C" JL_DLLEXPORT_CODEGEN
2204-
void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, jl_code_info_t *src, char getwrapper, char optimize, const jl_cgparams_t params)
2349+
void jl_get_llvmf_defn_impl(jl_llvmf_dump_t *dump, jl_method_instance_t *mi, jl_code_info_t *src, char getwrapper, char optimize, const jl_cgparams_t params)
22052350
{
22062351
// emit this function into a new llvm module
22072352
dump->F = nullptr;
@@ -2223,7 +2368,31 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, jl_
22232368
output.imaging_mode = jl_options.image_codegen;
22242369
output.temporary_roots = jl_alloc_array_1d(jl_array_any_type, 0);
22252370
JL_GC_PUSH1(&output.temporary_roots);
2226-
auto decls = jl_emit_code(m, mi, src, mi->specTypes, src->rettype, output);
2371+
jl_llvm_functions_t decls = jl_emit_code(m, mi, src, mi->specTypes, src->rettype, output);
2372+
// while not required, also emit the cfunc thunks, based on the
2373+
// inferred ABIs of their targets in the current latest world,
2374+
// since otherwise it is challenging to see all relevant codes
2375+
jl_compiled_functions_t compiled_functions;
2376+
size_t latestworld = jl_atomic_load_acquire(&jl_world_counter);
2377+
for (cfunc_decl_t &cfunc : output.cfuncs) {
2378+
jl_value_t *sigt = cfunc.sigt;
2379+
JL_GC_PROMISE_ROOTED(sigt);
2380+
jl_method_instance_t *mi = jl_get_specialization1((jl_tupletype_t*)sigt, latestworld, 0);
2381+
if (mi == nullptr)
2382+
continue;
2383+
jl_code_instance_t *codeinst = jl_type_infer(mi, latestworld, SOURCE_MODE_NOT_REQUIRED);
2384+
if (codeinst == nullptr || compiled_functions.count(codeinst))
2385+
continue;
2386+
orc::ThreadSafeModule decl_m = jl_create_ts_module("extern", ctx);
2387+
jl_llvm_functions_t decls;
2388+
if (jl_atomic_load_relaxed(&codeinst->invoke) == jl_fptr_const_return_addr)
2389+
decls.functionObject = "jl_fptr_const_return";
2390+
else
2391+
decls = jl_emit_codedecls(decl_m, codeinst, output);
2392+
compiled_functions[codeinst] = {std::move(decl_m), std::move(decls)};
2393+
}
2394+
generate_cfunc_thunks(output, compiled_functions);
2395+
compiled_functions.clear();
22272396
output.temporary_roots = nullptr;
22282397
JL_GC_POP(); // GC the global_targets array contents now since reflection doesn't need it
22292398

src/ccall.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,8 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
19701970
return retval;
19711971
}
19721972

1973+
static inline Constant *literal_static_pointer_val(const void *p, Type *T);
1974+
19731975
jl_cgval_t function_sig_t::emit_a_ccall(
19741976
jl_codectx_t &ctx,
19751977
const native_sym_arg_t &symarg,

0 commit comments

Comments
 (0)