Skip to content

Handle structs as return type in LLVM backend #1295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,9 @@ RUN(NAME structs_07 LABELS llvm c
EXTRAFILES structs_07b.c)
RUN(NAME structs_08 LABELS cpython llvm c)
RUN(NAME structs_09 LABELS cpython llvm c)
RUN(NAME structs_10 LABELS cpython llvm c)
# TODO: Re-enable c in structs_10
RUN(NAME structs_10 LABELS cpython llvm)
RUN(NAME structs_11 LABELS cpython llvm)
RUN(NAME structs_12 LABELS cpython llvm c)
RUN(NAME structs_13 LABELS llvm c
EXTRAFILES structs_13b.c)
Expand Down
28 changes: 21 additions & 7 deletions integration_tests/structs_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,25 @@
from numpy import empty, float64

@dataclass
class MatVec:
class Mat:
mat: f64[2, 2]

@dataclass
class Vec:
vec: f64[2]

@dataclass
class MatVec:
mat: Mat
vec: Vec

def rotate(mat_vec: MatVec) -> f64[2]:
rotated_vec: f64[2] = empty(2, dtype=float64)
rotated_vec[0] = mat_vec.mat[0, 0] * mat_vec.vec[0] + mat_vec.mat[0, 1] * mat_vec.vec[1]
rotated_vec[1] = mat_vec.mat[1, 0] * mat_vec.vec[0] + mat_vec.mat[1, 1] * mat_vec.vec[1]
rotated_vec[0] = mat_vec.mat.mat[0, 0] * mat_vec.vec.vec[0] + mat_vec.mat.mat[0, 1] * mat_vec.vec.vec[1]
rotated_vec[1] = mat_vec.mat.mat[1, 0] * mat_vec.vec.vec[0] + mat_vec.mat.mat[1, 1] * mat_vec.vec.vec[1]
return rotated_vec

def test_rotate_by_90():
def create_MatVec_obj() -> MatVec:
mat: f64[2, 2] = empty((2, 2), dtype=float64)
vec: f64[2] = empty(2, dtype=float64)
mat[0, 0] = 0.0
Expand All @@ -21,9 +29,15 @@ def test_rotate_by_90():
mat[1, 1] = 0.0
vec[0] = 1.0
vec[1] = 0.0
mat_vec: MatVec = MatVec(mat, vec)
print(mat_vec.mat[0, 0], mat_vec.mat[0, 1], mat_vec.mat[1, 0], mat_vec.mat[1, 1])
print(mat_vec.vec[0], mat_vec.vec[1])
mat_s: Mat = Mat(mat)
vec_s: Vec = Vec(vec)
mat_vec: MatVec = MatVec(mat_s, vec_s)
return mat_vec

def test_rotate_by_90():
mat_vec: MatVec = create_MatVec_obj()
print(mat_vec.mat.mat[0, 0], mat_vec.mat.mat[0, 1], mat_vec.mat.mat[1, 0], mat_vec.mat.mat[1, 1])
print(mat_vec.vec.vec[0], mat_vec.vec.vec[1])
rotated_vec: f64[2] = rotate(mat_vec)
print(rotated_vec[0], rotated_vec[1])
assert abs(rotated_vec[0] - 0.0) <= 1e-12
Expand Down
18 changes: 18 additions & 0 deletions integration_tests/structs_11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ltypes import i32, f64, dataclass

@dataclass
class A:
x: i32
y: f64

def f(x_: i32, y_: f64) -> A:
a_struct: A = A(x_, y_)
return a_struct

def test_struct_return():
b: A = f(0, 1.0)
print(b.x, b.y)
assert b.x == 0
assert b.y == 1.0

test_struct_return()
1 change: 1 addition & 0 deletions src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ set(SRC
pass/select_case.cpp
pass/implied_do_loops.cpp
pass/array_op.cpp
pass/subroutine_from_function.cpp
pass/class_constructor.cpp
pass/arr_slice.cpp
pass/print_arr.cpp
Expand Down
100 changes: 66 additions & 34 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm_utils->tuple_api = tuple_api.get();
llvm_utils->list_api = list_api.get();
llvm_utils->dict_api = nullptr;
llvm_utils->arr_api = arr_descr.get();
}

