Skip to content

Commit 6744ac5

Browse files
[ASRPass] Wrap all the global symbols into a module
1 parent ba525e6 commit 6744ac5

File tree

179 files changed

+704
-388
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

179 files changed

+704
-388
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,4 +426,5 @@ RUN(NAME comp_01 LABELS cpython llvm wasm wasm_x64)
426426
RUN(NAME bit_operations_i32 LABELS cpython llvm wasm wasm_x64)
427427
RUN(NAME bit_operations_i64 LABELS cpython llvm wasm)
428428

429-
RUN(NAME test_argv_01 LABELS llvm) # TODO: Test using CPython
429+
RUN(NAME test_argv_01 LABELS llvm) # TODO: Test using CPython
430+
RUN(NAME global_syms_01 LABELS cpython)

integration_tests/global_syms_01.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from ltypes import i32
2+
3+
x: list[i32]
4+
x = [1, 2]
5+
i: i32
6+
i = x[0]
7+
8+
def test_global_symbols():
9+
assert i == 1
10+
assert x[1] == 2
11+
12+
test_global_symbols()

src/libasr/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ set(SRC
3232
pass/for_all.cpp
3333
pass/global_stmts.cpp
3434
pass/global_stmts_program.cpp
35+
pass/global_symbols.cpp
3536
pass/select_case.cpp
3637
pass/implied_do_loops.cpp
3738
pass/array_op.cpp

src/libasr/asr_scopes.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,104 @@ std::string SymbolTable::get_unique_name(const std::string &name) {
135135
return unique_name;
136136
}
137137

138+
void SymbolTable::move_symbols_from_global_scope(Allocator &al,
139+
SymbolTable *module_scope, Vec<char *> &syms,
140+
Vec<char *> &mod_dependencies) {
141+
syms.reserve(al, 4);
142+
mod_dependencies.reserve(al, 4);
143+
for (auto &a : scope) {
144+
switch (a.second->type) {
145+
case (ASR::symbolType::Module): {
146+
// Pass
147+
break;
148+
} case (ASR::symbolType::Function) : {
149+
ASR::Function_t *fn = ASR::down_cast<ASR::Function_t>(a.second);
150+
for (size_t i = 0; i < fn->n_dependencies; i++ ) {
151+
ASR::symbol_t *s = fn->m_symtab->get_symbol(
152+
fn->m_dependencies[i]);
153+
if (s == nullptr) {
154+
std::string block_name = "block";
155+
ASR::symbol_t *block_s = fn->m_symtab->get_symbol(block_name);
156+
int32_t j = 1;
157+
while(block_s != nullptr) {
158+
while(block_s != nullptr) {
159+
ASR::Block_t *b = ASR::down_cast<ASR::Block_t>(block_s);
160+
s = b->m_symtab->get_symbol(fn->m_dependencies[i]);
161+
if (s == nullptr) {
162+
block_s = b->m_symtab->get_symbol("block");
163+
} else {
164+
break;
165+
}
166+
}
167+
if (s == nullptr) {
168+
block_s = fn->m_symtab->get_symbol(block_name +
169+
std::to_string(j));
170+
j++;
171+
} else {
172+
break;
173+
}
174+
}
175+
}
176+
if (s == nullptr) {
177+
s = fn->m_symtab->parent->get_symbol(fn->m_dependencies[i]);
178+
}
179+
if (s != nullptr && ASR::is_a<ASR::ExternalSymbol_t>(*s)) {
180+
char *es_name = ASR::down_cast<
181+
ASR::ExternalSymbol_t>(s)->m_module_name;
182+
if (!present(mod_dependencies, es_name)) {
183+
mod_dependencies.push_back(al, es_name);
184+
}
185+
}
186+
}
187+
fn->m_symtab->parent = module_scope;
188+
module_scope->add_symbol(a.first, (ASR::symbol_t *) fn);
189+
syms.push_back(al, s2c(al, a.first));
190+
break;
191+
} case (ASR::symbolType::GenericProcedure) : {
192+
ASR::GenericProcedure_t *es = ASR::down_cast<ASR::GenericProcedure_t>(a.second);
193+
es->m_parent_symtab = module_scope;
194+
module_scope->add_symbol(a.first, (ASR::symbol_t *) es);
195+
syms.push_back(al, s2c(al, a.first));
196+
break;
197+
} case (ASR::symbolType::ExternalSymbol) : {
198+
ASR::ExternalSymbol_t *es = ASR::down_cast<ASR::ExternalSymbol_t>(a.second);
199+
if (!present(mod_dependencies, es->m_module_name)) {
200+
mod_dependencies.push_back(al, es->m_module_name);
201+
}
202+
es->m_parent_symtab = module_scope;
203+
module_scope->add_symbol(a.first, (ASR::symbol_t *) es);
204+
syms.push_back(al, s2c(al, a.first));
205+
break;
206+
} case (ASR::symbolType::StructType) : {
207+
ASR::StructType_t *st = ASR::down_cast<ASR::StructType_t>(a.second);
208+
st->m_symtab->parent = module_scope;
209+
module_scope->add_symbol(a.first, (ASR::symbol_t *) st);
210+
syms.push_back(al, s2c(al, a.first));
211+
break;
212+
} case (ASR::symbolType::EnumType) : {
213+
ASR::EnumType_t *et = ASR::down_cast<ASR::EnumType_t>(a.second);
214+
et->m_symtab->parent = module_scope;
215+
module_scope->add_symbol(a.first, (ASR::symbol_t *) et);
216+
syms.push_back(al, s2c(al, a.first));
217+
break;
218+
} case (ASR::symbolType::UnionType) : {
219+
ASR::UnionType_t *ut = ASR::down_cast<ASR::UnionType_t>(a.second);
220+
ut->m_symtab->parent = module_scope;
221+
module_scope->add_symbol(a.first, (ASR::symbol_t *) ut);
222+
syms.push_back(al, s2c(al, a.first));
223+
break;
224+
} case (ASR::symbolType::Variable) : {
225+
ASR::Variable_t *v = ASR::down_cast<ASR::Variable_t>(a.second);
226+
v->m_parent_symtab = module_scope;
227+
module_scope->add_symbol(a.first, (ASR::symbol_t *) v);
228+
syms.push_back(al, s2c(al, a.first));
229+
break;
230+
} default : {
231+
throw LCompilersException("Moving the symbol:`" + a.first +
232+
"` from global scope is not implemented yet");
233+
};
234+
}
235+
}
236+
}
237+
138238
} // namespace LCompilers

