Skip to content

Commit 6e870f5

Browse files
committed
Improve class_constructor.cpp pass and StructTypeConstructor node in ASR
1 parent 9da3fce commit 6e870f5

14 files changed

+218
-82
lines changed

src/libasr/ASR.asdl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ expr
224224
| NamedExpr(expr target, expr value, ttype type)
225225
| FunctionCall(symbol name, symbol? original_name, call_arg* args,
226226
ttype type, expr? value, expr? dt)
227-
| StructTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
227+
| StructTypeConstructor(symbol dt_sym, call_arg* args, ttype type, expr? value)
228228
| EnumTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
229229
| UnionTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
230230
| ImpliedDoLoop(expr* values, expr var, expr start, expr end,

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5555,8 +5555,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
55555555
break;
55565556
}
55575557
case ASR::ttypeType::Struct: {
5558-
ASR::Struct_t* der = (ASR::Struct_t*)(&(x->m_type->base));
5559-
ASR::StructType_t* der_type = (ASR::StructType_t*)(&(der->m_derived_type->base));
5558+
ASR::Struct_t* der = ASR::down_cast<ASR::Struct_t>(x->m_type);
5559+
ASR::StructType_t* der_type = ASR::down_cast<ASR::StructType_t>(
5560+
ASRUtils::symbol_get_past_external(der->m_derived_type));
55605561
der_type_name = std::string(der_type->m_name);
55615562
uint32_t h = get_hash((ASR::asr_t*)x);
55625563
if( llvm_symtab.find(h) != llvm_symtab.end() ) {
@@ -5576,8 +5577,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
55765577
break;
55775578
}
55785579
case ASR::ttypeType::Class: {
5579-
ASR::Class_t* der = (ASR::Class_t*)(&(x->m_type->base));
5580-
ASR::ClassType_t* der_type = (ASR::ClassType_t*)(&(der->m_class_type->base));
5580+
ASR::Class_t* der = ASR::down_cast<ASR::Class_t>(x->m_type);
5581+
ASR::ClassType_t* der_type = ASR::down_cast<ASR::ClassType_t>(
5582+
ASRUtils::symbol_get_past_external(der->m_class_type));
55815583
der_type_name = std::string(der_type->m_name);
55825584
uint32_t h = get_hash((ASR::asr_t*)x);
55835585
if( llvm_symtab.find(h) != llvm_symtab.end() ) {

src/libasr/pass/class_constructor.cpp

Lines changed: 198 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,91 +3,228 @@
33
#include <libasr/exception.h>
44
#include <libasr/asr_utils.h>
55
#include <libasr/asr_verify.h>
6-
#include <libasr/pass/pass_utils.h>
76
#include <libasr/pass/class_constructor.h>
7+
#include <libasr/pass/pass_utils.h>
88

9-
#include <cstring>
10-
9+
#include <vector>
10+
#include <utility>
11+
#include <queue>
1112

1213
namespace LCompilers {
1314

1415
using ASR::down_cast;
1516
using ASR::is_a;
1617

17-
class ClassConstructorVisitor : public PassUtils::PassVisitor<ClassConstructorVisitor>
18-
{
19-
private:
18+
class ReplaceStructTypeConstructor: public ASR::BaseExprReplacer<ReplaceStructTypeConstructor> {
19+
20+
private:
21+
22+
Allocator& al;
23+
Vec<ASR::stmt_t*>& pass_result;
24+
bool& remove_original_statement;
25+
bool& inside_symtab;
26+
bool& apply_again;
27+
std::map<SymbolTable*, Vec<ASR::stmt_t*>>& symtab2decls;
28+
29+
public:
30+
31+
SymbolTable* current_scope;
2032
ASR::expr_t* result_var;
2133

22-
public:
34+
ReplaceStructTypeConstructor(Allocator& al_, Vec<ASR::stmt_t*>& pass_result_,
35+
bool& remove_original_statement_, bool& inside_symtab_, bool& apply_again_,
36+
std::map<SymbolTable*, Vec<ASR::stmt_t*>>& symtab2decls_) :
37+
al(al_), pass_result(pass_result_),
38+
remove_original_statement(remove_original_statement_),
39+
inside_symtab(inside_symtab_), apply_again(apply_again_),
40+
symtab2decls(symtab2decls_), current_scope(nullptr),
41+
result_var(nullptr) {}
2342

24-
bool is_constructor_present;
43+
void replace_StructTypeConstructor(ASR::StructTypeConstructor_t* x) {
44+
if( x->n_args == 0 ) {
45+
remove_original_statement = true;
46+
return ;
47+
}
48+
if( result_var == nullptr ) {
49+
std::string result_var_name = current_scope->get_unique_name("temp_struct_var__");
50+
result_var = PassUtils::create_auxiliary_variable(x->base.base.loc,
51+
result_var_name, al, current_scope, x->m_type);
52+
*current_expr = result_var;
53+
} else {
54+
if( inside_symtab ) {
55+
*current_expr = nullptr;
56+
} else {
57+
remove_original_statement = true;
58+
}
59+
}
2560

26-
ClassConstructorVisitor(Allocator &al) : PassVisitor(al, nullptr),
27-
result_var(nullptr), is_constructor_present(false) {
28-
pass_result.reserve(al, 0);
29-
}
61+
std::deque<ASR::symbol_t*> constructor_arg_syms;
62+
ASR::Struct_t* dt_der = ASR::down_cast<ASR::Struct_t>(x->m_type);
63+
ASR::StructType_t* dt_dertype = ASR::down_cast<ASR::StructType_t>(
64+
ASRUtils::symbol_get_past_external(dt_der->m_derived_type));
65+
while( dt_dertype ) {
66+
for( int i = (int) dt_dertype->n_members - 1; i >= 0; i-- ) {
67+
constructor_arg_syms.push_front(
68+
dt_dertype->m_symtab->get_symbol(
69+
dt_dertype->m_members[i]));
70+
}
71+
if( dt_dertype->m_parent != nullptr ) {
72+
ASR::symbol_t* dt_der_sym = ASRUtils::symbol_get_past_external(
73+
dt_dertype->m_parent);
74+
LCOMPILERS_ASSERT(ASR::is_a<ASR::StructType_t>(*dt_der_sym));
75+
dt_dertype = ASR::down_cast<ASR::StructType_t>(dt_der_sym);
76+
} else {
77+
dt_dertype = nullptr;
78+
}
79+
}
80+
LCOMPILERS_ASSERT(constructor_arg_syms.size() == x->n_args);
3081

31-
void visit_Assignment(const ASR::Assignment_t& x) {
32-
ASR::Assignment_t& xx = const_cast<ASR::Assignment_t&>(x);
33-
if( x.m_value->type == ASR::exprType::StructTypeConstructor ) {
34-
is_constructor_present = true;
35-
if( x.m_overloaded == nullptr ) {
36-
result_var = x.m_target;
37-
visit_expr(*x.m_value);
82+
for( size_t i = 0; i < x->n_args; i++ ) {
83+
if( x->m_args[i].m_value == nullptr ) {
84+
continue ;
85+
}
86+
ASR::symbol_t* member = constructor_arg_syms[i];
87+
if( ASR::is_a<ASR::StructTypeConstructor_t>(*x->m_args[i].m_value) ) {
88+
ASR::expr_t* result_var_copy = result_var;
89+
ASR::symbol_t *v = nullptr;
90+
if (ASR::is_a<ASR::Var_t>(*result_var_copy)) {
91+
v = ASR::down_cast<ASR::Var_t>(result_var_copy)->m_v;
92+
}
93+
result_var = ASRUtils::EXPR(ASRUtils::getStructInstanceMember_t(al,
94+
x->base.base.loc, (ASR::asr_t*) result_var_copy, v,
95+
member, current_scope));
96+
ASR::expr_t** current_expr_copy = current_expr;
97+
current_expr = &(x->m_args[i].m_value);
98+
replace_expr(x->m_args[i].m_value);
99+
current_expr = current_expr_copy;
100+
result_var = result_var_copy;
38101
} else {
39-
std::string result_var_name = current_scope->get_unique_name("temp_struct_var__");
40-
result_var = PassUtils::create_auxiliary_variable(x.m_value->base.loc, result_var_name,
41-
al, current_scope, ASRUtils::expr_type(x.m_target));
42-
visit_expr(*x.m_value);
43-
ASR::stmt_t* x_m_overloaded = x.m_overloaded;
44-
if( ASR::is_a<ASR::SubroutineCall_t>(*x.m_overloaded) ) {
45-
ASR::SubroutineCall_t* assign_call = ASR::down_cast<ASR::SubroutineCall_t>(xx.m_overloaded);
46-
Vec<ASR::call_arg_t> assign_call_args;
47-
assign_call_args.reserve(al, 2);
48-
assign_call_args.push_back(al, assign_call->m_args[0]);
49-
ASR::call_arg_t arg_1;
50-
arg_1.loc = assign_call->m_args[1].loc;
51-
arg_1.m_value = result_var;
52-
assign_call_args.push_back(al, arg_1);
53-
x_m_overloaded = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x_m_overloaded->base.loc,
54-
assign_call->m_name, assign_call->m_original_name, assign_call_args.p,
55-
assign_call_args.size(), assign_call->m_dt));
102+
Vec<ASR::stmt_t*>* result_vec = nullptr;
103+
if( inside_symtab ) {
104+
if( symtab2decls.find(current_scope) == symtab2decls.end() ) {
105+
Vec<ASR::stmt_t*> result_vec_;
106+
result_vec_.reserve(al, 0);
107+
symtab2decls[current_scope] = result_vec_;
108+
}
109+
result_vec = &symtab2decls[current_scope];
110+
} else {
111+
result_vec = &pass_result;
56112
}
57-
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target,
58-
result_var, x_m_overloaded)));
113+
ASR::symbol_t *v = nullptr;
114+
if (ASR::is_a<ASR::Var_t>(*result_var)) {
115+
v = ASR::down_cast<ASR::Var_t>(result_var)->m_v;
116+
}
117+
ASR::expr_t* derived_ref = ASRUtils::EXPR(ASRUtils::getStructInstanceMember_t(al,
118+
x->base.base.loc, (ASR::asr_t*) result_var, v,
119+
member, current_scope));
120+
ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al,
121+
x->base.base.loc, derived_ref,
122+
x->m_args[i].m_value, nullptr));
123+
result_vec->push_back(al, assign);
59124
}
60125
}
61-
62126
}
127+
};
63128

64-
void visit_StructTypeConstructor(const ASR::StructTypeConstructor_t &x) {
65-
if( x.n_args == 0 ) {
66-
remove_original_stmt = true;
129+
class StructTypeConstructorVisitor : public ASR::CallReplacerOnExpressionsVisitor<StructTypeConstructorVisitor>
130+
{
131+
private:
132+
133+
Allocator& al;
134+
bool remove_original_statement;
135+
bool inside_symtab;
136+
bool apply_again;
137+
ReplaceStructTypeConstructor replacer;
138+
Vec<ASR::stmt_t*> pass_result;
139+
std::map<SymbolTable*, Vec<ASR::stmt_t*>> symtab2decls;
140+
141+
public:
142+
143+
StructTypeConstructorVisitor(Allocator& al_) :
144+
al(al_), remove_original_statement(false),
145+
inside_symtab(true), apply_again(false),
146+
replacer(al_, pass_result,
147+
remove_original_statement, inside_symtab,
148+
apply_again, symtab2decls) {
149+
pass_result.n = 0;
150+
pass_result.reserve(al, 0);
67151
}
68-
ASR::Struct_t* dt_der = down_cast<ASR::Struct_t>(x.m_type);
69-
ASR::StructType_t* dt_dertype = ASR::down_cast<ASR::StructType_t>(
70-
ASRUtils::symbol_get_past_external(dt_der->m_derived_type));
71-
for( size_t i = 0; i < std::min(dt_dertype->n_members, x.n_args); i++ ) {
72-
ASR::symbol_t* member = dt_dertype->m_symtab->resolve_symbol(std::string(dt_dertype->m_members[i], strlen(dt_dertype->m_members[i])));
73-
ASR::symbol_t *v = nullptr;
74-
if (is_a<ASR::Var_t>(*result_var)) {
75-
v = down_cast<ASR::Var_t>(result_var)->m_v;
152+
153+
void call_replacer() {
154+
replacer.current_expr = current_expr;
155+
replacer.current_scope = current_scope;
156+
replacer.replace_expr(*current_expr);
157+
}
158+
159+
void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {
160+
bool inside_symtab_copy = inside_symtab;
161+
inside_symtab = false;
162+
Vec<ASR::stmt_t*> body;
163+
body.reserve(al, n_body);
164+
165+
if( symtab2decls.find(current_scope) != symtab2decls.end() ) {
166+
Vec<ASR::stmt_t*>& decls = symtab2decls[current_scope];
167+
for (size_t j = 0; j < decls.size(); j++) {
168+
body.push_back(al, decls[j]);
169+
}
170+
symtab2decls.erase(current_scope);
171+
}
172+
173+
for (size_t i = 0; i < n_body; i++) {
174+
pass_result.n = 0;
175+
pass_result.reserve(al, 1);
176+
apply_again = false;
177+
remove_original_statement = false;
178+
replacer.result_var = nullptr;
179+
visit_stmt(*m_body[i]);
180+
for (size_t j = 0; j < pass_result.size(); j++) {
181+
body.push_back(al, pass_result[j]);
182+
}
183+
if( !remove_original_statement ) {
184+
body.push_back(al, m_body[i]);
185+
}
186+
remove_original_statement = false;
76187
}
77-
ASR::expr_t* derived_ref = ASRUtils::EXPR(ASRUtils::getStructInstanceMember_t(al, x.base.base.loc, (ASR::asr_t*)result_var, v, member, current_scope));
78-
ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, derived_ref, x.m_args[i], nullptr));
79-
pass_result.push_back(al, assign);
188+
m_body = body.p;
189+
n_body = body.size();
190+
replacer.result_var = nullptr;
191+
pass_result.n = 0;
192+
pass_result.reserve(al, 0);
193+
inside_symtab = inside_symtab_copy;
80194
}
81-
}
195+
196+
void visit_Variable(const ASR::Variable_t &x) {
197+
ASR::Variable_t& xx = const_cast<ASR::Variable_t&>(x);
198+
replacer.result_var = ASRUtils::EXPR(ASR::make_Var_t(al,
199+
x.base.base.loc, &(xx.base)));
200+
ASR::CallReplacerOnExpressionsVisitor<
201+
StructTypeConstructorVisitor>::visit_Variable(x);
202+
}
203+
204+
void visit_Assignment(const ASR::Assignment_t &x) {
205+
if (x.m_overloaded) {
206+
this->visit_stmt(*x.m_overloaded);
207+
remove_original_statement = false;
208+
return ;
209+
}
210+
211+
replacer.result_var = x.m_target;
212+
ASR::expr_t** current_expr_copy_9 = current_expr;
213+
current_expr = const_cast<ASR::expr_t**>(&(x.m_value));
214+
this->call_replacer();
215+
current_expr = current_expr_copy_9;
216+
if( !remove_original_statement ) {
217+
this->visit_expr(*x.m_value);
218+
}
219+
}
220+
82221
};
83222

84-
void pass_replace_class_constructor(Allocator &al, ASR::TranslationUnit_t &unit,
85-
const LCompilers::PassOptions& /*pass_options*/) {
86-
ClassConstructorVisitor v(al);
87-
do {
88-
v.is_constructor_present = false;
89-
v.visit_TranslationUnit(unit);
90-
} while( v.is_constructor_present );
223+
void pass_replace_class_constructor(Allocator &al,
224+
ASR::TranslationUnit_t &unit,
225+
const LCompilers::PassOptions& /*pass_options*/) {
226+
StructTypeConstructorVisitor v(al);
227+
v.visit_TranslationUnit(unit);
91228
}
92229

93230

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,15 +1138,12 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
11381138
s_generic, args_new.p, args_new.size(), nullptr);
11391139
}
11401140
} else if(ASR::is_a<ASR::StructType_t>(*s)) {
1141-
Vec<ASR::expr_t*> args_new;
1142-
args_new.reserve(al, args.size());
1143-
visit_expr_list(args, args.size(), args_new);
11441141
ASR::StructType_t* StructType = ASR::down_cast<ASR::StructType_t>(s);
11451142
for( size_t i = 0; i < std::min(args.size(), StructType->n_members); i++ ) {
11461143
std::string member_name = StructType->m_members[i];
11471144
ASR::Variable_t* member_var = ASR::down_cast<ASR::Variable_t>(
11481145
StructType->m_symtab->resolve_symbol(member_name));
1149-
ASR::expr_t* arg_new_i = args_new[i];
1146+
ASR::expr_t* arg_new_i = args[i].m_value;
11501147
cast_helper(member_var->m_type, arg_new_i, arg_new_i->base.loc);
11511148
ASR::ttype_t* left_type = member_var->m_type;
11521149
ASR::ttype_t* right_type = ASRUtils::expr_type(arg_new_i);
@@ -1162,10 +1159,10 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
11621159
);
11631160
throw SemanticAbort();
11641161
}
1165-
args_new.p[i] = arg_new_i;
1162+
args.p[i].m_value = arg_new_i;
11661163
}
11671164
ASR::ttype_t* der_type = ASRUtils::TYPE(ASR::make_Struct_t(al, loc, stemp, nullptr, 0));
1168-
return ASR::make_StructTypeConstructor_t(al, loc, stemp, args_new.p, args_new.size(), der_type, nullptr);
1165+
return ASR::make_StructTypeConstructor_t(al, loc, stemp, args.p, args.size(), der_type, nullptr);
11691166
} else if( ASR::is_a<ASR::EnumType_t>(*s) ) {
11701167
Vec<ASR::expr_t*> args_new;
11711168
args_new.reserve(al, args.size());

tests/reference/asr-structs_01-be14d49.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-structs_01-be14d49.stdout",
9-
"stdout_hash": "bd6651b101dc97e717d0604af576a8492e549d6830a89ae0b2d2a83f",
9+
"stdout_hash": "7fc590bac6faf35bb632814509f73079edbd4c868b32dcfcf5efe3f2",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

0 commit comments

Comments
 (0)