llvm::Value* CreateLoad(llvm::Value *x) {
Expand Down Expand Up @@ -1383,7 +1384,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr(*x.m_args[i]);
llvm::Value* item = tmp;
llvm::Value* pos = llvm::ConstantInt::get(context, llvm::APInt(32, i));
list_api->write_item(const_list, pos, item, list_type->m_type, false, *module);
list_api->write_item(const_list, pos, item, list_type->m_type,
false, module.get(), name2memidx);
}
ptr_loads = ptr_loads_copy;
tmp = const_list;
Expand Down Expand Up @@ -1416,7 +1418,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
visit_expr(*x.m_values[i]);
llvm::Value* value = tmp;
llvm_utils->dict_api->write_item(const_dict, key, value, module.get(),
x_dict->m_key_type, x_dict->m_value_type);
x_dict->m_key_type, x_dict->m_value_type, name2memidx);
}
ptr_loads = ptr_loads_copy;
tmp = const_dict;
Expand Down Expand Up @@ -1484,7 +1486,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Value *item = tmp;
ptr_loads = ptr_loads_copy;

list_api->append(plist, item, asr_list->m_type, *module);
list_api->append(plist, item, asr_list->m_type, module.get(), name2memidx);
}

void visit_UnionRef(const ASR::UnionRef_t& x) {
Expand Down Expand Up @@ -1609,7 +1611,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Value *item = tmp;
ptr_loads = ptr_loads_copy;

list_api->insert_item(plist, pos, item, asr_list->m_type, *module);
list_api->insert_item(plist, pos, item, asr_list->m_type, module.get(), name2memidx);
}

void visit_DictInsert(const ASR::DictInsert_t& x) {
Expand All @@ -1631,7 +1633,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
set_dict_api(dict_type);
llvm_utils->dict_api->write_item(pdict, key, value, module.get(),
dict_type->m_key_type,
dict_type->m_value_type);
dict_type->m_value_type, name2memidx);
}

void visit_ListRemove(const ASR::ListRemove_t& x) {
Expand Down Expand Up @@ -2571,6 +2573,49 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
m_dims_local, n_dims_local, a_kind_local);
}

void fill_array_details_(llvm::Value* ptr, ASR::dimension_t* m_dims,
size_t n_dims, bool is_malloc_array_type, bool is_array_type,
bool is_list, ASR::ttype_t* m_type) {
if( is_malloc_array_type &&
m_type->type != ASR::ttypeType::Pointer &&
!is_list ) {
arr_descr->fill_dimension_descriptor(ptr, n_dims);
}
if( is_array_type && !is_malloc_array_type &&
m_type->type != ASR::ttypeType::Pointer &&
!is_list ) {
ASR::ttype_t* asr_data_type = ASRUtils::duplicate_type_without_dims(al, m_type, m_type->base.loc);
llvm::Type* llvm_data_type = get_type_from_ttype_t_util(asr_data_type);
fill_array_details(ptr, llvm_data_type, m_dims, n_dims);
}
if( is_array_type && is_malloc_array_type &&
m_type->type != ASR::ttypeType::Pointer &&
!is_list ) {
// Set allocatable arrays as unallocated
arr_descr->set_is_allocated_flag(ptr, 0);
}
}

