diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 815b9e052a..65549078ec 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -370,6 +370,7 @@ RUN(NAME test_import_02 LABELS cpython llvm c) RUN(NAME test_import_03 LABELS cpython llvm c) RUN(NAME test_import_04 IMPORT_PATH .. LABELS cpython llvm c) +RUN(NAME test_import_05 LABELS cpython llvm c wasm wasm_x86 wasm_x64) RUN(NAME test_math LABELS cpython llvm) RUN(NAME test_numpy_01 LABELS cpython llvm c) RUN(NAME test_numpy_02 LABELS cpython llvm c) diff --git a/integration_tests/test_import/sys.py b/integration_tests/test_import/sys.py new file mode 100644 index 0000000000..ac905eb29d --- /dev/null +++ b/integration_tests/test_import/sys.py @@ -0,0 +1,5 @@ +from lpython import i32 + +def hi_from_user_sys() -> i32: + print("hi from user sys!") + return -5 diff --git a/integration_tests/test_import_05.py b/integration_tests/test_import_05.py new file mode 100644 index 0000000000..8ee6c54c7a --- /dev/null +++ b/integration_tests/test_import_05.py @@ -0,0 +1,3 @@ +from test_import.sys import hi_from_user_sys + +assert hi_from_user_sys() == -5 diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index bfd2121594..e7511fa1a6 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -201,34 +201,29 @@ Result get_full_path(const std::string &filename, if (status) { return file_path; } else { - status = read_file(file_path, input); - if (status) { - return file_path; - } else { - // If this is `lpython`, do a special lookup - if (filename == "lpython.py") { - file_path = runtime_library_dir + "/lpython/" + filename; - status = read_file(file_path, input); - if (status) { - lpython = true; - return file_path; - } else { - return Error(); - } - } else if (startswith(filename, "numpy.py")) { - file_path = runtime_library_dir + "/lpython_intrinsic_" + filename; - status = read_file(file_path, input); - if (status) { - return file_path; - } else { - return Error(); - } - } else if (startswith(filename, "enum.py")) { - enum_py = true; + // If this is `lpython`, do a special lookup + if (filename == "lpython.py") { + file_path = runtime_library_dir + "/lpython/" + filename; + status = read_file(file_path, input); + if (status) { + lpython = true; + return file_path; + } else { return Error(); + } + } else if (startswith(filename, "numpy.py")) { + file_path = runtime_library_dir + "/lpython_intrinsic_" + filename; + status = read_file(file_path, input); + if (status) { + return file_path; } else { return Error(); } + } else if (startswith(filename, "enum.py")) { + enum_py = true; + return Error(); + } else { + return Error(); } } } @@ -3644,19 +3639,16 @@ class SymbolTableVisitor : public CommonVisitor { Search all the paths in order and stop when the desired module is found. */ - bool module_found = false; + std::string path_found = ""; for( auto& path: paths ) { if(is_directory(path + "/" + directory + mod_sym)) { - module_found = true; // Directory i.e., x/y/__init__.py - path += '/' + directory + mod_sym; + path_found = path + '/' + directory + mod_sym; mod_sym = "__init__"; + break; } else if(path_exists(path + "/" + directory + mod_sym + ".py")) { - module_found = true; // File i.e., x/y.py - path += '/' + directory; - } - if( module_found ) { + path_found = path + '/' + directory; break; } } @@ -3666,15 +3658,27 @@ class SymbolTableVisitor : public CommonVisitor { specified and if its a directory then prioritise the directory itself. */ - if( !module_found ) { + if( path_found.empty() ) { if (is_directory(directory + mod_sym)) { // Directory i.e., x/__init__.py - paths.insert(paths.begin(), directory + mod_sym); mod_sym = "__init__"; + path_found = directory + mod_sym; } else if (path_exists(directory + mod_sym + ".py")) { - paths.insert(paths.begin(), directory); + path_found = directory; } } + + // Update paths to contain only the found path (if found) + // so that later load_module() should only use this path + // to read the package/module file + // load_module() need not search through all paths again + if (path_found.empty()) { + // include the runtime library dir so that later + // runtime library modules could be imported + paths = {get_runtime_library_dir()}; + } else { + paths = {path_found}; + } } void visit_ImportFrom(const AST::ImportFrom_t &x) {