src/libasr/asr_scopes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <map>
55

66
#include <libasr/alloc.h>
7+
#include <libasr/containers.h>
78

89
namespace LCompilers {
910

@@ -80,6 +81,10 @@ struct SymbolTable {
8081
size_t n_scope_names, char **m_scope_names);
8182

8283
std::string get_unique_name(const std::string &name);
84+
85+
void move_symbols_from_global_scope(Allocator &al,
86+
SymbolTable *module_scope, Vec<char *> &syms,
87+
Vec<char *> &mod_dependencies);
8388
};
8489

8590
} // namespace LCompilers

src/libasr/asr_verify.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -578,21 +578,23 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
578578
require(x_m_module_name == asr_owner_name,
579579
"ExternalSymbol::m_module_name `" + x_m_module_name
580580
+ "` must match external's module name `" + asr_owner_name + "`");
581-
ASR::symbol_t *s = nullptr;
582-
if( m != nullptr && ((ASR::symbol_t*) m == ASRUtils::get_asr_owner(x.m_external)) ) {
583-
s = m->m_symtab->find_scoped_symbol(x.m_original_name, x.n_scope_names, x.m_scope_names);
584-
} else if( sm ) {
585-
s = sm->m_symtab->resolve_symbol(std::string(x.m_original_name));
586-
} else if( em ) {
587-
s = em->m_symtab->resolve_symbol(std::string(x.m_original_name));
581+
if (asr_owner_name != "_global_symbols") {
582+
ASR::symbol_t *s = nullptr;
583+
if( m != nullptr && ((ASR::symbol_t*) m == ASRUtils::get_asr_owner(x.m_external)) ) {
584+
s = m->m_symtab->find_scoped_symbol(x.m_original_name, x.n_scope_names, x.m_scope_names);
585+
} else if( sm ) {
586+
s = sm->m_symtab->resolve_symbol(std::string(x.m_original_name));
587+
} else if( em ) {
588+
s = em->m_symtab->resolve_symbol(std::string(x.m_original_name));
589+
}
590+
require(s != nullptr,
591+
"ExternalSymbol::m_original_name ('"
592+
+ std::string(x.m_original_name)
593+
+ "') + scope_names not found in a module '"
594+
+ asr_owner_name + "'");
595+
require(s == x.m_external,
596+
"ExternalSymbol::m_name + scope_names found but not equal to m_external");
588597
}
589-
require(s != nullptr,
590-
"ExternalSymbol::m_original_name ('"
591-
+ std::string(x.m_original_name)
592-
+ "') + scope_names not found in a module '"
593-
+ asr_owner_name + "'");
594-
require(s == x.m_external,
595-
"ExternalSymbol::m_name + scope_names found but not equal to m_external");
596598
}
597599
}
598600

