Skip to content

Commit ec3b0d2

Browse files
authored
Merge pull request #1238 from czgdp1807/arr_structs
Support arrays inside structs
2 parents fc08c06 + 61aa7ca commit ec3b0d2

File tree

7 files changed

+165
-59
lines changed

7 files changed

+165
-59
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ RUN(NAME structs_07 LABELS llvm c
265265
EXTRAFILES structs_07b.c)
266266
RUN(NAME structs_08 LABELS cpython llvm c)
267267
RUN(NAME structs_09 LABELS cpython llvm c)
268+
RUN(NAME structs_10 LABELS cpython llvm c)
268269
RUN(NAME sizeof_01 LABELS llvm c
269270
EXTRAFILES sizeof_01b.c)
270271
RUN(NAME enum_01 LABELS cpython llvm c)

integration_tests/structs_10.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from ltypes import i32, f64, dataclass
2+
from numpy import empty, float64
3+
4+
@dataclass
5+
class MatVec:
6+
mat: f64[2, 2]
7+
vec: f64[2]
8+
9+
def rotate(mat_vec: MatVec) -> f64[2]:
10+
rotated_vec: f64[2] = empty(2, dtype=float64)
11+
rotated_vec[0] = mat_vec.mat[0, 0] * mat_vec.vec[0] + mat_vec.mat[0, 1] * mat_vec.vec[1]
12+
rotated_vec[1] = mat_vec.mat[1, 0] * mat_vec.vec[0] + mat_vec.mat[1, 1] * mat_vec.vec[1]
13+
return rotated_vec
14+
15+
def test_rotate_by_90():
16+
mat: f64[2, 2] = empty((2, 2), dtype=float64)
17+
vec: f64[2] = empty(2, dtype=float64)
18+
mat[0, 0] = 0.0
19+
mat[0, 1] = -1.0
20+
mat[1, 0] = 1.0
21+
mat[1, 1] = 0.0
22+
vec[0] = 1.0
23+
vec[1] = 0.0
24+
mat_vec: MatVec = MatVec(mat, vec)
25+
print(mat_vec.mat[0, 0], mat_vec.mat[0, 1], mat_vec.mat[1, 0], mat_vec.mat[1, 1])
26+
print(mat_vec.vec[0], mat_vec.vec[1])
27+
rotated_vec: f64[2] = rotate(mat_vec)
28+
print(rotated_vec[0], rotated_vec[1])
29+
assert abs(rotated_vec[0] - 0.0) <= 1e-12
30+
assert abs(rotated_vec[1] - 1.0) <= 1e-12
31+
32+
test_rotate_by_90()

src/libasr/asr_utils.h

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -726,40 +726,50 @@ static inline void encode_dimensions(size_t n_dims, std::string& res,
726726
}
727727