void allocate_array_members_of_struct(llvm::Value* ptr, ASR::ttype_t* asr_type) {
LFORTRAN_ASSERT(ASR::is_a<ASR::Struct_t>(*asr_type));
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(asr_type);
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(struct_t->m_derived_type);
std::string struct_type_name = struct_type_t->m_name;
for( auto item: struct_type_t->m_symtab->get_scope() ) {
ASR::ttype_t* symbol_type = ASRUtils::symbol_type(item.second);
int idx = name2memidx[struct_type_name][item.first];
llvm::Value* ptr_member = llvm_utils->create_gep(ptr, idx);
if( ASRUtils::is_array(symbol_type) ) {
// Assume that struct member array is not allocatable
ASR::dimension_t* m_dims = nullptr;
size_t n_dims = ASRUtils::extract_dimensions_from_ttype(symbol_type, m_dims);
fill_array_details_(ptr_member, m_dims, n_dims, false, true, false, symbol_type);
} else if( ASR::is_a<ASR::Struct_t>(*symbol_type) ) {
allocate_array_members_of_struct(ptr_member, symbol_type);
}
}
}

template<typename T>
void declare_vars(const T &x) {
llvm::Value *target_var;
Expand Down Expand Up @@ -2626,6 +2671,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}
llvm::AllocaInst *ptr = builder->CreateAlloca(type, nullptr, v->m_name);
if( ASR::is_a<ASR::Struct_t>(*v->m_type) ) {
allocate_array_members_of_struct(ptr, v->m_type);
}
if (emit_debug_info) {
// Reset the debug location
builder->SetCurrentDebugLocation(nullptr);
Expand Down Expand Up @@ -2659,24 +2707,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}
llvm_symtab[h] = ptr;
if( is_malloc_array_type &&
v->m_type->type != ASR::ttypeType::Pointer &&
!is_list ) {
arr_descr->fill_dimension_descriptor(ptr, n_dims);
}
if( is_array_type && !is_malloc_array_type &&
v->m_type->type != ASR::ttypeType::Pointer &&
!is_list ) {
ASR::ttype_t* asr_data_type = ASRUtils::duplicate_type_without_dims(al, v->m_type, v->m_type->base.loc);
llvm::Type* llvm_data_type = get_type_from_ttype_t_util(asr_data_type);
fill_array_details(ptr, llvm_data_type, m_dims, n_dims);
}
if( is_array_type && is_malloc_array_type &&
v->m_type->type != ASR::ttypeType::Pointer &&
!is_list ) {
// Set allocatable arrays as unallocated
arr_descr->set_is_allocated_flag(ptr, 0);
}
fill_array_details_(ptr, m_dims, n_dims,
is_malloc_array_type,
is_array_type, is_list, v->m_type);
if( v->m_symbolic_value != nullptr &&
!ASR::is_a<ASR::List_t>(*v->m_type)) {
target_var = ptr;
Expand Down Expand Up @@ -3776,7 +3809,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
ASRUtils::expr_type(x.m_value));
std::string value_type_code = ASRUtils::get_type_code(value_asr_list->m_type);
list_api->list_deepcopy(value_list, target_list,
value_asr_list, *module);
value_asr_list, module.get(),
name2memidx);
return ;
} else if( is_target_tuple && is_value_tuple ) {
uint64_t ptr_loads_copy = ptr_loads;
Expand Down Expand Up @@ -3805,7 +3839,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Value* llvm_tuple_i = builder->CreateAlloca(llvm_tuple_i_type, nullptr);
ptr_loads = !LLVM::is_llvm_struct(asr_tuple_i_type);
visit_expr(*asr_value_tuple->m_elements[i]);
llvm_utils->deepcopy(tmp, llvm_tuple_i, asr_tuple_i_type, *module);
llvm_utils->deepcopy(tmp, llvm_tuple_i, asr_tuple_i_type, module.get(), name2memidx);
src_deepcopies.push_back(al, llvm_tuple_i);
}
ASR::TupleConstant_t* asr_target_tuple = ASR::down_cast<ASR::TupleConstant_t>(x.m_target);
Expand All @@ -3829,7 +3863,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
std::string type_code = ASRUtils::get_type_code(value_tuple_type->m_type,
value_tuple_type->n_type);
tuple_api->tuple_deepcopy(value_tuple, target_tuple,
value_tuple_type, *module);
value_tuple_type, module.get(),
name2memidx);
}
return ;
} else if( is_target_dict && is_value_dict ) {
Expand All @@ -3843,7 +3878,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
ASR::Dict_t* value_dict_type = ASR::down_cast<ASR::Dict_t>(asr_value_type);
set_dict_api(value_dict_type);
llvm_utils->dict_api->dict_deepcopy(value_dict, target_dict,
value_dict_type, module.get());
value_dict_type, module.get(), name2memidx);
return ;
} else if( is_target_struct && is_value_struct ) {
uint64_t ptr_loads_copy = ptr_loads;
Expand All @@ -3856,10 +3891,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
is_assignment_target = is_assignment_target_copy;
llvm::Value* target_struct = tmp;
ptr_loads = ptr_loads_copy;
LLVM::CreateStore(*builder,
LLVM::CreateLoad(*builder, value_struct),
target_struct
);
llvm_utils->deepcopy(value_struct, target_struct,
asr_target_type, module.get(), name2memidx);
return ;
}

