diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a6877dc6..98fa4d1ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fix a bug with `basilisp.edn/write-string` where nested double quotes were not escaped properly (#1071) * Fix a bug where additional arguments to `basilisp test` CLI subcommand were not being passed correctly to Pytest (#1075) +### Other + * Improve the state of the Python type hints in `basilisp.lang.multifn` (#800) + ## [v0.2.3] ### Added * Added a compiler metadata flag for suppressing warnings when Var indirection is unavoidable (#1052) diff --git a/src/basilisp/lang/multifn.py b/src/basilisp/lang/multifn.py index 2815a6abd..53265a1fa 100644 --- a/src/basilisp/lang/multifn.py +++ b/src/basilisp/lang/multifn.py @@ -1,5 +1,7 @@ import threading -from typing import Any, Callable, Generic, Optional, TypeVar +from typing import Any, Callable, Optional, TypeVar + +from typing_extensions import Concatenate, Generic, ParamSpec from basilisp.lang import map as lmap from basilisp.lang import runtime @@ -8,15 +10,16 @@ from basilisp.lang.set import PersistentSet T = TypeVar("T") -DispatchFunction = Callable[..., T] -Method = Callable[..., Any] +P = ParamSpec("P") +DispatchFunction = Callable[Concatenate[T, P], T] +Method = Callable[Concatenate[T, P], Any] _GLOBAL_HIERARCHY_SYM = sym.symbol("global-hierarchy", ns=runtime.CORE_NS) _ISA_SYM = sym.symbol("isa?", ns=runtime.CORE_NS) -class MultiFunction(Generic[T]): +class MultiFunction(Generic[T, P]): __slots__ = ( "_name", "_default", @@ -33,7 +36,7 @@ class MultiFunction(Generic[T]): def __init__( self, name: sym.Symbol, - dispatch: DispatchFunction, + dispatch: DispatchFunction[T, P], default: T, hierarchy: Optional[IRef] = None, ) -> None: @@ -63,11 +66,11 @@ def __init__( # caches. self._cached_hierarchy = self._hierarchy.deref() - def __call__(self, *args, **kwargs): - key = self._dispatch(*args, **kwargs) + def __call__(self, v: T, *args: P.args, **kwargs: P.kwargs) -> Any: + key = self._dispatch(v, *args, **kwargs) method = self.get_method(key) if method is not None: - return method(*args, **kwargs) + return method(v, *args, **kwargs) raise NotImplementedError def _reset_cache(self): @@ -94,14 +97,14 @@ def _precedes(self, tag: T, parent: T) -> bool: selection.""" return self._has_preference(tag, parent) or self._is_a(tag, parent) - def add_method(self, key: T, method: Method) -> None: + def add_method(self, key: T, method: Method[T, P]) -> None: """Add a new method to this function which will respond for key returned from the dispatch function.""" with self._lock: self._methods = self._methods.assoc(key, method) self._reset_cache() - def _find_and_cache_method(self, key: T) -> Optional[Method]: + def _find_and_cache_method(self, key: T) -> Optional[Method[T, P]]: """Find and cache the best method for dispatch value `key`.""" with self._lock: best_key: Optional[T] = None @@ -125,7 +128,7 @@ def _find_and_cache_method(self, key: T) -> Optional[Method]: return best_method - def get_method(self, key: T) -> Optional[Method]: + def get_method(self, key: T) -> Optional[Method[T, P]]: """Return the method which would handle this dispatch key or None if no method defined for this key and no default.""" if self._cached_hierarchy != self._hierarchy.deref(): @@ -159,7 +162,7 @@ def prefers(self): """Return a mapping of preferred values to the set of other values.""" return self._prefers - def remove_method(self, key: T) -> Optional[Method]: + def remove_method(self, key: T) -> Optional[Method[T, P]]: """Remove the method defined for this key and return it.""" with self._lock: method = self._methods.val_at(key, None) @@ -179,5 +182,5 @@ def default(self) -> T: return self._default @property - def methods(self) -> IPersistentMap[T, Method]: + def methods(self) -> IPersistentMap[T, Method[T, P]]: return self._methods