19
19
#include < libasr/utils.h>
20
20
#include < libasr/pass/global_stmts_program.h>
21
21
#include < libasr/pass/instantiate_template.h>
22
+ #include < libasr/pass/global_stmts.h>
22
23
#include < libasr/modfile.h>
23
24
24
25
#include < lpython/python_ast.h>
@@ -520,6 +521,9 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
520
521
*/
521
522
std::vector<ASR::asr_t *> tmp_vec;
522
523
524
+ // Used to store the initializer for the global variables like list, ...
525
+ Vec<ASR::asr_t *> global_init;
526
+
523
527
Allocator &al;
524
528
LocationManager &lm;
525
529
SymbolTable *current_scope;
@@ -557,6 +561,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
557
561
current_body{nullptr }, ann_assign_target_type{nullptr },
558
562
assign_ast_target{nullptr }, is_c_p_pointer_call{false }, allow_implicit_casting{allow_implicit_casting_} {
559
563
current_module_dependencies.reserve (al, 4 );
564
+ global_init.reserve (al, 1 );
560
565
}
561
566
562
567
ASR::asr_t * resolve_variable (const Location &loc, const std::string &var_name) {
@@ -2239,13 +2244,17 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
2239
2244
ASR::symbol_t * v_sym = ASR::down_cast<ASR::symbol_t >(v);
2240
2245
ASR::Variable_t* v_variable = ASR::down_cast<ASR::Variable_t>(v_sym);
2241
2246
2242
- if ( init_expr && current_body &&
2247
+ if ( init_expr && ( current_body || ASR::is_a<ASR::List_t>(*type)) &&
2243
2248
(is_runtime_expression || !is_variable_const)) {
2244
2249
ASR::expr_t * v_expr = ASRUtils::EXPR (ASR::make_Var_t (al, loc, v_sym));
2245
2250
cast_helper (v_expr, init_expr, true );
2246
2251
ASR::asr_t * assign = ASR::make_Assignment_t (al, loc, v_expr,
2247
2252
init_expr, nullptr );
2248
- current_body->push_back (al, ASRUtils::STMT (assign));
2253
+ if (current_body) {
2254
+ current_body->push_back (al, ASRUtils::STMT (assign));
2255
+ } else if (ASR::is_a<ASR::List_t>(*type)) {
2256
+ global_init.push_back (al, assign);
2257
+ }
2249
2258
2250
2259
v_variable->m_symbolic_value = nullptr ;
2251
2260
v_variable->m_value = nullptr ;
@@ -3945,10 +3954,53 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
3945
3954
mod->m_dependencies = current_module_dependencies.p ;
3946
3955
mod->n_dependencies = current_module_dependencies.size ();
3947
3956
}
3948
- // These global statements are added to the translation unit for now,
3949
- // but they should be adding to a module initialization function
3957
+
3958
+ if (global_init.n > 0 && main_module_sym) {
3959
+ // unit->m_items is used and set to nullptr in the
3960
+ // `pass_wrap_global_stmts_into_function` pass
3961
+ unit->m_items = global_init.p ;
3962
+ unit->n_items = global_init.size ();
3963
+ std::string func_name = " global_initializer" ;
3964
+ LCompilers::PassOptions pass_options;
3965
+ pass_options.run_fun = func_name;
3966
+ pass_wrap_global_stmts_into_function (al, *unit, pass_options);
3967
+
3968
+ ASR::Module_t *mod = ASR::down_cast<ASR::Module_t>(main_module_sym);
3969
+ ASR::symbol_t *f_sym = unit->m_global_scope ->get_symbol (func_name);
3970
+ if (f_sym) {
3971
+ // Add the `global_initilaizer` function into the `__main__`
3972
+ // module and later call this function to initialize the
3973
+ // global variables like list, ...
3974
+ ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(f_sym);
3975
+ f->m_symtab ->parent = mod->m_symtab ;
3976
+ mod->m_symtab ->add_symbol (func_name, (ASR::symbol_t *) f);
3977
+ // Erase the function in TranslationUnit
3978
+ unit->m_global_scope ->erase_symbol (func_name);
3979
+ }
3980
+ }
3981
+
3950
3982
unit->m_items = items.p ;
3951
3983
unit->n_items = items.size ();
3984
+ if (items.n > 0 && main_module_sym) {
3985
+ std::string func_name = " global_statements" ;
3986
+ // Wrap all the global statements into a Function
3987
+ LCompilers::PassOptions pass_options;
3988
+ pass_options.run_fun = func_name;
3989
+ pass_wrap_global_stmts_into_function (al, *unit, pass_options);
3990
+
3991
+ ASR::Module_t *mod = ASR::down_cast<ASR::Module_t>(main_module_sym);
3992
+ ASR::symbol_t *f_sym = unit->m_global_scope ->get_symbol (func_name);
3993
+ if (f_sym) {
3994
+ // Add the `global_statements` function into the `__main__`
3995
+ // module and later call this function to execute the
3996
+ // global_statements
3997
+ ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(f_sym);
3998
+ f->m_symtab ->parent = mod->m_symtab ;
3999
+ mod->m_symtab ->add_symbol (func_name, (ASR::symbol_t *) f);
4000
+ // Erase the function in TranslationUnit
4001
+ unit->m_global_scope ->erase_symbol (func_name);
4002
+ }
4003
+ }
3952
4004
3953
4005
tmp = asr;
3954
4006
}
@@ -4009,8 +4061,81 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
4009
4061
}
4010
4062
}
4011
4063
4012
- void visit_Import (const AST::Import_t &/* x*/ ) {
4013
- // visited in symbol visitor
4064
+ void visit_Import (const AST::Import_t &x) {
4065
+ // All the modules are imported in the SymbolTable visitor
4066
+ // Here, we call the global_initializer & global_statements to
4067
+ // initialize and execute the global symbols
4068
+ for (size_t i = 0 ; i < x.n_names ; i++) {
4069
+ std::string mod_name = x.m_names [i].m_name ;
4070
+ ASR::symbol_t *mod_sym = current_scope->resolve_symbol (mod_name);
4071
+ if (mod_sym) {
4072
+ ASR::Module_t *mod = ASR::down_cast<ASR::Module_t>(mod_sym);
4073
+
4074
+ std::string g_func_name = mod_name + " @global_initializer" ;
4075
+ ASR::symbol_t *g_func = mod->m_symtab ->get_symbol (" global_initializer" );
4076
+ if (g_func && !current_scope->get_symbol (g_func_name)) {
4077
+ ASR::symbol_t *es = ASR::down_cast<ASR::symbol_t >(
4078
+ ASR::make_ExternalSymbol_t (al, mod->base .base .loc ,
4079
+ current_scope, s2c (al, g_func_name), g_func,
4080
+ s2c (al, mod_name), nullptr , 0 , s2c (al, " global_initializer" ),
4081
+ ASR::accessType::Public));
4082
+ current_scope->add_symbol (g_func_name, es);
4083
+ tmp_vec.push_back (ASR::make_SubroutineCall_t (al, x.base .base .loc ,
4084
+ es, g_func, nullptr , 0 , nullptr ));
4085
+ }
4086
+
4087
+ g_func_name = mod_name + " @global_statements" ;
4088
+ g_func = mod->m_symtab ->get_symbol (" global_statements" );
4089
+ if (g_func && !current_scope->get_symbol (g_func_name)) {
4090
+ ASR::symbol_t *es = ASR::down_cast<ASR::symbol_t >(
4091
+ ASR::make_ExternalSymbol_t (al, mod->base .base .loc ,
4092
+ current_scope, s2c (al, g_func_name), g_func,
4093
+ s2c (al, mod_name), nullptr , 0 , s2c (al, " global_statements" ),
4094
+ ASR::accessType::Public));
4095
+ current_scope->add_symbol (g_func_name, es);
4096
+ tmp_vec.push_back (ASR::make_SubroutineCall_t (al, x.base .base .loc ,
4097
+ es, g_func, nullptr , 0 , nullptr ));
4098
+ }
4099
+ }
4100
+ }
4101
+ }
4102
+
4103
+ void visit_ImportFrom (const AST::ImportFrom_t &x) {
4104
+ // Handled by SymbolTableVisitor already
4105
+ // Here, we call the global_initializer & global_statements to
4106
+ // initialize and execute the global symbols
4107
+ std::string mod_name = x.m_module ;
4108
+ ASR::symbol_t *mod_sym = current_scope->resolve_symbol (mod_name);
4109
+ if (mod_sym) {
4110
+ ASR::Module_t *mod = ASR::down_cast<ASR::Module_t>(mod_sym);
4111
+
4112
+ std::string g_func_name = mod_name + " @global_initializer" ;
4113
+ ASR::symbol_t *g_func = mod->m_symtab ->get_symbol (" global_initializer" );
4114
+ if (g_func && !current_scope->get_symbol (g_func_name)) {
4115
+ ASR::symbol_t *es = ASR::down_cast<ASR::symbol_t >(
4116
+ ASR::make_ExternalSymbol_t (al, mod->base .base .loc ,
4117
+ current_scope, s2c (al, g_func_name), g_func,
4118
+ s2c (al, mod_name), nullptr , 0 , s2c (al, " global_initializer" ),
4119
+ ASR::accessType::Public));
4120
+ current_scope->add_symbol (g_func_name, es);
4121
+ tmp_vec.push_back (ASR::make_SubroutineCall_t (al, x.base .base .loc ,
4122
+ es, g_func, nullptr , 0 , nullptr ));
4123
+ }
4124
+
4125
+ g_func_name = mod_name + " @global_statements" ;
4126
+ g_func = mod->m_symtab ->get_symbol (" global_statements" );
4127
+ if (g_func && !current_scope->get_symbol (g_func_name)) {
4128
+ ASR::symbol_t *es = ASR::down_cast<ASR::symbol_t >(
4129
+ ASR::make_ExternalSymbol_t (al, mod->base .base .loc ,
4130
+ current_scope, s2c (al, g_func_name), g_func,
4131
+ s2c (al, mod_name), nullptr , 0 , s2c (al, " global_statements" ),
4132
+ ASR::accessType::Public));
4133
+ current_scope->add_symbol (g_func_name, es);
4134
+ tmp_vec.push_back (ASR::make_SubroutineCall_t (al, x.base .base .loc ,
4135
+ es, g_func, nullptr , 0 , nullptr ));
4136
+ }
4137
+ }
4138
+ tmp = nullptr ;
4014
4139
}
4015
4140
4016
4141
void visit_AnnAssign (const AST::AnnAssign_t &x) {
@@ -5749,15 +5874,15 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
5749
5874
fn_args.push_back (al, suffix);
5750
5875
} else if (attr_name == " partition" ) {
5751
5876
5752
- /*
5877
+ /*
5753
5878
str.partition(seperator) ---->
5754
5879
5755
- Split the string at the first occurrence of sep, and return a 3-tuple containing the part
5756
- before the separator, the separator itself, and the part after the separator.
5757
- If the separator is not found, return a 3-tuple containing the string itself, followed
5880
+ Split the string at the first occurrence of sep, and return a 3-tuple containing the part
5881
+ before the separator, the separator itself, and the part after the separator.
5882
+ If the separator is not found, return a 3-tuple containing the string itself, followed
5758
5883
by two empty strings.
5759
5884
*/
5760
-
5885
+
5761
5886
if (args.size () != 1 ) {
5762
5887
throw SemanticError (" str.partition() takes one argument" ,
5763
5888
loc);
@@ -5830,7 +5955,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
5830
5955
return res;
5831
5956
}
5832
5957
5833
- ASR::expr_t * eval_partition (std::string &s_var, ASR::expr_t * arg_seperator,
5958
+ ASR::expr_t * eval_partition (std::string &s_var, ASR::expr_t * arg_seperator,
5834
5959
const Location &loc, ASR::ttype_t *arg_seperator_type) {
5835
5960
/*
5836
5961
Invoked when Seperator argument is provided as a constant string
@@ -5841,16 +5966,16 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
5841
5966
throw SemanticError (" empty separator" , arg_seperator->base .loc );
5842
5967
}
5843
5968
/*
5844
- using KMP algorithm to find seperator inside string
5845
- res_tuple: stores the resulting 3-tuple expression --->
5969
+ using KMP algorithm to find seperator inside string
5970
+ res_tuple: stores the resulting 3-tuple expression --->
5846
5971
(if seperator exist) tuple: (left of seperator, seperator, right of seperator)
5847
5972
(if seperator does not exist) tuple: (string, "", "")
5848
5973
res_tuple_type: stores the type of each expression present in resulting 3-tuple
5849
5974
*/
5850
5975
int seperator_pos = KMP_string_match (s_var, seperator);
5851
5976
Vec<ASR::expr_t *> res_tuple;
5852
5977
Vec<ASR::ttype_t *> res_tuple_type;
5853
- res_tuple.reserve (al, 3 );
5978
+ res_tuple.reserve (al, 3 );
5854
5979
res_tuple_type.reserve (al, 3 );
5855
5980
std :: string first_res, second_res, third_res;
5856
5981
if (seperator_pos == -1 ) {
@@ -6104,11 +6229,11 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
6104
6229
}
6105
6230
return ;
6106
6231
} else if (attr_name == " partition" ) {
6107
- /*
6232
+ /*
6108
6233
str.partition(seperator) ---->
6109
- Split the string at the first occurrence of sep, and return a 3-tuple containing the part
6110
- before the separator, the separator itself, and the part after the separator.
6111
- If the separator is not found, return a 3-tuple containing the string itself, followed
6234
+ Split the string at the first occurrence of sep, and return a 3-tuple containing the part
6235
+ before the separator, the separator itself, and the part after the separator.
6236
+ If the separator is not found, return a 3-tuple containing the string itself, followed
6112
6237
by two empty strings.
6113
6238
*/
6114
6239
if (args.size () != 1 ) {
@@ -6438,11 +6563,6 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
6438
6563
false , x.m_args , x.n_args , x.m_keywords , x.n_keywords );
6439
6564
}
6440
6565
6441
- void visit_ImportFrom (const AST::ImportFrom_t &/* x*/ ) {
6442
- // Handled by SymbolTableVisitor already
6443
- tmp = nullptr ;
6444
- }
6445
-
6446
6566
void visit_Global (const AST::Global_t &/* x*/ ) {
6447
6567
tmp = nullptr ;
6448
6568
}
0 commit comments