Skip to content

Commit ffd7aa7

Browse files
DarkLight1337mzusman
authored andcommitted
[Bugfix] Override dunder methods of placeholder modules (vllm-project#11882)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 9a981e1 commit ffd7aa7

File tree

2 files changed

+220
-16
lines changed

2 files changed

+220
-16
lines changed

tests/test_utils.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import torch
88
from vllm_test_utils import monitor
99

10-
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
11-
get_open_port, memory_profiling, merge_async_iterators,
12-
supports_kw)
10+
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
11+
StoreBoolean, deprecate_kwargs, get_open_port,
12+
memory_profiling, merge_async_iterators, supports_kw)
1313

1414
from .utils import error_on_warning, fork_new_process_for_each_test
1515

@@ -323,3 +323,44 @@ def measure_current_non_torch():
323323
del weights
324324
lib.cudaFree(handle1)
325325
lib.cudaFree(handle2)
326+
327+
328+
def test_placeholder_module_error_handling():
329+
placeholder = PlaceholderModule("placeholder_1234")
330+
331+
def build_ctx():
332+
return pytest.raises(ModuleNotFoundError,
333+
match="No module named")
334+
335+
with build_ctx():
336+
int(placeholder)
337+
338+
with build_ctx():
339+
placeholder()
340+
341+
with build_ctx():
342+
_ = placeholder.some_attr
343+
344+
with build_ctx():
345+
# Test conflict with internal __name attribute
346+
_ = placeholder.name
347+
348+
# OK to print the placeholder or use it in a f-string
349+
_ = repr(placeholder)
350+
_ = str(placeholder)
351+
352+
# No error yet; only error when it is used downstream
353+
placeholder_attr = placeholder.placeholder_attr("attr")
354+
355+
with build_ctx():
356+
int(placeholder_attr)
357+
358+
with build_ctx():
359+
placeholder_attr()
360+
361+
with build_ctx():
362+
_ = placeholder_attr.some_attr
363+
364+
with build_ctx():
365+
# Test conflict with internal __module attribute
366+
_ = placeholder_attr.module

vllm/utils.py

Lines changed: 176 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
import zmq.asyncio
4747
from packaging.version import Version
4848
from torch.library import Library
49-
from typing_extensions import ParamSpec, TypeIs, assert_never
49+
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
5050

5151
import vllm.envs as envs
5252
from vllm.logger import enable_trace_function_call, init_logger
@@ -1627,24 +1627,183 @@ def get_vllm_optional_dependencies():
16271627
}
16281628

16291629

