Skip to content

Commit 3c8d70c

Browse files
XX Wrap symbols from global scope into a module
1 parent a5aa754 commit 3c8d70c

File tree

7 files changed

+130
-5
lines changed

7 files changed

+130
-5
lines changed

src/libasr/asr_scopes.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,33 @@ std::string SymbolTable::get_unique_name(const std::string &name) {
134134
return unique_name;
135135
}
136136

137+
void SymbolTable::move_symbols_from_global_scope(Allocator &al,
138+
SymbolTable *module_scope, std::vector<std::string> &syms) {
139+
for (auto &a : scope) {
140+
switch (a.second->type) {
141+
case (ASR::symbolType::Variable) : {
142+
ASR::Variable_t *v = ASR::down_cast<ASR::Variable_t>(a.second);
143+
v->m_parent_symtab = module_scope;
144+
module_scope->add_symbol(a.first, (ASR::symbol_t *)v);
145+
syms.push_back(a.first);
146+
break;
147+
} case (ASR::symbolType::Function) : {
148+
ASR::Function_t *fn = ASR::down_cast<ASR::Function_t>(a.second);
149+
150+
// Replace Var_t's m_v with the symbols in the current scope
151+
ASRUtils::VarExprVisitor replace_symbols(al, fn->m_symtab);
152+
replace_symbols.visit_Function(*fn);
153+
154+
fn->m_symtab->parent = module_scope;
155+
module_scope->add_symbol(a.first, (ASR::symbol_t *) fn);
156+
syms.push_back(a.first);
157+
break;
158+
} default : {
159+
LCompilersException("Moving the symbol:`" + a.first +
160+
"` from global scope is not implemented yet");
161+
};
162+
}
163+
}
164+
}
165+
137166
} // namespace LCompilers

src/libasr/asr_scopes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ struct SymbolTable {
8080
size_t n_scope_names, char **m_scope_names);
8181

8282
std::string get_unique_name(const std::string &name);
83+
84+
void move_symbols_from_global_scope(Allocator &al, SymbolTable *fn_scope,
85+
std::vector<std::string> &syms);
8386
};
8487

8588
} // namespace LCompilers

src/libasr/asr_utils.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,6 +2090,60 @@ class ReplaceReturnWithGotoVisitor: public ASR::BaseStmtReplacer<ReplaceReturnWi
20902090

20912091
};
20922092

2093+
class VarExprReplacer: public ASR::BaseExprReplacer<VarExprReplacer>
2094+
{
2095+
private:
2096+
Allocator &al;
2097+
SymbolTable *current_scope;
2098+
2099+
public:
2100+
VarExprReplacer(Allocator& al_, SymbolTable *current_scope_) :
2101+
al(al_), current_scope(current_scope_)
2102+
{}
2103+
2104+
void replace_Var(ASR::Var_t *x) {
2105+
std::string sym_name = ASRUtils::symbol_name(x->m_v);
2106+
// Check for the symbol in the current scope; if available, return
2107+
if (current_scope->get_symbol(sym_name)) {
2108+
return;
2109+
}
2110+
// otherwise, get the symbol from the module scope.
2111+
// Create an ExternalSymbol and add it to the current scope.
2112+
ASR::symbol_t *var_s = current_scope->resolve_symbol(sym_name);
2113+
LCOMPILERS_ASSERT(var_s != nullptr)
2114+
LCOMPILERS_ASSERT(ASR::is_a<ASR::Variable_t>(*var_s))
2115+
ASR::Variable_t *v = ASR::down_cast<ASR::Variable_t>(var_s);
2116+
var_s = (ASR::symbol_t *) ASR::make_ExternalSymbol_t(al,
2117+
v->base.base.loc, current_scope, v->m_name, var_s,
2118+
s2c(al, "_global_symbols"), nullptr, 0, v->m_name,
2119+
ASR::accessType::Public);
2120+
current_scope->add_symbol(v->m_name, var_s);
2121+
// Replace the symbol with an ExternalSymbol
2122+
*current_expr = ASRUtils::EXPR(ASR::make_Var_t(al,
2123+
x->base.base.loc, var_s));
2124+
}
2125+
};
2126+
2127+
class VarExprVisitor : public ASR::CallReplacerOnExpressionsVisitor
2128+
<VarExprVisitor>
2129+
{
2130+
private:
2131+
2132+
VarExprReplacer replacer;
2133+
2134+
public:
2135+
2136+
VarExprVisitor(Allocator& al_, SymbolTable *current_scope) :
2137+
replacer(al_, current_scope)
2138+
{ }
2139+
2140+
void call_replacer() {
2141+
replacer.current_expr = current_expr;
2142+
replacer.replace_expr(*current_expr);
2143+
}
2144+
2145+
};
2146+
20932147
// Singleton LabelGenerator so that it generates
20942148
// unique labels for different statements, from
20952149
// whereever it is called (be it ASR passes, be it