728728
static inline std::string get_type_code(const ASR::ttype_t *t, bool use_underscore_sep=false,
729-
bool encode_dimensions_=true)
729+
bool encode_dimensions_=true, bool set_dimensional_hint=true)
730730
{
731+
bool is_dimensional = false;
732+
std::string res = "";
731733
switch (t->type) {
732734
case ASR::ttypeType::Integer: {
733735
ASR::Integer_t *integer = ASR::down_cast<ASR::Integer_t>(t);
734-
std::string res = "i" + std::to_string(integer->m_kind * 8);
736+
res = "i" + std::to_string(integer->m_kind * 8);
735737
if( encode_dimensions_ ) {
736738
encode_dimensions(integer->n_dims, res, use_underscore_sep);
739+
return res;
737740
}
738-
return res;
741+
is_dimensional = integer->n_dims > 0;
742+
break;
739743
}
740744
case ASR::ttypeType::Real: {
741745
ASR::Real_t *real = ASR::down_cast<ASR::Real_t>(t);
742-
std::string res = "r" + std::to_string(real->m_kind * 8);
746+
res = "r" + std::to_string(real->m_kind * 8);
743747
if( encode_dimensions_ ) {
744748
encode_dimensions(real->n_dims, res, use_underscore_sep);
749+
return res;
745750
}
746-
return res;
751+
is_dimensional = real->n_dims > 0;
752+
break;
747753
}
748754
case ASR::ttypeType::Complex: {
749755
ASR::Complex_t *complx = ASR::down_cast<ASR::Complex_t>(t);
750-
std::string res = "c" + std::to_string(complx->m_kind * 8);
756+
res = "c" + std::to_string(complx->m_kind * 8);
751757
if( encode_dimensions_ ) {
752758
encode_dimensions(complx->n_dims, res, use_underscore_sep);
759+
return res;
753760
}
754-
return res;
761+
is_dimensional = complx->n_dims > 0;
762+
break;
755763
}
756764
case ASR::ttypeType::Logical: {
757765
ASR::Logical_t* bool_ = ASR::down_cast<ASR::Logical_t>(t);
758766
std::string res = "bool";
759767
if( encode_dimensions_ ) {
760768
encode_dimensions(bool_->n_dims, res, use_underscore_sep);
769+
return res;
761770
}
762-
return res;
771+
is_dimensional = bool_->n_dims > 0;
772+
break;
763773
}
764774
case ASR::ttypeType::Character: {
765775
return "str";
@@ -773,7 +783,8 @@ static inline std::string get_type_code(const ASR::ttype_t *t, bool use_undersco
773783
result += "[";
774784
}
775785
for (size_t i = 0; i < tup->n_type; i++) {
776-
result += get_type_code(tup->m_type[i], use_underscore_sep, encode_dimensions_);
786+
result += get_type_code(tup->m_type[i], use_underscore_sep,
787+
encode_dimensions_, set_dimensional_hint);
777788
if (i + 1 != tup->n_type) {
778789
if( use_underscore_sep ) {
779790
result += "_";
@@ -792,64 +803,80 @@ static inline std::string get_type_code(const ASR::ttype_t *t, bool use_undersco
792803
case ASR::ttypeType::Set: {
793804
ASR::Set_t *s = ASR::down_cast<ASR::Set_t>(t);
794805
if( use_underscore_sep ) {
795-
return "set_" + get_type_code(s->m_type, use_underscore_sep, encode_dimensions_) + "_";
806+
return "set_" + get_type_code(s->m_type, use_underscore_sep,
807+
encode_dimensions_, set_dimensional_hint) + "_";
796808
}
797-
return "set[" + get_type_code(s->m_type, use_underscore_sep, encode_dimensions_) + "]";
809+
return "set[" + get_type_code(s->m_type, use_underscore_sep,
810+
encode_dimensions_, set_dimensional_hint) + "]";
798811
}
799812
case ASR::ttypeType::Dict: {
800813
ASR::Dict_t *d = ASR::down_cast<ASR::Dict_t>(t);
801814
if( use_underscore_sep ) {
802-
return "dict_" + get_type_code(d->m_key_type, use_underscore_sep, encode_dimensions_) +
803-
"_" + get_type_code(d->m_value_type, use_underscore_sep, encode_dimensions_) + "_";
815+
return "dict_" + get_type_code(d->m_key_type, use_underscore_sep,
816+
encode_dimensions_, set_dimensional_hint) +
817+
"_" + get_type_code(d->m_value_type, use_underscore_sep,
818+
encode_dimensions_, set_dimensional_hint) + "_";
804819
}
805-
return "dict[" + get_type_code(d->m_key_type, use_underscore_sep, encode_dimensions_) +
806-
", " + get_type_code(d->m_value_type, use_underscore_sep, encode_dimensions_) + "]";
820+
return "dict[" + get_type_code(d->m_key_type, use_underscore_sep,
821+
encode_dimensions_, set_dimensional_hint) +
822+
", " + get_type_code(d->m_value_type, use_underscore_sep,
823+
encode_dimensions_, set_dimensional_hint) + "]";
807824
}
808825
case ASR::ttypeType::List: {
809826
ASR::List_t *l = ASR::down_cast<ASR::List_t>(t);
810827
if( use_underscore_sep ) {
811-
return "list_" + get_type_code(l->m_type, use_underscore_sep, encode_dimensions_) + "_";
828+
return "list_" + get_type_code(l->m_type, use_underscore_sep,
829+
encode_dimensions_, set_dimensional_hint) + "_";
812830
}
813-
return "list[" + get_type_code(l->m_type, use_underscore_sep, encode_dimensions_) + "]";
831+
return "list[" + get_type_code(l->m_type, use_underscore_sep,
832+
encode_dimensions_, set_dimensional_hint) + "]";
814833
}
815834
case ASR::ttypeType::CPtr: {
816835
return "CPtr";
817836
}
818837
case ASR::ttypeType::Struct: {
819838
ASR::Struct_t* d = ASR::down_cast<ASR::Struct_t>(t);
820-
std::string res = symbol_name(d->m_derived_type);
839+
res = symbol_name(d->m_derived_type);
821840
if( encode_dimensions_ ) {
822841
encode_dimensions(d->n_dims, res, use_underscore_sep);
842+
return res;
823843
}
824-
return res;
844+
is_dimensional = d->n_dims > 0;
845+
break;
825846
}
826847
case ASR::ttypeType::Union: {
827848
ASR::Union_t* d = ASR::down_cast<ASR::Union_t>(t);
828-
std::string res = symbol_name(d->m_union_type);
849+
res = symbol_name(d->m_union_type);
829850
if( encode_dimensions_ ) {
830851
encode_dimensions(d->n_dims, res, use_underscore_sep);
852+
return res;
831853
}
832-
return res;
854+
is_dimensional = d->n_dims > 0;
855+
break;
833856
}
834857
case ASR::ttypeType::Pointer: {
835858
ASR::Pointer_t* p = ASR::down_cast<ASR::Pointer_t>(t);
836859
if( use_underscore_sep ) {
837-
return "Pointer_" + get_type_code(p->m_type, use_underscore_sep, encode_dimensions_) + "_";
860+
return "Pointer_" + get_type_code(p->m_type, use_underscore_sep, encode_dimensions_, set_dimensional_hint) + "_";
838861
}
839-
return "Pointer[" + get_type_code(p->m_type, use_underscore_sep, encode_dimensions_) + "]";
862+
return "Pointer[" + get_type_code(p->m_type, use_underscore_sep, encode_dimensions_, set_dimensional_hint) + "]";
840863
}
841864
case ASR::ttypeType::Const: {
842865
ASR::Const_t* p = ASR::down_cast<ASR::Const_t>(t);
843866
if( use_underscore_sep ) {
844-
return "Const_" + get_type_code(p->m_type, use_underscore_sep, encode_dimensions_) + "_";
867+
return "Const_" + get_type_code(p->m_type, use_underscore_sep, encode_dimensions_, set_dimensional_hint) + "_";
845868
}
846-
return "Const[" + get_type_code(p->m_type, use_underscore_sep, encode_dimensions_) + "]";
869+
return "Const[" + get_type_code(p->m_type, use_underscore_sep, encode_dimensions_, set_dimensional_hint) + "]";
847870
}
848871
default: {
849872
throw LCompilersException("Type encoding not implemented for "
850873
+ std::to_string(t->type));
851874
}
852875
}
876+
if( is_dimensional && set_dimensional_hint ) {
877+
res += "dim";
878+
}
879+
return res;
853880
}
854881

855882
static inline std::string get_type_code(ASR::ttype_t** types, size_t n_types,

src/libasr/codegen/asr_to_c.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,15 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
381381
bool is_fixed_size = true;
382382
dims = convert_dims_c(t->n_dims, t->m_dims, v.m_type, is_fixed_size, true);
383383
std::string encoded_type_name = "i" + std::to_string(t->m_kind * 8);
384+
bool is_struct_type_member = ASR::is_a<ASR::StructType_t>(
385+
*ASR::down_cast<ASR::symbol_t>(v.m_parent_symtab->asr_owner));
384386
generate_array_decl(sub, std::string(v.m_name), type_name, dims,
385387
encoded_type_name, t->m_dims, t->n_dims,
386388
use_ref, dummy,
387389
v.m_intent != ASRUtils::intent_in &&
388390
v.m_intent != ASRUtils::intent_inout &&
389-
v.m_intent != ASRUtils::intent_out, is_fixed_size);
391+
v.m_intent != ASRUtils::intent_out &&
392+
!is_struct_type_member, is_fixed_size);
390393
} else {
391394
bool is_fixed_size = true;
392395
dims = convert_dims_c(t->n_dims, t->m_dims, v.m_type, is_fixed_size);
@@ -400,12 +403,15 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
400403
bool is_fixed_size = true;
401404
dims = convert_dims_c(t->n_dims, t->m_dims, v.m_type, is_fixed_size, true);
402405
std::string encoded_type_name = "r" + std::to_string(t->m_kind * 8);
406+
bool is_struct_type_member = ASR::is_a<ASR::StructType_t>(
407+
*ASR::down_cast<ASR::symbol_t>(v.m_parent_symtab->asr_owner));
403408
generate_array_decl(sub, std::string(v.m_name), type_name, dims,
404409
encoded_type_name, t->m_dims, t->n_dims,
405410
use_ref, dummy,
406411
v.m_intent != ASRUtils::intent_in &&
407412
v.m_intent != ASRUtils::intent_inout &&
408-
v.m_intent != ASRUtils::intent_out, is_fixed_size);
413+
v.m_intent != ASRUtils::intent_out &&
414+
!is_struct_type_member, is_fixed_size);
409415
} else {
410416
bool is_fixed_size = true;
411417
dims = convert_dims_c(t->n_dims, t->m_dims, v.m_type, is_fixed_size);
@@ -819,6 +825,7 @@ R"(
819825
}
820826

821827
void visit_StructType(const ASR::StructType_t& x) {
828+
src = "";
822829
std::string c_type_name = "struct";
823830
if( x.m_is_packed ) {
824831
std::string attr_args = "(packed";
@@ -835,6 +842,7 @@ R"(
835842
c_type_name += " __attribute__(" + attr_args + ")";
836843
}
837844
visit_AggregateTypeUtil(x, c_type_name);
845+
src = "";
838846
}
839847

840848
void visit_UnionType(const ASR::UnionType_t& x) {
@@ -1102,13 +1110,13 @@ R"(
11021110

11031111
ASR::ttype_t* array_type_asr = ASRUtils::expr_type(x.m_array);
11041112
std::string array_type_name = CUtils::get_c_type_from_ttype_t(array_type_asr);
1105-
std::string array_encoded_type_name = ASRUtils::get_type_code(array_type_asr, true, false);
1113+
std::string array_encoded_type_name = ASRUtils::get_type_code(array_type_asr, true, false, false);
11061114
std::string array_type = get_array_type(array_type_name, array_encoded_type_name, true);
11071115
std::string return_type = get_array_type(array_type_name, array_encoded_type_name, false);
11081116

11091117
ASR::ttype_t* shape_type_asr = ASRUtils::expr_type(x.m_shape);
11101118
std::string shape_type_name = CUtils::get_c_type_from_ttype_t(shape_type_asr);
1111-
std::string shape_encoded_type_name = ASRUtils::get_type_code(shape_type_asr, true, false);
1119+
std::string shape_encoded_type_name = ASRUtils::get_type_code(shape_type_asr, true, false, false);
11121120
std::string shape_type = get_array_type(shape_type_name, shape_encoded_type_name, true);
11131121

11141122
std::string array_reshape_func = c_utils_functions->get_array_reshape(array_type, shape_type,

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
617617
const std::map<std::string, ASR::symbol_t*>& scope = der_type->m_symtab->get_scope();
618618
for( auto itr = scope.begin(); itr != scope.end(); itr++ ) {
619619
ASR::Variable_t* member = (ASR::Variable_t*)(&(itr->second->base));
620-
llvm::Type* llvm_mem_type = getMemberType(member->m_type, member);
620+
llvm::Type* llvm_mem_type = get_type_from_ttype_t_util(member->m_type);
621621
member_types.push_back(llvm_mem_type);
622622
name2memidx[der_type_name][std::string(member->m_name)] = member_idx;
623623
member_idx++;
@@ -3790,9 +3790,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
37903790
if( ASR::is_a<ASR::UnionTypeConstructor_t>(*x.m_value) ) {
37913791
return ;
37923792
}
3793+
ASR::ttype_t* target_type = ASRUtils::expr_type(x.m_target);
3794+
ASR::ttype_t* value_type = ASRUtils::expr_type(x.m_value);
37933795
this->visit_expr_wrapper(x.m_value, true);
37943796
if( ASR::is_a<ASR::Var_t>(*x.m_value) &&
3795-
ASR::is_a<ASR::Union_t>(*ASRUtils::expr_type(x.m_value)) ) {
3797+
ASR::is_a<ASR::Union_t>(*value_type) ) {
37963798
tmp = LLVM::CreateLoad(*builder, tmp);
37973799
}
37983800
value = tmp;
@@ -3804,12 +3806,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
38043806
}
38053807
}
38063808
}
3807-
ASR::ttype_t* target_type = ASRUtils::expr_type(x.m_target);
3808-
ASR::ttype_t* value_type = ASRUtils::expr_type(x.m_value);
38093809
if( ASRUtils::is_array(target_type) &&
38103810
ASRUtils::is_array(value_type) &&
38113811
ASRUtils::check_equal_type(target_type, value_type) ) {
3812-
arr_descr->copy_array(value, target);
3812+
bool create_dim_des_array = !ASR::is_a<ASR::Var_t>(*x.m_target);
3813+
arr_descr->copy_array(value, target, module.get(),
3814+
target_type, create_dim_des_array);
38133815
} else {
38143816
builder->CreateStore(value, target);
38153817
}

0 commit comments

Comments
 (0)