Skip to content

Commit 70fef49

Browse files
authored
Merge pull request #614 from czgdp1807/structs_03
Support for pointer to structs
2 parents 9e3d33c + beecbc4 commit 70fef49

11 files changed

+231
-110
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ RUN(NAME test_unary_plus LABELS cpython llvm)
170170
RUN(NAME test_bool_binop LABELS cpython llvm)
171171
RUN(NAME test_issue_518 LABELS cpython llvm)
172172
RUN(NAME structs_01 LABELS cpython llvm)
173+
RUN(NAME structs_02 LABELS llvm)
174+
RUN(NAME structs_03 LABELS llvm)
173175

174176
# Just CPython
175177
RUN(NAME test_builtin_bin LABELS cpython)

integration_tests/structs_02.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from ltypes import i32, f32, dataclass, CPtr, Pointer, c_p_pointer, pointer
2+
3+
@dataclass
4+
class A:
5+
x: i32
6+
y: f32
7+
8+
@ccallable
9+
def f(a: CPtr) -> None:
10+
x: i32
11+
y: f32
12+
a1: A
13+
a2: Pointer[A]
14+
a1 = A(3, 3.25)
15+
a2 = pointer(a1)
16+
print(a2, pointer(a1))
17+
x = a2.x
18+
y = a2.y
19+
assert x == 3
20+
assert y == 3.25
21+
c_p_pointer(a, a2)
22+
print(a, a2, pointer(a1))
23+
24+
def g():
25+
b: CPtr
26+
f(b)
27+
28+
g()

integration_tests/structs_03.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from ltypes import i32, f32, dataclass, Pointer, pointer
2+
3+
@dataclass
4+
class A:
5+
x: i32
6+
y: f32
7+
8+
def f(pa: Pointer[A]):
9+
print(pa.x)
10+
print(pa.y)
11+
12+
def g():
13+
x: A
14+
x = A(5, 5.5)
15+
px: Pointer[A]
16+
px = pointer(x)
17+
px.x = 5
18+
px.y = 5.5
19+
f(px)
20+
21+
g()

src/libasr/asr_utils.h

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,49 @@ static inline ASR::ttype_t* TYPE(const ASR::asr_t *f)
3737
return ASR::down_cast<ASR::ttype_t>(f);
3838
}
3939

40+
static inline char *symbol_name(const ASR::symbol_t *f)
41+
{
42+
switch (f->type) {
43+
case ASR::symbolType::Program: {
44+
return ASR::down_cast<ASR::Program_t>(f)->m_name;
45+
}
46+
case ASR::symbolType::Module: {
47+
return ASR::down_cast<ASR::Module_t>(f)->m_name;
48+
}
49+
case ASR::symbolType::Subroutine: {
50+
return ASR::down_cast<ASR::Subroutine_t>(f)->m_name;
51+
}
52+
case ASR::symbolType::Function: {
53+
return ASR::down_cast<ASR::Function_t>(f)->m_name;
54+
}
55+
case ASR::symbolType::GenericProcedure: {
56+
return ASR::down_cast<ASR::GenericProcedure_t>(f)->m_name;
57+
}
58+
case ASR::symbolType::DerivedType: {
59+
return ASR::down_cast<ASR::DerivedType_t>(f)->m_name;
60+
}
61+
case ASR::symbolType::Variable: {
62+
return ASR::down_cast<ASR::Variable_t>(f)->m_name;
63+
}
64+
case ASR::symbolType::ExternalSymbol: {
65+
return ASR::down_cast<ASR::ExternalSymbol_t>(f)->m_name;
66+
}
67+
case ASR::symbolType::ClassProcedure: {
68+
return ASR::down_cast<ASR::ClassProcedure_t>(f)->m_name;
69+
}
70+
case ASR::symbolType::CustomOperator: {
71+
return ASR::down_cast<ASR::CustomOperator_t>(f)->m_name;
72+
}
73+
case ASR::symbolType::AssociateBlock: {
74+
return ASR::down_cast<ASR::AssociateBlock_t>(f)->m_name;
75+
}
76+
case ASR::symbolType::Block: {
77+
return ASR::down_cast<ASR::Block_t>(f)->m_name;
78+
}
79+
default : throw LFortranException("Not implemented");
80+
}
81+
}
82+
4083
static inline ASR::symbol_t *symbol_get_past_external(ASR::symbol_t *f)
4184
{
4285
if (f->type == ASR::symbolType::ExternalSymbol) {
@@ -200,6 +243,14 @@ static inline std::string type_to_str_python(const ASR::ttype_t *t)
200243
case ASR::ttypeType::CPtr: {
201244
return "CPtr";
202245
}
246+
case ASR::ttypeType::Derived: {
247+
ASR::Derived_t* d = ASR::down_cast<ASR::Derived_t>(t);
248+
return symbol_name(d->m_derived_type);
249+
}
250+
case ASR::ttypeType::Pointer: {
251+
ASR::Pointer_t* p = ASR::down_cast<ASR::Pointer_t>(t);
252+
return "Pointer[" + type_to_str_python(p->m_type) + "]";
253+
}
203254
default : throw LFortranException("Not implemented " + std::to_string(t->type));
204255
}
205256
}
@@ -241,49 +292,6 @@ static inline ASR::expr_t* expr_value(ASR::expr_t *f)
241292
return ASR::expr_value0(f);
242293
}
243294