Expand Down Expand Up @@ -3973,9 +4006,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
if( ASRUtils::is_array(target_type) &&
ASRUtils::is_array(value_type) &&
ASRUtils::check_equal_type(target_type, value_type) ) {
bool create_dim_des_array = !ASR::is_a<ASR::Var_t>(*x.m_target);
arr_descr->copy_array(value, target, module.get(),
target_type, create_dim_des_array);
target_type, false, false);
} else {
builder->CreateStore(value, target);
}
Expand Down Expand Up @@ -5975,7 +6007,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
if( ASR::is_a<ASR::Tuple_t>(*arg_type) ||
ASR::is_a<ASR::List_t>(*arg_type) ) {
llvm_utils->deepcopy(value, target, arg_type, *module);
llvm_utils->deepcopy(value, target, arg_type, module.get(), name2memidx);
} else {
builder->CreateStore(value, target);
}
Expand Down
9 changes: 6 additions & 3 deletions src/libasr/codegen/llvm_array_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,13 +571,16 @@ namespace LFortran {

// Shallow copies source array descriptor to destination descriptor
void SimpleCMODescriptor::copy_array(llvm::Value* src, llvm::Value* dest,
llvm::Module* module, ASR::ttype_t* asr_data_type, bool create_dim_des_array) {
llvm::Module* module, ASR::ttype_t* asr_data_type, bool create_dim_des_array,
bool reserve_memory) {
llvm::Value* num_elements = this->get_array_size(src, nullptr, 4);

llvm::Value* first_ptr = this->get_pointer_to_data(dest);
llvm::Type* llvm_data_type = tkr2array[ASRUtils::get_type_code(asr_data_type, false, false)].second;
llvm::Value* arr_first = builder->CreateAlloca(llvm_data_type, num_elements);
builder->CreateStore(arr_first, first_ptr);
if( reserve_memory ) {
llvm::Value* arr_first = builder->CreateAlloca(llvm_data_type, num_elements);
builder->CreateStore(arr_first, first_ptr);
}

llvm::Value* ptr2firstptr = this->get_pointer_to_data(src);
llvm::DataLayout data_layout(module);
Expand Down
4 changes: 2 additions & 2 deletions src/libasr/codegen/llvm_array_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ namespace LFortran {
virtual
void copy_array(llvm::Value* src, llvm::Value* dest,
llvm::Module* module, ASR::ttype_t* asr_data_type,
bool create_dim_des_array) = 0;
bool create_dim_des_array, bool reserve_memory) = 0;

virtual
llvm::Value* get_array_size(llvm::Value* array, llvm::Value* dim,
Expand Down Expand Up @@ -394,7 +394,7 @@ namespace LFortran {
virtual
void copy_array(llvm::Value* src, llvm::Value* dest,
llvm::Module* module, ASR::ttype_t* asr_data_type,
bool create_dim_des_array);
bool create_dim_des_array, bool reserve_memory);

virtual
llvm::Value* get_array_size(llvm::Value* array, llvm::Value* dim,
Expand Down
Loading