src/libasr/codegen/asr_to_c.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,32 @@ R"(
645645
array_types_decls += src;
646646
}
647647

648+
if (x.m_global_scope->get_symbol("_global_symbols") != nullptr) {
649+
struct_dep_graph.clear();
650+
SymbolTable *global_symbols = ASR::down_cast<ASR::Module_t>(
651+
x.m_global_scope->get_symbol("_global_symbols"))->m_symtab;
652+
for (auto &item : global_symbols->get_scope()) {
653+
if (ASR::is_a<ASR::StructType_t>(*item.second) ||
654+
ASR::is_a<ASR::EnumType_t>(*item.second) ||
655+
ASR::is_a<ASR::UnionType_t>(*item.second)) {
656+
std::vector<std::string> struct_deps_vec;
657+
std::pair<char**, size_t> struct_deps_ptr = ASRUtils::symbol_dependencies(item.second);
658+
for( size_t i = 0; i < struct_deps_ptr.second; i++ ) {
659+
struct_deps_vec.push_back(std::string(struct_deps_ptr.first[i]));
660+
}
661+
struct_dep_graph[item.first] = struct_deps_vec;
662+
}
663+
}
664+
665+
std::vector<std::string> struct_deps = ASRUtils::order_deps(struct_dep_graph);
666+
667+
for (auto &item : struct_deps) {
668+
ASR::symbol_t* struct_sym = global_symbols->get_symbol(item);
669+
visit_symbol(*struct_sym);
670+
array_types_decls += src;
671+
}
672+
}
673+
648674
// Topologically sort all global functions
649675
// and then define them in the right order
650676
std::vector<std::string> global_func_order = ASRUtils::determine_function_definition_order(x.m_global_scope);

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,16 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
12281228
fname2arg_type["lbound"] = std::make_pair(bound_arg, bound_arg->getPointerTo());
12291229
fname2arg_type["ubound"] = std::make_pair(bound_arg, bound_arg->getPointerTo());
12301230

