diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index c0c5bf9694..f8521e024b 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -441,4 +441,5 @@ RUN(NAME test_argv_01 LABELS llvm) # TODO: Test using CPython RUN(NAME global_syms_01 LABELS cpython llvm c) RUN(NAME global_syms_02 LABELS cpython llvm c) RUN(NAME global_syms_03_b LABELS cpython llvm c) +RUN(NAME global_syms_03_c LABELS cpython llvm c) RUN(NAME global_syms_04 LABELS cpython llvm c wasm wasm_x64) diff --git a/integration_tests/global_syms_03_a.py b/integration_tests/global_syms_03_a.py index 8eb775a9d4..ad8f62f759 100644 --- a/integration_tests/global_syms_03_a.py +++ b/integration_tests/global_syms_03_a.py @@ -1,7 +1,13 @@ -from lpython import i32 +from lpython import i32, f64 + +print("Imported from global_syms_03_a") l_1: list[str] = ['Monday', 'Tuesday', 'Wednesday'] +l_1.append('Thursday') def populate_lists() -> list[i32]: return [10, -20] l_2: list[i32] = populate_lists() + +l_3: list[f64] +l_3 = [1.0, 2.0, 3.0] diff --git a/integration_tests/global_syms_03_b.py b/integration_tests/global_syms_03_b.py index 66295b1f1b..0d77cfc79e 100644 --- a/integration_tests/global_syms_03_b.py +++ b/integration_tests/global_syms_03_b.py @@ -1,5 +1,5 @@ from global_syms_03_a import l_1, l_2 -assert len(l_1) == 3 +assert len(l_1) == 4 assert l_1[1] == "Tuesday" assert l_2[1] == -20 diff --git a/integration_tests/global_syms_03_c.py b/integration_tests/global_syms_03_c.py new file mode 100644 index 0000000000..cf6f05375b --- /dev/null +++ b/integration_tests/global_syms_03_c.py @@ -0,0 +1,5 @@ +import global_syms_03_a + +assert len(global_syms_03_a.l_1) == 4 +assert global_syms_03_a.l_1[3] == "Thursday" +assert global_syms_03_a.l_3 == [1.0, 2.0, 3.0] diff --git a/src/libasr/asr_scopes.cpp b/src/libasr/asr_scopes.cpp index 5fc47695f7..c21d237761 100644 --- a/src/libasr/asr_scopes.cpp +++ b/src/libasr/asr_scopes.cpp @@ -137,13 +137,10 @@ std::string SymbolTable::get_unique_name(const std::string &name) { void SymbolTable::move_symbols_from_global_scope(Allocator &al, SymbolTable *module_scope, Vec &syms, - Vec &mod_dependencies, Vec &func_dependencies, - Vec &var_init) { + Vec &mod_dependencies) { // TODO: This isn't scalable. We have write a visitor in asdl_cpp.py syms.reserve(al, 4); mod_dependencies.reserve(al, 4); - func_dependencies.reserve(al, 4); - var_init.reserve(al, 4); for (auto &a : scope) { switch (a.second->type) { case (ASR::symbolType::Module): { @@ -206,47 +203,6 @@ void SymbolTable::move_symbols_from_global_scope(Allocator &al, es->m_parent_symtab = module_scope; ASR::symbol_t *s = ASRUtils::symbol_get_past_external(a.second); LCOMPILERS_ASSERT(s); - if (ASR::is_a(*s)) { - ASR::Variable_t *v = ASR::down_cast(s); - if (v->m_symbolic_value && !ASR::is_a(*v->m_type) - && ASR::is_a(*v->m_type)) { - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t( - al, v->base.base.loc, (ASR::symbol_t *) es)); - ASR::expr_t *value = v->m_symbolic_value; - v->m_symbolic_value = nullptr; - v->m_value = nullptr; - if (ASR::is_a(*value)) { - ASR::FunctionCall_t *call = - ASR::down_cast(value); - ASR::Module_t *m = ASRUtils::get_sym_module(s); - ASR::symbol_t *func = m->m_symtab->get_symbol( - ASRUtils::symbol_name(call->m_name)); - ASR::Function_t *f = ASR::down_cast(func); - std::string func_name = std::string(m->m_name) + - "@" + f->m_name; - ASR::symbol_t *es_func; - if (!module_scope->get_symbol(func_name)) { - es_func = ASR::down_cast( - ASR::make_ExternalSymbol_t(al, f->base.base.loc, - module_scope, s2c(al, func_name), func, m->m_name, - nullptr, 0, s2c(al, f->m_name), ASR::accessType::Public)); - module_scope->add_symbol(func_name, es_func); - if (!present(func_dependencies, s2c(al, func_name))) { - func_dependencies.push_back(al, s2c(al,func_name)); - } - } else { - es_func = module_scope->get_symbol(func_name); - } - value = ASRUtils::EXPR(ASR::make_FunctionCall_t(al, - call->base.base.loc, es_func, call->m_original_name, - call->m_args, call->n_args, call->m_type, - call->m_value, call->m_dt)); - } - ASR::asr_t* assign = ASR::make_Assignment_t(al, - v->base.base.loc, target, value, nullptr); - var_init.push_back(al, ASRUtils::STMT(assign)); - } - } module_scope->add_symbol(a.first, (ASR::symbol_t *) es); syms.push_back(al, s2c(al, a.first)); break; @@ -271,24 +227,6 @@ void SymbolTable::move_symbols_from_global_scope(Allocator &al, } case (ASR::symbolType::Variable) : { ASR::Variable_t *v = ASR::down_cast(a.second); v->m_parent_symtab = module_scope; - // Make the Assignment statement only for the data-types (List, - // Dict, ...), that cannot be handled in the LLVM global scope - if (v->m_symbolic_value && !ASR::is_a(*v->m_type) - && ASR::is_a(*v->m_type)) { - ASR::expr_t* v_expr = ASRUtils::EXPR(ASR::make_Var_t( - al, v->base.base.loc, (ASR::symbol_t *) v)); - ASR::asr_t* assign = ASR::make_Assignment_t(al, - v->base.base.loc, v_expr, v->m_symbolic_value, nullptr); - var_init.push_back(al, ASRUtils::STMT(assign)); - v->m_symbolic_value = nullptr; - v->m_value = nullptr; - Vec v_dependencies; - v_dependencies.reserve(al, 1); - ASRUtils::collect_variable_dependencies(al, - v_dependencies, v->m_type); - v->m_dependencies = v_dependencies.p; - v->n_dependencies = v_dependencies.size(); - } module_scope->add_symbol(a.first, (ASR::symbol_t *) v); syms.push_back(al, s2c(al, a.first)); break; diff --git a/src/libasr/asr_scopes.h b/src/libasr/asr_scopes.h index 08395e30a7..936043cd61 100644 --- a/src/libasr/asr_scopes.h +++ b/src/libasr/asr_scopes.h @@ -85,8 +85,7 @@ struct SymbolTable { void move_symbols_from_global_scope(Allocator &al, SymbolTable *module_scope, Vec &syms, - Vec &mod_dependencies, Vec &func_dependencies, - Vec &var_init); + Vec &mod_dependencies); }; } // namespace LCompilers diff --git a/src/libasr/pass/global_symbols.cpp b/src/libasr/pass/global_symbols.cpp index 7e231b2ec8..895cea85d6 100644 --- a/src/libasr/pass/global_symbols.cpp +++ b/src/libasr/pass/global_symbols.cpp @@ -22,35 +22,16 @@ void pass_wrap_global_syms_into_module(Allocator &al, SymbolTable *module_scope = al.make_new(unit.m_global_scope); Vec moved_symbols; Vec mod_dependencies; - Vec func_dependencies; - Vec var_init; // Move all the symbols from global into the module scope unit.m_global_scope->move_symbols_from_global_scope(al, module_scope, - moved_symbols, mod_dependencies, func_dependencies, var_init); + moved_symbols, mod_dependencies); // Erase the symbols that are moved into the module for (auto &sym: moved_symbols) { unit.m_global_scope->erase_symbol(sym); } - if (module_scope->get_symbol(pass_options.run_fun) && var_init.n > 0) { - ASR::Function_t *f = ASR::down_cast( - module_scope->get_symbol(pass_options.run_fun)); - for (size_t i = 0; i < f->n_body; i++) { - var_init.push_back(al, f->m_body[i]); - } - for (size_t i = 0; i < f->n_dependencies; i++) { - func_dependencies.push_back(al, f->m_dependencies[i]); - } - f->m_body = var_init.p; - f->n_body = var_init.n; - f->m_dependencies = func_dependencies.p; - f->n_dependencies = func_dependencies.n; - // Overwrites the function: `_lpython_main_program` - module_scope->add_symbol(f->m_name, (ASR::symbol_t *) f); - } - ASR::symbol_t *module = (ASR::symbol_t *) ASR::make_Module_t(al, loc, module_scope, module_name, mod_dependencies.p, mod_dependencies.n, false, false); diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index ba1dc49576..127b1c4461 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -520,6 +521,9 @@ class CommonVisitor : public AST::BaseVisitor { */ std::vector tmp_vec; + // Used to store the initializer for the global variables like list, ... + Vec global_init; + Allocator &al; LocationManager &lm; SymbolTable *current_scope; @@ -557,6 +561,7 @@ class CommonVisitor : public AST::BaseVisitor { current_body{nullptr}, ann_assign_target_type{nullptr}, assign_ast_target{nullptr}, is_c_p_pointer_call{false}, allow_implicit_casting{allow_implicit_casting_} { current_module_dependencies.reserve(al, 4); + global_init.reserve(al, 1); } ASR::asr_t* resolve_variable(const Location &loc, const std::string &var_name) { @@ -2239,13 +2244,17 @@ class CommonVisitor : public AST::BaseVisitor { ASR::symbol_t* v_sym = ASR::down_cast(v); ASR::Variable_t* v_variable = ASR::down_cast(v_sym); - if( init_expr && current_body && + if( init_expr && (current_body || ASR::is_a(*type)) && (is_runtime_expression || !is_variable_const)) { ASR::expr_t* v_expr = ASRUtils::EXPR(ASR::make_Var_t(al, loc, v_sym)); cast_helper(v_expr, init_expr, true); ASR::asr_t* assign = ASR::make_Assignment_t(al, loc, v_expr, init_expr, nullptr); - current_body->push_back(al, ASRUtils::STMT(assign)); + if (current_body) { + current_body->push_back(al, ASRUtils::STMT(assign)); + } else if (ASR::is_a(*type)) { + global_init.push_back(al, assign); + } v_variable->m_symbolic_value = nullptr; v_variable->m_value = nullptr; @@ -3945,10 +3954,53 @@ class BodyVisitor : public CommonVisitor { mod->m_dependencies = current_module_dependencies.p; mod->n_dependencies = current_module_dependencies.size(); } - // These global statements are added to the translation unit for now, - // but they should be adding to a module initialization function + + if (global_init.n > 0 && main_module_sym) { + // unit->m_items is used and set to nullptr in the + // `pass_wrap_global_stmts_into_function` pass + unit->m_items = global_init.p; + unit->n_items = global_init.size(); + std::string func_name = "global_initializer"; + LCompilers::PassOptions pass_options; + pass_options.run_fun = func_name; + pass_wrap_global_stmts_into_function(al, *unit, pass_options); + + ASR::Module_t *mod = ASR::down_cast(main_module_sym); + ASR::symbol_t *f_sym = unit->m_global_scope->get_symbol(func_name); + if (f_sym) { + // Add the `global_initilaizer` function into the `__main__` + // module and later call this function to initialize the + // global variables like list, ... + ASR::Function_t *f = ASR::down_cast(f_sym); + f->m_symtab->parent = mod->m_symtab; + mod->m_symtab->add_symbol(func_name, (ASR::symbol_t *) f); + // Erase the function in TranslationUnit + unit->m_global_scope->erase_symbol(func_name); + } + } + unit->m_items = items.p; unit->n_items = items.size(); + if (items.n > 0 && main_module_sym) { + std::string func_name = "global_statements"; + // Wrap all the global statements into a Function + LCompilers::PassOptions pass_options; + pass_options.run_fun = func_name; + pass_wrap_global_stmts_into_function(al, *unit, pass_options); + + ASR::Module_t *mod = ASR::down_cast(main_module_sym); + ASR::symbol_t *f_sym = unit->m_global_scope->get_symbol(func_name); + if (f_sym) { + // Add the `global_statements` function into the `__main__` + // module and later call this function to execute the + // global_statements + ASR::Function_t *f = ASR::down_cast(f_sym); + f->m_symtab->parent = mod->m_symtab; + mod->m_symtab->add_symbol(func_name, (ASR::symbol_t *) f); + // Erase the function in TranslationUnit + unit->m_global_scope->erase_symbol(func_name); + } + } tmp = asr; } @@ -4009,8 +4061,81 @@ class BodyVisitor : public CommonVisitor { } } - void visit_Import(const AST::Import_t &/*x*/) { - // visited in symbol visitor + void visit_Import(const AST::Import_t &x) { + // All the modules are imported in the SymbolTable visitor + // Here, we call the global_initializer & global_statements to + // initialize and execute the global symbols + for (size_t i = 0; i < x.n_names; i++) { + std::string mod_name = x.m_names[i].m_name; + ASR::symbol_t *mod_sym = current_scope->resolve_symbol(mod_name); + if (mod_sym) { + ASR::Module_t *mod = ASR::down_cast(mod_sym); + + std::string g_func_name = mod_name + "@global_initializer"; + ASR::symbol_t *g_func = mod->m_symtab->get_symbol("global_initializer"); + if (g_func && !current_scope->get_symbol(g_func_name)) { + ASR::symbol_t *es = ASR::down_cast( + ASR::make_ExternalSymbol_t(al, mod->base.base.loc, + current_scope, s2c(al, g_func_name), g_func, + s2c(al, mod_name), nullptr, 0, s2c(al, "global_initializer"), + ASR::accessType::Public)); + current_scope->add_symbol(g_func_name, es); + tmp_vec.push_back(ASR::make_SubroutineCall_t(al, x.base.base.loc, + es, g_func, nullptr, 0, nullptr)); + } + + g_func_name = mod_name + "@global_statements"; + g_func = mod->m_symtab->get_symbol("global_statements"); + if (g_func && !current_scope->get_symbol(g_func_name)) { + ASR::symbol_t *es = ASR::down_cast( + ASR::make_ExternalSymbol_t(al, mod->base.base.loc, + current_scope, s2c(al, g_func_name), g_func, + s2c(al, mod_name), nullptr, 0, s2c(al, "global_statements"), + ASR::accessType::Public)); + current_scope->add_symbol(g_func_name, es); + tmp_vec.push_back(ASR::make_SubroutineCall_t(al, x.base.base.loc, + es, g_func, nullptr, 0, nullptr)); + } + } + } + } + + void visit_ImportFrom(const AST::ImportFrom_t &x) { + // Handled by SymbolTableVisitor already + // Here, we call the global_initializer & global_statements to + // initialize and execute the global symbols + std::string mod_name = x.m_module; + ASR::symbol_t *mod_sym = current_scope->resolve_symbol(mod_name); + if (mod_sym) { + ASR::Module_t *mod = ASR::down_cast(mod_sym); + + std::string g_func_name = mod_name + "@global_initializer"; + ASR::symbol_t *g_func = mod->m_symtab->get_symbol("global_initializer"); + if (g_func && !current_scope->get_symbol(g_func_name)) { + ASR::symbol_t *es = ASR::down_cast( + ASR::make_ExternalSymbol_t(al, mod->base.base.loc, + current_scope, s2c(al, g_func_name), g_func, + s2c(al, mod_name), nullptr, 0, s2c(al, "global_initializer"), + ASR::accessType::Public)); + current_scope->add_symbol(g_func_name, es); + tmp_vec.push_back(ASR::make_SubroutineCall_t(al, x.base.base.loc, + es, g_func, nullptr, 0, nullptr)); + } + + g_func_name = mod_name + "@global_statements"; + g_func = mod->m_symtab->get_symbol("global_statements"); + if (g_func && !current_scope->get_symbol(g_func_name)) { + ASR::symbol_t *es = ASR::down_cast( + ASR::make_ExternalSymbol_t(al, mod->base.base.loc, + current_scope, s2c(al, g_func_name), g_func, + s2c(al, mod_name), nullptr, 0, s2c(al, "global_statements"), + ASR::accessType::Public)); + current_scope->add_symbol(g_func_name, es); + tmp_vec.push_back(ASR::make_SubroutineCall_t(al, x.base.base.loc, + es, g_func, nullptr, 0, nullptr)); + } + } + tmp = nullptr; } void visit_AnnAssign(const AST::AnnAssign_t &x) { @@ -5749,15 +5874,15 @@ class BodyVisitor : public CommonVisitor { fn_args.push_back(al, suffix); } else if (attr_name == "partition") { - /* + /* str.partition(seperator) ----> - Split the string at the first occurrence of sep, and return a 3-tuple containing the part - before the separator, the separator itself, and the part after the separator. - If the separator is not found, return a 3-tuple containing the string itself, followed + Split the string at the first occurrence of sep, and return a 3-tuple containing the part + before the separator, the separator itself, and the part after the separator. + If the separator is not found, return a 3-tuple containing the string itself, followed by two empty strings. */ - + if(args.size() != 1) { throw SemanticError("str.partition() takes one argument", loc); @@ -5830,7 +5955,7 @@ class BodyVisitor : public CommonVisitor { return res; } - ASR::expr_t* eval_partition(std::string &s_var, ASR::expr_t* arg_seperator, + ASR::expr_t* eval_partition(std::string &s_var, ASR::expr_t* arg_seperator, const Location &loc, ASR::ttype_t *arg_seperator_type) { /* Invoked when Seperator argument is provided as a constant string @@ -5841,8 +5966,8 @@ class BodyVisitor : public CommonVisitor { throw SemanticError("empty separator", arg_seperator->base.loc); } /* - using KMP algorithm to find seperator inside string - res_tuple: stores the resulting 3-tuple expression ---> + using KMP algorithm to find seperator inside string + res_tuple: stores the resulting 3-tuple expression ---> (if seperator exist) tuple: (left of seperator, seperator, right of seperator) (if seperator does not exist) tuple: (string, "", "") res_tuple_type: stores the type of each expression present in resulting 3-tuple @@ -5850,7 +5975,7 @@ class BodyVisitor : public CommonVisitor { int seperator_pos = KMP_string_match(s_var, seperator); Vec res_tuple; Vec res_tuple_type; - res_tuple.reserve(al, 3); + res_tuple.reserve(al, 3); res_tuple_type.reserve(al, 3); std :: string first_res, second_res, third_res; if(seperator_pos == -1) { @@ -6104,11 +6229,11 @@ class BodyVisitor : public CommonVisitor { } return; } else if (attr_name == "partition") { - /* + /* str.partition(seperator) ----> - Split the string at the first occurrence of sep, and return a 3-tuple containing the part - before the separator, the separator itself, and the part after the separator. - If the separator is not found, return a 3-tuple containing the string itself, followed + Split the string at the first occurrence of sep, and return a 3-tuple containing the part + before the separator, the separator itself, and the part after the separator. + If the separator is not found, return a 3-tuple containing the string itself, followed by two empty strings. */ if (args.size() != 1) { @@ -6438,11 +6563,6 @@ class BodyVisitor : public CommonVisitor { false, x.m_args, x.n_args, x.m_keywords, x.n_keywords); } - void visit_ImportFrom(const AST::ImportFrom_t &/*x*/) { - // Handled by SymbolTableVisitor already - tmp = nullptr; - } - void visit_Global(const AST::Global_t &/*x*/) { tmp = nullptr; }