1630-
@dataclass(frozen=True)
1631-
class PlaceholderModule:
1630+
class _PlaceholderBase:
1631+
"""
1632+
Disallows downstream usage of placeholder modules.
1633+
1634+
We need to explicitly override each dunder method because
1635+
:meth:`__getattr__` is not called when they are accessed.
1636+
1637+
See also:
1638+
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
1639+
"""
1640+
1641+
def __getattr__(self, key: str) -> Never:
1642+
"""
1643+
The main class should implement this to throw an error
1644+
for attribute accesses representing downstream usage.
1645+
"""
1646+
raise NotImplementedError
1647+
1648+
# [Basic customization]
1649+
1650+
def __lt__(self, other: object):
1651+
return self.__getattr__("__lt__")
1652+
1653+
def __le__(self, other: object):
1654+
return self.__getattr__("__le__")
1655+
1656+
def __eq__(self, other: object):
1657+
return self.__getattr__("__eq__")
1658+
1659+
def __ne__(self, other: object):
1660+
return self.__getattr__("__ne__")
1661+
1662+
def __gt__(self, other: object):
1663+
return self.__getattr__("__gt__")
1664+
1665+
def __ge__(self, other: object):
1666+
return self.__getattr__("__ge__")
1667+
1668+
def __hash__(self):
1669+
return self.__getattr__("__hash__")
1670+
1671+
def __bool__(self):
1672+
return self.__getattr__("__bool__")
1673+
1674+
# [Callable objects]
1675+
1676+
def __call__(self, *args: object, **kwargs: object):
1677+
return self.__getattr__("__call__")
1678+
1679+
# [Container types]
1680+
1681+
def __len__(self):
1682+
return self.__getattr__("__len__")
1683+
1684+
def __getitem__(self, key: object):
1685+
return self.__getattr__("__getitem__")
1686+
1687+
def __setitem__(self, key: object, value: object):
1688+
return self.__getattr__("__setitem__")
1689+
1690+
def __delitem__(self, key: object):
1691+
return self.__getattr__("__delitem__")
1692+
1693+
# __missing__ is optional according to __getitem__ specification,
1694+
# so it is skipped
1695+
1696+
# __iter__ and __reversed__ have a default implementation
1697+
# based on __len__ and __getitem__, so they are skipped.
1698+
1699+
# [Numeric Types]
1700+
1701+
def __add__(self, other: object):
1702+
return self.__getattr__("__add__")
1703+
1704+
def __sub__(self, other: object):
1705+
return self.__getattr__("__sub__")
1706+
1707+
def __mul__(self, other: object):
1708+
return self.__getattr__("__mul__")
1709+
1710+
def __matmul__(self, other: object):
1711+
return self.__getattr__("__matmul__")
1712+
1713+
def __truediv__(self, other: object):
1714+
return self.__getattr__("__truediv__")
1715+
1716+
def __floordiv__(self, other: object):
1717+
return self.__getattr__("__floordiv__")
1718+
1719+
def __mod__(self, other: object):
1720+
return self.__getattr__("__mod__")
1721+
1722+
def __divmod__(self, other: object):
1723+
return self.__getattr__("__divmod__")
1724+
1725+
def __pow__(self, other: object, modulo: object = ...):
1726+
return self.__getattr__("__pow__")
1727+
1728+
def __lshift__(self, other: object):
1729+
return self.__getattr__("__lshift__")
1730+
1731+
def __rshift__(self, other: object):
1732+
return self.__getattr__("__rshift__")
1733+
1734+
def __and__(self, other: object):
1735+
return self.__getattr__("__and__")
1736+
1737+
def __xor__(self, other: object):
1738+
return self.__getattr__("__xor__")
1739+
1740+
def __or__(self, other: object):
1741+
return self.__getattr__("__or__")
1742+
1743+
# r* and i* methods have lower priority than
1744+
# the methods for left operand so they are skipped
1745+
1746+
def __neg__(self):
1747+
return self.__getattr__("__neg__")
1748+
1749+
def __pos__(self):
1750+
return self.__getattr__("__pos__")
1751+
1752+
def __abs__(self):
1753+
return self.__getattr__("__abs__")
1754+
1755+
def __invert__(self):
1756+
return self.__getattr__("__invert__")
1757+
1758+
# __complex__, __int__ and __float__ have a default implementation
1759+
# based on __index__, so they are skipped.
1760+
1761+
def __index__(self):
1762+
return self.__getattr__("__index__")
1763+
1764+
def __round__(self, ndigits: object = ...):
1765+
return self.__getattr__("__round__")
1766+
1767+
def __trunc__(self):
1768+
return self.__getattr__("__trunc__")
1769+
1770+
def __floor__(self):
1771+
return self.__getattr__("__floor__")
1772+
1773+
def __ceil__(self):
1774+
return self.__getattr__("__ceil__")
1775+
1776+
# [Context managers]
1777+
1778+
def __enter__(self):
1779+
return self.__getattr__("__enter__")
1780+
1781+
def __exit__(self, *args: object, **kwargs: object):
1782+
return self.__getattr__("__exit__")
1783+
1784+
1785+
class PlaceholderModule(_PlaceholderBase):
16321786
"""
16331787
A placeholder object to use when a module does not exist.
16341788
16351789
This enables more informative errors when trying to access attributes
16361790
of a module that does not exists.
16371791
"""
1638-
name: str
1792+
1793+
def __init__(self, name: str) -> None:
1794+
super().__init__()
1795+
1796+
# Apply name mangling to avoid conflicting with module attributes
1797+
self.__name = name
16391798

16401799
def placeholder_attr(self, attr_path: str):
16411800
return _PlaceholderModuleAttr(self, attr_path)
16421801

16431802
def __getattr__(self, key: str):
1644-
name = self.name
1803+
name = self.__name
16451804

16461805
try:
1647-
importlib.import_module(self.name)
1806+
importlib.import_module(name)
16481807
except ImportError as exc:
16491808
for extra, names in get_vllm_optional_dependencies().items():
16501809
if name in names:
@@ -1657,17 +1816,21 @@ def __getattr__(self, key: str):
16571816
"when the original module can be imported")
16581817

16591818

1660-
@dataclass(frozen=True)
1661-
class _PlaceholderModuleAttr:
1662-
module: PlaceholderModule
1663-
attr_path: str
1819+
class _PlaceholderModuleAttr(_PlaceholderBase):
1820+
1821+
def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
1822+
super().__init__()
1823+
1824+
# Apply name mangling to avoid conflicting with module attributes
1825+
self.__module = module
1826+
self.__attr_path = attr_path
16641827

16651828
def placeholder_attr(self, attr_path: str):
1666-
return _PlaceholderModuleAttr(self.module,
1667-
f"{self.attr_path}.{attr_path}")
1829+
return _PlaceholderModuleAttr(self.__module,
1830+
f"{self.__attr_path}.{attr_path}")
16681831

16691832
def __getattr__(self, key: str):
1670-
getattr(self.module, f"{self.attr_path}.{key}")
1833+
getattr(self.__module, f"{self.__attr_path}.{key}")
16711834

16721835
raise AssertionError("PlaceholderModule should not be used "
16731836
"when the original module can be imported")

0 commit comments

Comments
 (0)