Skip to content

Commit cd7217a

Browse files
authored
Improve class_constructor.cpp pass and StructTypeConstructor node in ASR (#1579)
1 parent fe61c4d commit cd7217a

14 files changed

+215
-82
lines changed

src/libasr/ASR.asdl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ expr
224224
| NamedExpr(expr target, expr value, ttype type)
225225
| FunctionCall(symbol name, symbol? original_name, call_arg* args,
226226
ttype type, expr? value, expr? dt)
227-
| StructTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
227+
| StructTypeConstructor(symbol dt_sym, call_arg* args, ttype type, expr? value)
228228
| EnumTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
229229
| UnionTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
230230
| ImpliedDoLoop(expr* values, expr var, expr start, expr end,

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5563,8 +5563,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
55635563
break;
55645564
}
55655565
case ASR::ttypeType::Struct: {
5566-
ASR::Struct_t* der = (ASR::Struct_t*)(&(x->m_type->base));
5567-
ASR::StructType_t* der_type = (ASR::StructType_t*)(&(der->m_derived_type->base));
5566+
ASR::Struct_t* der = ASR::down_cast<ASR::Struct_t>(x->m_type);
5567+
ASR::StructType_t* der_type = ASR::down_cast<ASR::StructType_t>(
5568+
ASRUtils::symbol_get_past_external(der->m_derived_type));
55685569
der_type_name = std::string(der_type->m_name);
55695570
uint32_t h = get_hash((ASR::asr_t*)x);
55705571
if( llvm_symtab.find(h) != llvm_symtab.end() ) {
@@ -5584,8 +5585,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
55845585
break;
55855586
}
55865587
case ASR::ttypeType::Class: {
5587-
ASR::Class_t* der = (ASR::Class_t*)(&(x->m_type->base));
5588-
ASR::ClassType_t* der_type = (ASR::ClassType_t*)(&(der->m_class_type->base));
5588+
ASR::Class_t* der = ASR::down_cast<ASR::Class_t>(x->m_type);
5589+
ASR::ClassType_t* der_type = ASR::down_cast<ASR::ClassType_t>(
5590+
ASRUtils::symbol_get_past_external(der->m_class_type));
55895591
der_type_name = std::string(der_type->m_name);
55905592
uint32_t h = get_hash((ASR::asr_t*)x);
55915593
if( llvm_symtab.find(h) != llvm_symtab.end() ) {

src/libasr/pass/class_constructor.cpp

Lines changed: 195 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,91 +3,225 @@
33
#include <libasr/exception.h>
44
#include <libasr/asr_utils.h>
55
#include <libasr/asr_verify.h>
6-
#include <libasr/pass/pass_utils.h>
76
#include <libasr/pass/class_constructor.h>
7+
#include <libasr/pass/pass_utils.h>
88

9-
#include <cstring>
10-
9+
#include <vector>
10+
#include <utility>
11+
#include <queue>
1112

1213
namespace LCompilers {
1314

1415
using ASR::down_cast;
1516
using ASR::is_a;
1617

17-
class ClassConstructorVisitor : public PassUtils::PassVisitor<ClassConstructorVisitor>
18-
{
19-
private:
18+
class ReplaceStructTypeConstructor: public ASR::BaseExprReplacer<ReplaceStructTypeConstructor> {
19+
20+
private:
21+
22+
Allocator& al;
23+
Vec<ASR::stmt_t*>& pass_result;
24+
bool& remove_original_statement;
25+
bool& inside_symtab;
26+
std::map<SymbolTable*, Vec<ASR::stmt_t*>>& symtab2decls;
27+
28+
public:
29+
30+
SymbolTable* current_scope;
2031
ASR::expr_t* result_var;
2132

22-
public:
33+
ReplaceStructTypeConstructor(Allocator& al_, Vec<ASR::stmt_t*>& pass_result_,
34+
bool& remove_original_statement_, bool& inside_symtab_,
35+
std::map<SymbolTable*, Vec<ASR::stmt_t*>>& symtab2decls_) :
36+
al(al_), pass_result(pass_result_),
37+
remove_original_statement(remove_original_statement_),
38+
inside_symtab(inside_symtab_), symtab2decls(symtab2decls_),
39+
current_scope(nullptr), result_var(nullptr) {}
2340

24-
bool is_constructor_present;
41+
void replace_StructTypeConstructor(ASR::StructTypeConstructor_t* x) {
42+
if( x->n_args == 0 ) {
43+
remove_original_statement = true;
44+
return ;
45+
}
46+
if( result_var == nullptr ) {
47+
std::string result_var_name = current_scope->get_unique_name("temp_struct_var__");
48+
result_var = PassUtils::create_auxiliary_variable(x->base.base.loc,
49+
result_var_name, al, current_scope, x->m_type);
50+
*current_expr = result_var;
51+
} else {
52+
if( inside_symtab ) {
53+
*current_expr = nullptr;
54+
} else {
55+
remove_original_statement = true;
56+
}
57+
}
2558

26-
ClassConstructorVisitor(Allocator &al) : PassVisitor(al, nullptr),
27-
result_var(nullptr), is_constructor_present(false) {
28-
pass_result.reserve(al, 0);
29-
}
59+
std::deque<ASR::symbol_t*> constructor_arg_syms;
60+
ASR::Struct_t* dt_der = ASR::down_cast<ASR::Struct_t>(x->m_type);
61+
ASR::StructType_t* dt_dertype = ASR::down_cast<ASR::StructType_t>(
62+
ASRUtils::symbol_get_past_external(dt_der->m_derived_type));
63+
while( dt_dertype ) {
64+
for( int i = (int) dt_dertype->n_members - 1; i >= 0; i-- ) {
65+
constructor_arg_syms.push_front(
66+
dt_dertype->m_symtab->get_symbol(
67+
dt_dertype->m_members[i]));
68+
}
69+
if( dt_dertype->m_parent != nullptr ) {
70+
ASR::symbol_t* dt_der_sym = ASRUtils::symbol_get_past_external(
71+
dt_dertype->m_parent);
72+
LCOMPILERS_ASSERT(ASR::is_a<ASR::StructType_t>(*dt_der_sym));
73+
dt_dertype = ASR::down_cast<ASR::StructType_t>(dt_der_sym);
74+
} else {
75+
dt_dertype = nullptr;
76+
}
77+
}
78+
LCOMPILERS_ASSERT(constructor_arg_syms.size() == x->n_args);
3079

31-
void visit_Assignment(const ASR::Assignment_t& x) {
32-
ASR::Assignment_t& xx = const_cast<ASR::Assignment_t&>(x);
33-
if( x.m_value->type == ASR::exprType::StructTypeConstructor ) {
34-
is_constructor_present = true;
35-
if( x.m_overloaded == nullptr ) {
36-
result_var = x.m_target;
37-
visit_expr(*x.m_value);
80+
for( size_t i = 0; i < x->n_args; i++ ) {
81+
if( x->m_args[i].m_value == nullptr ) {
82+
continue ;
83+
}
84+
ASR::symbol_t* member = constructor_arg_syms[i];
85+
if( ASR::is_a<ASR::StructTypeConstructor_t>(*x->m_args[i].m_value) ) {
86+
ASR::expr_t* result_var_copy = result_var;
87+
ASR::symbol_t *v = nullptr;
88+
if (ASR::is_a<ASR::Var_t>(*result_var_copy)) {
89+
v = ASR::down_cast<ASR::Var_t>(result_var_copy)->m_v;
90+
}
91+
result_var = ASRUtils::EXPR(ASRUtils::getStructInstanceMember_t(al,
92+
x->base.base.loc, (ASR::asr_t*) result_var_copy, v,
93+
member, current_scope));
94+
ASR::expr_t** current_expr_copy = current_expr;
95+
current_expr = &(x->m_args[i].m_value);
96+
replace_expr(x->m_args[i].m_value);
97+
current_expr = current_expr_copy;
98+
result_var = result_var_copy;
3899
} else {
39-
std::string result_var_name = current_scope->get_unique_name("temp_struct_var__");
40-
result_var = PassUtils::create_auxiliary_variable(x.m_value->base.loc, result_var_name,
41-
al, current_scope, ASRUtils::expr_type(x.m_target));
42-
visit_expr(*x.m_value);
43-
ASR::stmt_t* x_m_overloaded = x.m_overloaded;
44-
if( ASR::is_a<ASR::SubroutineCall_t>(*x.m_overloaded) ) {
45-
ASR::SubroutineCall_t* assign_call = ASR::down_cast<ASR::SubroutineCall_t>(xx.m_overloaded);
46-
Vec<ASR::call_arg_t> assign_call_args;
47-
assign_call_args.reserve(al, 2);
48-
assign_call_args.push_back(al, assign_call->m_args[0]);
49-
ASR::call_arg_t arg_1;
50-
arg_1.loc = assign_call->m_args[1].loc;
51-
arg_1.m_value = result_var;
52-
assign_call_args.push_back(al, arg_1);
53-
x_m_overloaded = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x_m_overloaded->base.loc,
54-
assign_call->m_name, assign_call->m_original_name, assign_call_args.p,
55-
assign_call_args.size(), assign_call->m_dt));
100+
Vec<ASR::stmt_t*>* result_vec = nullptr;
101+
if( inside_symtab ) {
102+
if( symtab2decls.find(current_scope) == symtab2decls.end() ) {
103+
Vec<ASR::stmt_t*> result_vec_;
104+
result_vec_.reserve(al, 0);
105+
symtab2decls[current_scope] = result_vec_;
106+
}
107+
result_vec = &symtab2decls[current_scope];
108+
} else {
109+
result_vec = &pass_result;
56110
}
57-
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target,
58-
result_var, x_m_overloaded)));
111+
ASR::symbol_t *v = nullptr;
112+
if (ASR::is_a<ASR::Var_t>(*result_var)) {
113+
v = ASR::down_cast<ASR::Var_t>(result_var)->m_v;
114+
}
115+
ASR::expr_t* derived_ref = ASRUtils::EXPR(ASRUtils::getStructInstanceMember_t(al,
116+
x->base.base.loc, (ASR::asr_t*) result_var, v,
117+
member, current_scope));
118+
ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al,
119+
x->base.base.loc, derived_ref,
120+
x->m_args[i].m_value, nullptr));
121+
result_vec->push_back(al, assign);
59122
}
60123
}
61-
62124
}
125+
};
63126

64-
void visit_StructTypeConstructor(const ASR::StructTypeConstructor_t &x) {
65-
if( x.n_args == 0 ) {
66-
remove_original_stmt = true;
127+
class StructTypeConstructorVisitor : public ASR::CallReplacerOnExpressionsVisitor<StructTypeConstructorVisitor>
128+
{
129+
private:
130+
131+
Allocator& al;
132+
bool remove_original_statement;
133+
bool inside_symtab;
134+
ReplaceStructTypeConstructor replacer;
135+
Vec<ASR::stmt_t*> pass_result;
136+
std::map<SymbolTable*, Vec<ASR::stmt_t*>> symtab2decls;
137+
138+
public:
139+
140+
StructTypeConstructorVisitor(Allocator& al_) :
141+
al(al_), remove_original_statement(false),
142+
inside_symtab(true), replacer(al_, pass_result,
143+
remove_original_statement, inside_symtab,
144+
symtab2decls) {
145+
pass_result.n = 0;
146+
pass_result.reserve(al, 0);
67147
}
68-
ASR::Struct_t* dt_der = down_cast<ASR::Struct_t>(x.m_type);
69-
ASR::StructType_t* dt_dertype = ASR::down_cast<ASR::StructType_t>(
70-
ASRUtils::symbol_get_past_external(dt_der->m_derived_type));
71-
for( size_t i = 0; i < std::min(dt_dertype->n_members, x.n_args); i++ ) {
72-
ASR::symbol_t* member = dt_dertype->m_symtab->resolve_symbol(std::string(dt_dertype->m_members[i], strlen(dt_dertype->m_members[i])));
73-
ASR::symbol_t *v = nullptr;
74-
if (is_a<ASR::Var_t>(*result_var)) {
75-
v = down_cast<ASR::Var_t>(result_var)->m_v;
148+
149+
void call_replacer() {
150+
replacer.current_expr = current_expr;
151+
replacer.current_scope = current_scope;
152+
replacer.replace_expr(*current_expr);
153+
}
154+
155+
void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {
156+
bool inside_symtab_copy = inside_symtab;
157+
inside_symtab = false;
158+
Vec<ASR::stmt_t*> body;
159+
body.reserve(al, n_body);
160+
161+
if( symtab2decls.find(current_scope) != symtab2decls.end() ) {
162+
Vec<ASR::stmt_t*>& decls = symtab2decls[current_scope];
163+
for (size_t j = 0; j < decls.size(); j++) {
164+
body.push_back(al, decls[j]);
165+
}
166+
symtab2decls.erase(current_scope);
167+
}
168+
169+
for (size_t i = 0; i < n_body; i++) {
170+
pass_result.n = 0;
171+
pass_result.reserve(al, 1);
172+
remove_original_statement = false;
173+
replacer.result_var = nullptr;
174+
visit_stmt(*m_body[i]);
175+
for (size_t j = 0; j < pass_result.size(); j++) {
176+
body.push_back(al, pass_result[j]);
177+
}
178+
if( !remove_original_statement ) {
179+
body.push_back(al, m_body[i]);
180+
}
181+
remove_original_statement = false;
76182
}
77-
ASR::expr_t* derived_ref = ASRUtils::EXPR(ASRUtils::getStructInstanceMember_t(al, x.base.base.loc, (ASR::asr_t*)result_var, v, member, current_scope));
78-
ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, derived_ref, x.m_args[i], nullptr));
79-
pass_result.push_back(al, assign);
183+
m_body = body.p;
184+
n_body = body.size();
185+
replacer.result_var = nullptr;
186+
pass_result.n = 0;
187+
pass_result.reserve(al, 0);
188+
inside_symtab = inside_symtab_copy;
80189
}
81-
}
190+
191+
void visit_Variable(const ASR::Variable_t &x) {
192+
ASR::Variable_t& xx = const_cast<ASR::Variable_t&>(x);
193+
replacer.result_var = ASRUtils::EXPR(ASR::make_Var_t(al,
194+
x.base.base.loc, &(xx.base)));
195+
ASR::CallReplacerOnExpressionsVisitor<
196+
StructTypeConstructorVisitor>::visit_Variable(x);
197+
}
198+
199+
void visit_Assignment(const ASR::Assignment_t &x) {
200+
if (x.m_overloaded) {
201+
this->visit_stmt(*x.m_overloaded);
202+
remove_original_statement = false;
203+
return ;
204+
}
205+
206+
replacer.result_var = x.m_target;
207+
ASR::expr_t** current_expr_copy_9 = current_expr;
208+
current_expr = const_cast<ASR::expr_t**>(&(x.m_value));
209+
this->call_replacer();
210+
current_expr = current_expr_copy_9;
211+
if( !remove_original_statement ) {
212+
this->visit_expr(*x.m_value);
213+
}
214+
}
215+
82216
};
83217

84-
void pass_replace_class_constructor(Allocator &al, ASR::TranslationUnit_t &unit,
85-
const LCompilers::PassOptions& /*pass_options*/) {
86-
ClassConstructorVisitor v(al);
87-
do {
88-
v.is_constructor_present = false;
89-
v.visit_TranslationUnit(unit);
90-
} while( v.is_constructor_present );
218+
void pass_replace_class_constructor(Allocator &al,
219+
ASR::TranslationUnit_t &unit,
220+
const LCompilers::PassOptions& /*pass_options*/) {
221+
StructTypeConstructorVisitor v(al);
222+
v.visit_TranslationUnit(unit);
223+
PassUtils::UpdateDependenciesVisitor w(al);
224+
w.visit_TranslationUnit(unit);
91225
}
92226

93227

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,15 +1138,12 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
11381138
s_generic, args_new.p, args_new.size(), nullptr);
11391139
}
11401140
} else if(ASR::is_a<ASR::StructType_t>(*s)) {
1141-
Vec<ASR::expr_t*> args_new;
1142-
args_new.reserve(al, args.size());
1143-
visit_expr_list(args, args.size(), args_new);
11441141
ASR::StructType_t* StructType = ASR::down_cast<ASR::StructType_t>(s);
11451142
for( size_t i = 0; i < std::min(args.size(), StructType->n_members); i++ ) {
11461143
std::string member_name = StructType->m_members[i];
11471144
ASR::Variable_t* member_var = ASR::down_cast<ASR::Variable_t>(
11481145
StructType->m_symtab->resolve_symbol(member_name));
1149-
ASR::expr_t* arg_new_i = args_new[i];
1146+
ASR::expr_t* arg_new_i = args[i].m_value;
11501147
cast_helper(member_var->m_type, arg_new_i, arg_new_i->base.loc);
11511148
ASR::ttype_t* left_type = member_var->m_type;
11521149
ASR::ttype_t* right_type = ASRUtils::expr_type(arg_new_i);
@@ -1162,10 +1159,10 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
11621159
);
11631160
throw SemanticAbort();
11641161
}
1165-
args_new.p[i] = arg_new_i;
1162+
args.p[i].m_value = arg_new_i;
11661163
}
11671164
ASR::ttype_t* der_type = ASRUtils::TYPE(ASR::make_Struct_t(al, loc, stemp, nullptr, 0));
1168-
return ASR::make_StructTypeConstructor_t(al, loc, stemp, args_new.p, args_new.size(), der_type, nullptr);
1165+
return ASR::make_StructTypeConstructor_t(al, loc, stemp, args.p, args.size(), der_type, nullptr);
11691166
} else if( ASR::is_a<ASR::EnumType_t>(*s) ) {
11701167
Vec<ASR::expr_t*> args_new;
11711168
args_new.reserve(al, args.size());

tests/reference/asr-structs_01-be14d49.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-structs_01-be14d49.stdout",
9-
"stdout_hash": "bd6651b101dc97e717d0604af576a8492e549d6830a89ae0b2d2a83f",
9+
"stdout_hash": "7fc590bac6faf35bb632814509f73079edbd4c868b32dcfcf5efe3f2",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

0 commit comments

Comments
 (0)