244-
static inline char *symbol_name(const ASR::symbol_t *f)
245-
{
246-
switch (f->type) {
247-
case ASR::symbolType::Program: {
248-
return ASR::down_cast<ASR::Program_t>(f)->m_name;
249-
}
250-
case ASR::symbolType::Module: {
251-
return ASR::down_cast<ASR::Module_t>(f)->m_name;
252-
}
253-
case ASR::symbolType::Subroutine: {
254-
return ASR::down_cast<ASR::Subroutine_t>(f)->m_name;
255-
}
256-
case ASR::symbolType::Function: {
257-
return ASR::down_cast<ASR::Function_t>(f)->m_name;
258-
}
259-
case ASR::symbolType::GenericProcedure: {
260-
return ASR::down_cast<ASR::GenericProcedure_t>(f)->m_name;
261-
}
262-
case ASR::symbolType::DerivedType: {
263-
return ASR::down_cast<ASR::DerivedType_t>(f)->m_name;
264-
}
265-
case ASR::symbolType::Variable: {
266-
return ASR::down_cast<ASR::Variable_t>(f)->m_name;
267-
}
268-
case ASR::symbolType::ExternalSymbol: {
269-
return ASR::down_cast<ASR::ExternalSymbol_t>(f)->m_name;
270-
}
271-
case ASR::symbolType::ClassProcedure: {
272-
return ASR::down_cast<ASR::ClassProcedure_t>(f)->m_name;
273-
}
274-
case ASR::symbolType::CustomOperator: {
275-
return ASR::down_cast<ASR::CustomOperator_t>(f)->m_name;
276-
}
277-
case ASR::symbolType::AssociateBlock: {
278-
return ASR::down_cast<ASR::AssociateBlock_t>(f)->m_name;
279-
}
280-
case ASR::symbolType::Block: {
281-
return ASR::down_cast<ASR::Block_t>(f)->m_name;
282-
}
283-
default : throw LFortranException("Not implemented");
284-
}
285-
}
286-
287295
static inline SymbolTable *symbol_parent_symtab(const ASR::symbol_t *f)
288296
{
289297
switch (f->type) {

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
12391239
return;
12401240
}
12411241
der_type_name = "";
1242+
uint64_t ptr_loads_copy = ptr_loads;
1243+
ptr_loads = ptr_loads_copy - ASR::is_a<ASR::Pointer_t>(*ASRUtils::expr_type(x.m_v));
12421244
this->visit_expr(*x.m_v);
1245+
ptr_loads = ptr_loads_copy;
12431246
ASR::Variable_t* member = down_cast<ASR::Variable_t>(symbol_get_past_external(x.m_m));
12441247
std::string member_name = std::string(member->m_name);
12451248
LFORTRAN_ASSERT(der_type_name.size() != 0);
@@ -1608,15 +1611,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16081611
}
16091612
case (ASR::ttypeType::Pointer) : {
16101613
ASR::ttype_t *t2 = ASR::down_cast<ASR::Pointer_t>(asr_type)->m_type;
1611-
switch (t2->type) {
1612-
case (ASR::ttypeType::Derived) : {
1613-
throw CodeGenError("Pointers for Derived type not implemented yet in conversion.");
1614-
}
1615-
default :
1616-
llvm_type = get_type_from_ttype_t(t2, m_storage, is_array_type,
1614+
llvm_type = get_type_from_ttype_t(t2, m_storage, is_array_type,
16171615
is_malloc_array_type, m_dims, n_dims, a_kind);
1618-
llvm_type = llvm_type->getPointerTo();
1619-
}
1616+
llvm_type = llvm_type->getPointerTo();
16201617
break;
16211618
}
16221619
case (ASR::ttypeType::List) : {
@@ -2654,9 +2651,14 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
26542651

26552652
void visit_CPtrToPointer(const ASR::CPtrToPointer_t& x) {
26562653
ASR::expr_t *cptr = x.m_cptr, *fptr = x.m_ptr, *shape = x.m_shape;
2654+
int reduce_loads = 0;
2655+
if( ASR::is_a<ASR::Var_t>(*cptr) ) {
2656+
ASR::Variable_t* cptr_var = ASRUtils::EXPR2VAR(cptr);
2657+
reduce_loads = cptr_var->m_intent == ASRUtils::intent_in;
2658+
}
26572659
if( ASRUtils::is_array(ASRUtils::expr_type(fptr)) ) {
26582660
uint64_t ptr_loads_copy = ptr_loads;
2659-
ptr_loads = 1;
2661+
ptr_loads = 1 - reduce_loads;
26602662
this->visit_expr(*cptr);
26612663
llvm::Value* llvm_cptr = tmp;
26622664
ptr_loads = 0;
@@ -2706,7 +2708,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
27062708
}
27072709
} else {
27082710
uint64_t ptr_loads_copy = ptr_loads;
2709-
ptr_loads = 1;
2711+
ptr_loads = 1 - reduce_loads;
27102712
this->visit_expr(*cptr);
27112713
llvm::Value* llvm_cptr = tmp;
27122714
ptr_loads = 0;
@@ -3704,13 +3706,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
37043706
switch (t2->type) {
37053707
case ASR::ttypeType::Integer:
37063708
case ASR::ttypeType::Real:
3707-
case ASR::ttypeType::Complex: {
3709+
case ASR::ttypeType::Complex:
3710+
case ASR::ttypeType::Derived: {
3711+
if( t2->type == ASR::ttypeType::Derived ) {
3712+
ASR::Derived_t* d = ASR::down_cast<ASR::Derived_t>(t2);
3713+
der_type_name = ASRUtils::symbol_name(d->m_derived_type);
3714+
}
37083715
fetch_ptr(x);
37093716
break;
37103717
}
37113718
case ASR::ttypeType::Character:
3712-
case ASR::ttypeType::Logical:
3713-
case ASR::ttypeType::Derived: {
3719+
case ASR::ttypeType::Logical: {
37143720
break;
37153721
}
37163722
default:
@@ -4095,7 +4101,16 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
40954101
std::vector<std::string> fmt;
40964102
for (size_t i=0; i<x.n_values; i++) {
40974103
uint64_t ptr_loads_copy = ptr_loads;
4098-
ptr_loads = 1;
4104+
int reduce_loads = 0;
4105+
ptr_loads = 2;
4106+
if( ASR::is_a<ASR::Var_t>(*x.m_values[i]) ) {
4107+
ASR::Variable_t* var = ASRUtils::EXPR2VAR(x.m_values[i]);
4108+
reduce_loads = var->m_intent == ASRUtils::intent_in;
4109+
if( ASR::is_a<ASR::Pointer_t>(*var->m_type) ) {
4110+
ptr_loads = 1;
4111+
}
4112+
}
4113+
ptr_loads = ptr_loads - reduce_loads;
40994114
this->visit_expr_wrapper(x.m_values[i], true);
41004115
ptr_loads = ptr_loads_copy;
41014116
ASR::expr_t *v = x.m_values[i];
@@ -4250,11 +4265,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
42504265
template <typename T>
42514266
inline void set_func_subrout_params(T* func_subrout, ASR::abiType& x_abi,
42524267
std::uint32_t& m_h, ASR::Variable_t*& orig_arg,
4253-
std::string& orig_arg_name, size_t arg_idx) {
4268+
std::string& orig_arg_name, ASR::intentType& arg_intent,
4269+
size_t arg_idx) {
42544270
m_h = get_hash((ASR::asr_t*)func_subrout);
42554271
orig_arg = EXPR2VAR(func_subrout->m_args[arg_idx]);
42564272
orig_arg_name = orig_arg->m_name;
42574273
x_abi = func_subrout->m_abi;
4274+
arg_intent = orig_arg->m_intent;
42584275
}
42594276

42604277

@@ -4270,6 +4287,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
42704287
ASR::Subroutine_t* sub = down_cast<ASR::Subroutine_t>(func_subrout);
42714288
x_abi = sub->m_abi;
42724289
}
4290+
// TODO: Below if check is dead. Remove.
42734291
if( x_abi == ASR::abiType::Intrinsic ) {
42744292
if( name == "lbound" || name == "ubound" ) {
42754293
ASR::Variable_t *arg = EXPR2VAR(x.m_args[0].m_value);
@@ -4298,23 +4316,24 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
42984316
tmp = llvm_symtab[h];
42994317
func_subrout = symbol_get_past_external(x.m_name);
43004318
x_abi = (ASR::abiType) 0;
4319+
ASR::intentType orig_arg_intent = ASR::intentType::Unspecified;
43014320
std::uint32_t m_h;
43024321
ASR::Variable_t *orig_arg = nullptr;
43034322
std::string orig_arg_name = "";
43044323
if( func_subrout->type == ASR::symbolType::Function ) {
43054324
ASR::Function_t* func = down_cast<ASR::Function_t>(func_subrout);
4306-
set_func_subrout_params(func, x_abi, m_h, orig_arg, orig_arg_name, i);
4325+
set_func_subrout_params(func, x_abi, m_h, orig_arg, orig_arg_name, orig_arg_intent, i);
43074326
} else if( func_subrout->type == ASR::symbolType::Subroutine ) {
43084327
ASR::Subroutine_t* sub = down_cast<ASR::Subroutine_t>(func_subrout);
4309-
set_func_subrout_params(sub, x_abi, m_h, orig_arg, orig_arg_name, i);
4328+
set_func_subrout_params(sub, x_abi, m_h, orig_arg, orig_arg_name, orig_arg_intent, i);
43104329
} else if( func_subrout->type == ASR::symbolType::ClassProcedure ) {
43114330
ASR::ClassProcedure_t* clss_proc = ASR::down_cast<ASR::ClassProcedure_t>(func_subrout);
43124331
if( clss_proc->m_proc->type == ASR::symbolType::Subroutine ) {
43134332
ASR::Subroutine_t* sub = down_cast<ASR::Subroutine_t>(clss_proc->m_proc);
4314-
set_func_subrout_params(sub, x_abi, m_h, orig_arg, orig_arg_name, i);
4333+
set_func_subrout_params(sub, x_abi, m_h, orig_arg, orig_arg_name, orig_arg_intent, i);
43154334
} else if( clss_proc->m_proc->type == ASR::symbolType::Function ) {
43164335
ASR::Function_t* func = down_cast<ASR::Function_t>(clss_proc->m_proc);
4317-
set_func_subrout_params(func, x_abi, m_h, orig_arg, orig_arg_name, i);
4336+
set_func_subrout_params(func, x_abi, m_h, orig_arg, orig_arg_name, orig_arg_intent, i);
43184337
}
43194338
} else {
43204339
LFORTRAN_ASSERT(false)

0 commit comments

Comments
 (0)