Skip to content

Fix user import when it overlaps with runtime library modules #1743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ RUN(NAME test_import_02 LABELS cpython llvm c)
RUN(NAME test_import_03 LABELS cpython llvm c)
RUN(NAME test_import_04 IMPORT_PATH ..
LABELS cpython llvm c)
RUN(NAME test_import_05 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME test_math LABELS cpython llvm)
RUN(NAME test_numpy_01 LABELS cpython llvm c)
RUN(NAME test_numpy_02 LABELS cpython llvm c)
Expand Down
5 changes: 5 additions & 0 deletions integration_tests/test_import/sys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from lpython import i32

def hi_from_user_sys() -> i32:
print("hi from user sys!")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably return an integer value from here.

return -5
3 changes: 3 additions & 0 deletions integration_tests/test_import_05.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from test_import.sys import hi_from_user_sys

assert hi_from_user_sys() == -5
72 changes: 38 additions & 34 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,34 +201,29 @@ Result<std::string> get_full_path(const std::string &filename,
if (status) {
return file_path;
} else {
status = read_file(file_path, input);
if (status) {
return file_path;
} else {
// If this is `lpython`, do a special lookup
if (filename == "lpython.py") {
file_path = runtime_library_dir + "/lpython/" + filename;
status = read_file(file_path, input);
if (status) {
lpython = true;
return file_path;
} else {
return Error();
}
} else if (startswith(filename, "numpy.py")) {
file_path = runtime_library_dir + "/lpython_intrinsic_" + filename;
status = read_file(file_path, input);
if (status) {
return file_path;
} else {
return Error();
}
} else if (startswith(filename, "enum.py")) {
enum_py = true;
// If this is `lpython`, do a special lookup
if (filename == "lpython.py") {
file_path = runtime_library_dir + "/lpython/" + filename;
status = read_file(file_path, input);
if (status) {
lpython = true;
return file_path;
} else {
return Error();
}
} else if (startswith(filename, "numpy.py")) {
file_path = runtime_library_dir + "/lpython_intrinsic_" + filename;
status = read_file(file_path, input);
if (status) {
return file_path;
} else {
return Error();
}
} else if (startswith(filename, "enum.py")) {
enum_py = true;
return Error();
} else {
return Error();
}
}
}
Expand Down Expand Up @@ -3644,19 +3639,16 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
Search all the paths in order and stop
when the desired module is found.
*/
bool module_found = false;
std::string path_found = "";
for( auto& path: paths ) {
if(is_directory(path + "/" + directory + mod_sym)) {
module_found = true;
// Directory i.e., x/y/__init__.py
path += '/' + directory + mod_sym;
path_found = path + '/' + directory + mod_sym;
mod_sym = "__init__";
break;
} else if(path_exists(path + "/" + directory + mod_sym + ".py")) {
module_found = true;
// File i.e., x/y.py
path += '/' + directory;
}
if( module_found ) {
path_found = path + '/' + directory;
break;
}
}
Expand All @@ -3666,15 +3658,27 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
specified and if its a directory
then prioritise the directory itself.
*/
if( !module_found ) {
if( path_found.empty() ) {
if (is_directory(directory + mod_sym)) {
// Directory i.e., x/__init__.py
paths.insert(paths.begin(), directory + mod_sym);
mod_sym = "__init__";
path_found = directory + mod_sym;
} else if (path_exists(directory + mod_sym + ".py")) {
paths.insert(paths.begin(), directory);
path_found = directory;
}
}

// Update paths to contain only the found path (if found)
// so that later load_module() should only use this path
// to read the package/module file
// load_module() need not search through all paths again
if (path_found.empty()) {
// include the runtime library dir so that later
// runtime library modules could be imported
paths = {get_runtime_library_dir()};
} else {
paths = {path_found};
}
}

void visit_ImportFrom(const AST::ImportFrom_t &x) {
Expand Down