src/libasr/pass/global_stmts.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,26 @@ void pass_wrap_global_stmts_into_function(Allocator &al,
162162
}
163163
}
164164

165+
void pass_wrap_global_vars_into_function(Allocator &al,
166+
ASR::TranslationUnit_t &unit) {
167+
Location loc = unit.base.base.loc;
168+
char *module_name = s2c(al, "_global_symbols");
169+
SymbolTable *module_scope = al.make_new<SymbolTable>(unit.m_global_scope);
170+
std::vector<std::string> moved_symbols;
171+
172+
// Move all the symbols from global into the module scope
173+
unit.m_global_scope->move_symbols_from_global_scope(
174+
al, module_scope, moved_symbols);
175+
176+
// Erase the symbols that are moved into the function
177+
for (auto &sym: moved_symbols) {
178+
unit.m_global_scope->erase_symbol(sym);
179+
}
180+
181+
ASR::symbol_t *module = (ASR::symbol_t *) ASR::make_Module_t(al, loc,
182+
module_scope, module_name, nullptr, 0, false,false);
183+
unit.m_global_scope->add_symbol(module_name, module);
184+
185+
}
186+
165187
} // namespace LCompilers

src/libasr/pass/global_stmts.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace LCompilers {
88

99
void pass_wrap_global_stmts_into_function(Allocator &al, ASR::TranslationUnit_t &unit,
1010
const LCompilers::PassOptions& pass_options);
11+
void pass_wrap_global_vars_into_function(Allocator &al,
12+
ASR::TranslationUnit_t &unit);
1113

1214
} // namespace LCompilers
1315

src/libasr/pass/global_stmts_program.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,21 @@ void pass_wrap_global_stmts_into_program(Allocator &al,
2626
prog_body.reserve(al, 1);
2727
if (unit.n_items > 0) {
2828
pass_wrap_global_stmts_into_function(al, unit, pass_options);
29-
ASR::symbol_t *fn = unit.m_global_scope->get_symbol(program_fn_name);
30-
if (ASR::is_a<ASR::Function_t>(*fn)
31-
&& ASR::down_cast<ASR::Function_t>(fn)->m_return_var == nullptr) {
29+
pass_wrap_global_vars_into_function(al, unit);
30+
ASR::Module_t *mod = ASR::down_cast<ASR::Module_t>(
31+
unit.m_global_scope->get_symbol("_global_symbols"));
32+
ASR::symbol_t *fn_s = mod->m_symtab->get_symbol(program_fn_name);
33+
if (ASR::is_a<ASR::Function_t>(*fn_s)
34+
&& ASR::down_cast<ASR::Function_t>(fn_s)->m_return_var == nullptr) {
35+
ASR::Function_t *fn = ASR::down_cast<ASR::Function_t>(fn_s);
36+
ASR::symbol_t *es = (ASR::symbol_t *) ASR::make_ExternalSymbol_t(
37+
al, fn->base.base.loc, current_scope, s2c(al, program_fn_name),
38+
fn_s, mod->m_name, nullptr, 0, s2c(al, program_fn_name),
39+
ASR::accessType::Public);
40+
current_scope->add_symbol(program_fn_name, es);
3241
ASR::asr_t *stmt = ASR::make_SubroutineCall_t(
3342
al, unit.base.base.loc,
34-
fn, nullptr,
43+
es, nullptr,
3544
nullptr, 0,
3645
nullptr);
3746
prog_body.push_back(al, ASR::down_cast<ASR::stmt_t>(stmt));

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6248,7 +6248,13 @@ Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al, LocationManager
62486248
pass_options.run_fun = "_lpython_main_program";
62496249
pass_options.runtime_library_dir = get_runtime_library_dir();
62506250
pass_wrap_global_stmts_into_program(al, *tu, pass_options);
6251-
LCOMPILERS_ASSERT(asr_verify(*tu, true, diagnostics));
6251+
#if defined(WITH_LFORTRAN_ASSERT)
6252+
diag::Diagnostics diagnostics;
6253+
if (!asr_verify(*tu, true, diagnostics)) {
6254+
std::cerr << diagnostics.render2();
6255+
throw LCompilersException("Verify failed");
6256+
};
6257+
#endif
62526258
}
62536259
}
62546260

0 commit comments

Comments
 (0)