Skip to content

Commit 73d75d1

Browse files
Support Annotated assignment in the global_scope
1 parent 322f9c3 commit 73d75d1

7 files changed

+39
-11
lines changed

src/libasr/asr_scopes.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,11 @@ std::string SymbolTable::get_unique_name(const std::string &name) {
137137

138138
void SymbolTable::move_symbols_from_global_scope(Allocator &al,
139139
SymbolTable *module_scope, Vec<char *> &syms,
140-
Vec<char *> &mod_dependencies) {
140+
Vec<char *> &mod_dependencies, Vec<ASR::stmt_t*> &var_init) {
141141
// TODO: This isn't scalable. We have write a visitor in asdl_cpp.py
142142
syms.reserve(al, 4);
143143
mod_dependencies.reserve(al, 4);
144+
var_init.reserve(al, 4);
144145
for (auto &a : scope) {
145146
switch (a.second->type) {
146147
case (ASR::symbolType::Module): {
@@ -225,6 +226,15 @@ void SymbolTable::move_symbols_from_global_scope(Allocator &al,
225226
} case (ASR::symbolType::Variable) : {
226227
ASR::Variable_t *v = ASR::down_cast<ASR::Variable_t>(a.second);
227228
v->m_parent_symtab = module_scope;
229+
if (v->m_symbolic_value) {
230+
ASR::expr_t* v_expr = ASRUtils::EXPR(ASR::make_Var_t(
231+
al, v->base.base.loc, (ASR::symbol_t *) v));
232+
ASR::asr_t* assign = ASR::make_Assignment_t(al,
233+
v->base.base.loc, v_expr, v->m_symbolic_value, nullptr);
234+
var_init.push_back(al, ASRUtils::STMT(assign));
235+
v->m_symbolic_value = nullptr;
236+
v->m_value = nullptr;
237+
}
228238
module_scope->add_symbol(a.first, (ASR::symbol_t *) v);
229239
syms.push_back(al, s2c(al, a.first));
230240
break;

src/libasr/asr_scopes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace LCompilers {
1010

1111
namespace ASR {
1212
struct asr_t;
13+
struct stmt_t;
1314
struct symbol_t;
1415
}
1516

@@ -84,7 +85,7 @@ struct SymbolTable {
8485

8586
void move_symbols_from_global_scope(Allocator &al,
8687
SymbolTable *module_scope, Vec<char *> &syms,
87-
Vec<char *> &mod_dependencies);
88+
Vec<char *> &mod_dependencies, Vec<ASR::stmt_t*> &var_init);
8889
};
8990

9091
} // namespace LCompilers

src/libasr/pass/global_symbols.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,35 @@ namespace LCompilers {
1616

1717
void pass_wrap_global_syms_into_module(Allocator &al,
1818
ASR::TranslationUnit_t &unit,
19-
const LCompilers::PassOptions& /*pass_options*/) {
19+
const LCompilers::PassOptions& pass_options) {
2020
Location loc = unit.base.base.loc;
2121
char *module_name = s2c(al, "_global_symbols");
2222
SymbolTable *module_scope = al.make_new<SymbolTable>(unit.m_global_scope);
2323
Vec<char *> moved_symbols;
2424
Vec<char *> mod_dependencies;
25+
Vec<ASR::stmt_t*> var_init;
2526

2627
// Move all the symbols from global into the module scope
2728
unit.m_global_scope->move_symbols_from_global_scope(al, module_scope,
28-
moved_symbols, mod_dependencies);
29+
moved_symbols, mod_dependencies, var_init);
2930

3031
// Erase the symbols that are moved into the module
3132
for (auto &sym: moved_symbols) {
3233
unit.m_global_scope->erase_symbol(sym);
3334
}
3435

36+
if (module_scope->get_symbol(pass_options.run_fun) && var_init.n > 0) {
37+
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(
38+
module_scope->get_symbol(pass_options.run_fun));
39+
for (size_t i = 0; i < f->n_body; i++) {
40+
var_init.push_back(al, f->m_body[i]);
41+
}
42+
f->m_body = var_init.p;
43+
f->n_body = var_init.n;
44+
// Overwrites the function: `_lpython_main_program`
45+
module_scope->add_symbol(f->m_name, (ASR::symbol_t *) f);
46+
}
47+
3548
Vec<char *> m_dependencies;
3649
m_dependencies.reserve(al, mod_dependencies.size());
3750
for( auto &dep: mod_dependencies) {

tests/reference/asr-expr_07-7742668.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-expr_07-7742668.stdout",
9-
"stdout_hash": "36af7cdd0a8bed977355197e5fb0512e1b23f1b36966eef5e83ef244",
9+
"stdout_hash": "e419602ad57368da314d299a740459bc2fd912e9302ef68207e70105",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
(TranslationUnit (SymbolTable 1 {_global_symbols: (Module (SymbolTable 7 {_lpython_main_program: (Function (SymbolTable 6 {}) _lpython_main_program (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [f bool_to_str] [] [(SubroutineCall 7 f () [] ()) (SubroutineCall 7 bool_to_str () [] ())] () Public .false. .false.), bool_to_str: (Function (SymbolTable 4 {var: (Variable 4 var [] Local () () Default (Logical 4 []) Source Public Required .false.)}) bool_to_str (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [] [(= (Var 4 var) (LogicalConstant .true. (Logical 4 [])) ()) (Print () [(Cast (LogicalConstant .true. (Logical 4 [])) LogicalToCharacter (Character 1 -2 () []) (StringConstant "True" (Character 1 4 () [])))] () ()) (Assert (StringCompare (Cast (Var 4 var) LogicalToCharacter (Character 1 -2 () []) ()) Eq (StringConstant "True" (Character 1 4 () [])) (Logical 4 []) ()) ()) (= (Var 4 var) (LogicalConstant .false. (Logical 4 [])) ()) (Assert (StringCompare (Cast (Var 4 var) LogicalToCharacter (Character 1 -2 () []) ()) Eq (StringConstant "False" (Character 1 5 () [])) (Logical 4 []) ()) ()) (Assert (StringCompare (Cast (LogicalConstant .true. (Logical 4 [])) LogicalToCharacter (Character 1 -2 () []) (StringConstant "True" (Character 1 4 () []))) Eq (StringConstant "True" (Character 1 4 () [])) (Logical 4 []) (LogicalConstant .true. (Logical 4 []))) ())] () Public .false. .false.), f: (Function (SymbolTable 3 {a: (Variable 3 a [] Local (IntegerConstant 5 (Integer 4 [])) () Default (Integer 4 []) Source Public Required .false.), b: (Variable 3 b [x] Local (IntegerBinOp (Var 3 x) Add (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) ()) () Default (Integer 4 []) Source Public Required .false.), x: (Variable 3 x [] Local (IntegerConstant 3 (Integer 4 [])) () Default (Integer 4 []) Source Public Required .false.)}) f (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [g] [] [(= (Var 3 a) (IntegerConstant 5 (Integer 4 [])) ()) (= (Var 3 x) (IntegerConstant 3 (Integer 4 [])) ()) (= (Var 3 x) (IntegerConstant 5 (Integer 4 [])) ()) (= (Var 3 b) (IntegerBinOp (Var 3 x) Add (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) ()) ()) (Print () [(Var 3 a) (Var 3 b)] () ()) (Assert (IntegerCompare (Var 3 b) Eq (IntegerConstant 6 (Integer 4 [])) (Logical 4 []) ()) ()) (SubroutineCall 7 g () [((IntegerBinOp (IntegerBinOp (Var 3 a) Mul (Var 3 b) (Integer 4 []) ()) Add (IntegerConstant 3 (Integer 4 [])) (Integer 4 []) ()))] ())] () Public .false. .false.), g: (Function (SymbolTable 2 {x: (Variable 2 x [] In () () Default (Integer 4 []) Source Public Required .false.)}) g (FunctionType [(Integer 4 [])] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [(Var 2 x)] [(Print () [(Var 2 x)] () ())] () Public .false. .false.), x: (Variable 7 x [] Local (IntegerConstant 7 (Integer 4 [])) () Default (Integer 4 []) Source Public Required .false.)}) _global_symbols [] .false. .false.), main_program: (Program (SymbolTable 5 {_lpython_main_program: (ExternalSymbol 5 _lpython_main_program 7 _lpython_main_program _global_symbols [] _lpython_main_program Public)}) main_program [_global_symbols] [(SubroutineCall 5 _lpython_main_program () [] ())])}) [])
1+
(TranslationUnit (SymbolTable 1 {_global_symbols: (Module (SymbolTable 7 {_lpython_main_program: (Function (SymbolTable 6 {}) _lpython_main_program (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [f bool_to_str] [] [(= (Var 7 x) (IntegerConstant 7 (Integer 4 [])) ()) (SubroutineCall 7 f () [] ()) (SubroutineCall 7 bool_to_str () [] ())] () Public .false. .false.), bool_to_str: (Function (SymbolTable 4 {var: (Variable 4 var [] Local () () Default (Logical 4 []) Source Public Required .false.)}) bool_to_str (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [] [(= (Var 4 var) (LogicalConstant .true. (Logical 4 [])) ()) (Print () [(Cast (LogicalConstant .true. (Logical 4 [])) LogicalToCharacter (Character 1 -2 () []) (StringConstant "True" (Character 1 4 () [])))] () ()) (Assert (StringCompare (Cast (Var 4 var) LogicalToCharacter (Character 1 -2 () []) ()) Eq (StringConstant "True" (Character 1 4 () [])) (Logical 4 []) ()) ()) (= (Var 4 var) (LogicalConstant .false. (Logical 4 [])) ()) (Assert (StringCompare (Cast (Var 4 var) LogicalToCharacter (Character 1 -2 () []) ()) Eq (StringConstant "False" (Character 1 5 () [])) (Logical 4 []) ()) ()) (Assert (StringCompare (Cast (LogicalConstant .true. (Logical 4 [])) LogicalToCharacter (Character 1 -2 () []) (StringConstant "True" (Character 1 4 () []))) Eq (StringConstant "True" (Character 1 4 () [])) (Logical 4 []) (LogicalConstant .true. (Logical 4 []))) ())] () Public .false. .false.), f: (Function (SymbolTable 3 {a: (Variable 3 a [] Local (IntegerConstant 5 (Integer 4 [])) () Default (Integer 4 []) Source Public Required .false.), b: (Variable 3 b [x] Local (IntegerBinOp (Var 3 x) Add (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) ()) () Default (Integer 4 []) Source Public Required .false.), x: (Variable 3 x [] Local (IntegerConstant 3 (Integer 4 [])) () Default (Integer 4 []) Source Public Required .false.)}) f (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [g] [] [(= (Var 3 a) (IntegerConstant 5 (Integer 4 [])) ()) (= (Var 3 x) (IntegerConstant 3 (Integer 4 [])) ()) (= (Var 3 x) (IntegerConstant 5 (Integer 4 [])) ()) (= (Var 3 b) (IntegerBinOp (Var 3 x) Add (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) ()) ()) (Print () [(Var 3 a) (Var 3 b)] () ()) (Assert (IntegerCompare (Var 3 b) Eq (IntegerConstant 6 (Integer 4 [])) (Logical 4 []) ()) ()) (SubroutineCall 7 g () [((IntegerBinOp (IntegerBinOp (Var 3 a) Mul (Var 3 b) (Integer 4 []) ()) Add (IntegerConstant 3 (Integer 4 [])) (Integer 4 []) ()))] ())] () Public .false. .false.), g: (Function (SymbolTable 2 {x: (Variable 2 x [] In () () Default (Integer 4 []) Source Public Required .false.)}) g (FunctionType [(Integer 4 [])] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [(Var 2 x)] [(Print () [(Var 2 x)] () ())] () Public .false. .false.), x: (Variable 7 x [] Local () () Default (Integer 4 []) Source Public Required .false.)}) _global_symbols [] .false. .false.), main_program: (Program (SymbolTable 5 {_lpython_main_program: (ExternalSymbol 5 _lpython_main_program 7 _lpython_main_program _global_symbols [] _lpython_main_program Public)}) main_program [_global_symbols] [(SubroutineCall 5 _lpython_main_program () [] ())])}) [])

tests/reference/llvm-print_04-443a8d8.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "llvm-print_04-443a8d8.stdout",
9-
"stdout_hash": "82ed5155deef7a5b597b62a427ca9c33a1dd6650f757223c5c604889",
9+
"stdout_hash": "740498e6d0b9c0a6ffc7f6711e1f8c5bae27b834b726d108fb0e0c18",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

tests/reference/llvm-print_04-443a8d8.stdout

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
; ModuleID = 'LFortran'
22
source_filename = "LFortran"
33

4-
@u = global i64 -922337203685477580
5-
@x = global i32 -2147483648
6-
@y = global i16 -32768
7-
@z = global i8 -128
4+
@u = global i64 0
5+
@x = global i32 0
6+
@y = global i16 0
7+
@z = global i8 0
88
@0 = private unnamed_addr constant [2 x i8] c" \00", align 1
99
@1 = private unnamed_addr constant [2 x i8] c"\0A\00", align 1
1010
@2 = private unnamed_addr constant [7 x i8] c"%lld%s\00", align 1
@@ -20,6 +20,10 @@ source_filename = "LFortran"
2020

2121
define void @__module__global_symbols__lpython_main_program() {
2222
.entry:
23+
store i64 -922337203685477580, i64* @u, align 4
24+
store i32 -2147483648, i32* @x, align 4
25+
store i16 -32768, i16* @y, align 2
26+
store i8 -128, i8* @z, align 1
2327
%0 = load i64, i64* @u, align 4
2428
call void (i8*, ...) @_lfortran_printf(i8* getelementptr inbounds ([7 x i8], [7 x i8]* @2, i32 0, i32 0), i64 %0, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @1, i32 0, i32 0))
2529
%1 = load i32, i32* @x, align 4

0 commit comments

Comments
 (0)