1231+
if (x.m_global_scope->get_symbol("_global_symbols") != nullptr) {
1232+
SymbolTable *global_symbols = ASR::down_cast<ASR::Module_t>(
1233+
x.m_global_scope->get_symbol("_global_symbols"))->m_symtab;
1234+
for (auto &item : global_symbols->get_scope()) {
1235+
if (is_a<ASR::EnumType_t>(*item.second)) {
1236+
visit_symbol(*item.second);
1237+
}
1238+
}
1239+
}
1240+
12311241
// Process Variables first:
12321242
for (auto &item : x.m_global_scope->get_scope()) {
12331243
if (is_a<ASR::Variable_t>(*item.second) ||

src/libasr/codegen/asr_to_x86.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,21 @@ class ASRToX86Visitor : public ASR::BaseVisitor<ASRToX86Visitor>
7979
visit_symbol(*sym);
8080
}
8181

82+
if (x.m_global_scope->get_symbol("_global_symbols") != nullptr) {
83+
SymbolTable *global_symbols = ASR::down_cast<ASR::Module_t>(
84+
x.m_global_scope->get_symbol("_global_symbols"))->m_symtab;
85+
std::vector<std::string> global_func_order
86+
= ASRUtils::determine_function_definition_order(global_symbols);
87+
for (size_t i = 0; i < global_func_order.size(); i++) {
88+
ASR::symbol_t* sym = global_symbols->get_symbol(global_func_order[i]);
89+
// Ignore external symbols because they are already defined by the loop above.
90+
if( !sym || ASR::is_a<ASR::ExternalSymbol_t>(*sym) ) {
91+
continue;
92+
}
93+
visit_symbol(*sym);
94+
}
95+
}
96+
8297
// Then the main program:
8398
for (auto &item : x.m_global_scope->get_scope()) {
8499
if (ASR::is_a<ASR::Program_t>(*item.second)) {
@@ -504,7 +519,8 @@ class ASRToX86Visitor : public ASR::BaseVisitor<ASRToX86Visitor>
504519
}
505520

506521
void visit_SubroutineCall(const ASR::SubroutineCall_t &x) {
507-
ASR::Function_t *s = ASR::down_cast<ASR::Function_t>(x.m_name);
522+
ASR::Function_t *s = ASR::down_cast<ASR::Function_t>(
523+
ASRUtils::symbol_get_past_external(x.m_name));
508524

509525
uint32_t h = get_hash((ASR::asr_t*)s);
510526
if (x86_symtab.find(h) == x86_symtab.end()) {

src/libasr/pass/global_stmts_program.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <libasr/asr_utils.h>
55
#include <libasr/asr_verify.h>
66
#include <libasr/pass/global_stmts.h>
7+
#include <libasr/pass/global_symbols.h>
78

89

910
namespace LCompilers {
@@ -26,12 +27,22 @@ void pass_wrap_global_stmts_into_program(Allocator &al,
2627
prog_body.reserve(al, 1);
2728
if (unit.n_items > 0) {
2829
pass_wrap_global_stmts_into_function(al, unit, pass_options);
29-
ASR::symbol_t *fn = unit.m_global_scope->get_symbol(program_fn_name);
30-
if (ASR::is_a<ASR::Function_t>(*fn)
31-
&& ASR::down_cast<ASR::Function_t>(fn)->m_return_var == nullptr) {
30+
pass_wrap_global_syms_into_module(al, unit, pass_options);
31+
ASR::Module_t *mod = ASR::down_cast<ASR::Module_t>(
32+
unit.m_global_scope->get_symbol("_global_symbols"));
33+
// Call `_lpython_main_program` function
34+
ASR::symbol_t *fn_s = mod->m_symtab->get_symbol(program_fn_name);
35+
if (ASR::is_a<ASR::Function_t>(*fn_s)
36+
&& ASR::down_cast<ASR::Function_t>(fn_s)->m_return_var == nullptr) {
37+
ASR::Function_t *fn = ASR::down_cast<ASR::Function_t>(fn_s);
38+
fn_s = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(
39+
al, fn->base.base.loc, current_scope, s2c(al, program_fn_name),
40+
fn_s, mod->m_name, nullptr, 0, s2c(al, program_fn_name),
41+
ASR::accessType::Public));
42+
current_scope->add_symbol(program_fn_name, fn_s);
3243
ASR::asr_t *stmt = ASR::make_SubroutineCall_t(
3344
al, unit.base.base.loc,
34-
fn, nullptr,
45+
fn_s, nullptr,
3546
nullptr, 0,
3647
nullptr);
3748
prog_body.push_back(al, ASR::down_cast<ASR::stmt_t>(stmt));

src/libasr/pass/global_symbols.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include <libasr/asr.h>
2+
#include <libasr/containers.h>
3+
#include <libasr/exception.h>
4+
#include <libasr/asr_utils.h>
5+
#include <libasr/asr_verify.h>
6+
#include <libasr/pass/global_symbols.h>
7+
#include <libasr/pass/pass_utils.h>
8+
9+
10+
namespace LCompilers {
11+
12+
/*
13+
* This ASR pass transforms (in-place) the ASR tree
14+
* and wraps all global symbols into a module
15+
*/
16+
17+
void pass_wrap_global_syms_into_module(Allocator &al,
18+
ASR::TranslationUnit_t &unit,
19+
const LCompilers::PassOptions& /*pass_options*/) {
20+
Location loc = unit.base.base.loc;
21+
char *module_name = s2c(al, "_global_symbols");
22+
SymbolTable *module_scope = al.make_new<SymbolTable>(unit.m_global_scope);
23+
Vec<char *> moved_symbols;
24+
Vec<char *> mod_dependencies;
25+
26+
// Move all the symbols from global into the module scope
27+
unit.m_global_scope->move_symbols_from_global_scope(al, module_scope,
28+
moved_symbols, mod_dependencies);
29+
30+
// Erase the symbols that are moved into the module
31+
for (auto &sym: moved_symbols) {
32+
unit.m_global_scope->erase_symbol(sym);
33+
}
34+
35+
Vec<char *> m_dependencies;
36+
m_dependencies.reserve(al, mod_dependencies.size());
37+
for( auto &dep: mod_dependencies) {
38+
m_dependencies.push_back(al, dep);
39+
}
40+
41+
ASR::symbol_t *module = (ASR::symbol_t *) ASR::make_Module_t(al, loc,
42+
module_scope, module_name, m_dependencies.p, m_dependencies.n,
43+
false, false);
44+
unit.m_global_scope->add_symbol(module_name, module);
45+
}
46+
47+
} // namespace LCompilers

src/libasr/pass/global_symbols.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef LFORTRAN_PASS_GLOBAL_SYMBOLS_H
2+
#define LFORTRAN_PASS_GLOBAL_SYMBOLS_H
3+
4+
#include <libasr/asr.h>
5+
#include <libasr/utils.h>
6+
7+
namespace LCompilers {
8+
9+
void pass_wrap_global_syms_into_module(Allocator &al,
10+
ASR::TranslationUnit_t &unit,
11+
const LCompilers::PassOptions& pass_options);
12+
13+
} // namespace LCompilers
14+
15+
#endif // LFORTRAN_PASS_GLOBAL_SYMBOLS_H

0 commit comments

Comments
 (0)