diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index a362f1640c2e0..657db58831c1b 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -860,9 +860,10 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, LLVMRustEraseInstBefore(bb, last_inst); let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src)); + let f_is_struct = llvm::LLVMRustIsStructType(f_return_type); let void_type = LLVMVoidTypeInContext(llcx); // Now unwrap the struct_ret if it's actually a struct - if f_return_type != void_type { + if f_is_struct { let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); if num_elem_in_ret_struct == 1 { let inner_grad_name = "foo".to_string(); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 69049ca752d26..5b94b80502109 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1035,6 +1035,7 @@ extern "C" { pub fn LLVMRustEraseInstFromParent(V: &Value); pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMRustIsStructType(T: &Type) -> bool; pub fn LLVMDumpModule(M: &Module); pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; pub fn LLVMDeleteFunction(V: &Value); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 548040579b392..078c8918939b0 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -300,6 +300,10 @@ extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index, AddAttributes(F, Index, Attrs, AttrsLen); } +extern "C" bool LLVMRustIsStructType(LLVMTypeRef Ty) { + return unwrap(Ty)->isStructTy(); +} + extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, unsigned Index, LLVMAttributeRef *Attrs,