Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ Release date: TBA

Closes pylint-dev/pylint#9139

* Add ``AstroidManager.prefer_stubs`` attribute to control the astroid 3.2.0 feature that prefers stubs.

Refs pylint-dev/#9626
Refs pylint-dev/#9623


What's New in astroid 3.2.0?
============================
Expand Down
8 changes: 2 additions & 6 deletions astroid/interpreter/_import/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,10 @@ def find_module(
pass
submodule_path = sys.path

# We're looping on pyi first because if a pyi exists there's probably a reason
# (i.e. the code is hard or impossible to parse), so we take pyi into account
# But we're not quite ready to do this for numpy, see https://github.com/pylint-dev/astroid/pull/2375
suffixes = (".pyi", ".py", importlib.machinery.BYTECODE_SUFFIXES[0])
numpy_suffixes = (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0])
suffixes = (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0])
for entry in submodule_path:
package_directory = os.path.join(entry, modname)
for suffix in numpy_suffixes if "numpy" in entry else suffixes:
for suffix in suffixes:
package_file_name = "__init__" + suffix
file_path = os.path.join(package_directory, package_file_name)
if os.path.isfile(file_path):
Expand Down
14 changes: 13 additions & 1 deletion astroid/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class AstroidManager:
"extension_package_whitelist": set(),
"module_denylist": set(),
"_transform": TransformVisitor(),
"prefer_stubs": False,
}

def __init__(self) -> None:
Expand All @@ -73,6 +74,7 @@ def __init__(self) -> None:
]
self.module_denylist = AstroidManager.brain["module_denylist"]
self._transform = AstroidManager.brain["_transform"]
self.prefer_stubs = AstroidManager.brain["prefer_stubs"]

@property
def always_load_extensions(self) -> bool:
Expand Down Expand Up @@ -111,6 +113,14 @@ def unregister_transform(self):
def builtins_module(self) -> nodes.Module:
return self.astroid_cache["builtins"]

@property
def prefer_stubs(self) -> bool:
return AstroidManager.brain["prefer_stubs"]

@prefer_stubs.setter
def prefer_stubs(self, value: bool) -> None:
AstroidManager.brain["prefer_stubs"] = value

def visit_transforms(self, node: nodes.NodeNG) -> InferenceResult:
"""Visit the transforms and apply them to the given *node*."""
return self._transform.visit(node)
Expand All @@ -136,7 +146,9 @@ def ast_from_file(
# Call get_source_file() only after a cache miss,
# since it calls os.path.exists().
try:
filepath = get_source_file(filepath, include_no_ext=True)
filepath = get_source_file(
filepath, include_no_ext=True, prefer_stubs=self.prefer_stubs
)
source = True
except NoSourceFile:
pass
Expand Down
15 changes: 9 additions & 6 deletions astroid/modutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@


if sys.platform.startswith("win"):
PY_SOURCE_EXTS = ("pyi", "pyw", "py")
PY_SOURCE_EXTS = ("py", "pyw", "pyi")
PY_SOURCE_EXTS_STUBS_FIRST = ("pyi", "pyw", "py")
PY_COMPILED_EXTS = ("dll", "pyd")
else:
PY_SOURCE_EXTS = ("pyi", "py")
PY_SOURCE_EXTS = ("py", "pyi")
PY_SOURCE_EXTS_STUBS_FIRST = ("pyi", "py")
PY_COMPILED_EXTS = ("so",)


Expand Down Expand Up @@ -484,7 +486,9 @@ def get_module_files(
return files


def get_source_file(filename: str, include_no_ext: bool = False) -> str:
def get_source_file(
filename: str, include_no_ext: bool = False, prefer_stubs: bool = False
) -> str:
"""Given a python module's file name return the matching source file
name (the filename will be returned identically if it's already an
absolute path to a python source file).
Expand All @@ -499,7 +503,7 @@ def get_source_file(filename: str, include_no_ext: bool = False) -> str:
base, orig_ext = os.path.splitext(filename)
if orig_ext == ".pyi" and os.path.exists(f"{base}{orig_ext}"):
return f"{base}{orig_ext}"
for ext in PY_SOURCE_EXTS if "numpy" not in filename else reversed(PY_SOURCE_EXTS):
for ext in PY_SOURCE_EXTS_STUBS_FIRST if prefer_stubs else PY_SOURCE_EXTS:
source_path = f"{base}.{ext}"
if os.path.exists(source_path):
return source_path
Expand Down Expand Up @@ -671,8 +675,7 @@ def _has_init(directory: str) -> str | None:
else return None.
"""
mod_or_pack = os.path.join(directory, "__init__")
exts = reversed(PY_SOURCE_EXTS) if "numpy" in directory else PY_SOURCE_EXTS
for ext in (*exts, "pyc", "pyo"):
for ext in (*PY_SOURCE_EXTS, "pyc", "pyo"):
if os.path.exists(mod_or_pack + "." + ext):
return mod_or_pack + "." + ext
return None
Expand Down
3 changes: 2 additions & 1 deletion tests/test_modutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def test_pyi_preferred(self) -> None:
package = resources.find("pyi_data/find_test")
module = os.path.join(package, "__init__.py")
self.assertEqual(
modutils.get_source_file(module), os.path.normpath(module) + "i"
modutils.get_source_file(module, prefer_stubs=True),
os.path.normpath(module) + "i",
)


Expand Down