|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
| 5 | +from collections.abc import Callable |
5 | 6 | from contextlib import suppress
|
6 | 7 | from dataclasses import dataclass
|
7 | 8 | from functools import cache
|
8 | 9 | from importlib import import_module
|
9 | 10 | from inspect import isclass, isroutine
|
10 |
| -from typing import Any, Callable, Union |
| 11 | +from types import UnionType |
| 12 | +from typing import Any, Union, get_type_hints |
11 | 13 |
|
12 | 14 | from sphinx_codeautolink.parse import Name, NameBreak
|
13 | 15 |
|
@@ -116,34 +118,27 @@ def call_value(cursor: Cursor) -> None:
|
116 | 118 |
|
117 | 119 | def get_return_annotation(func: Callable) -> type | None:
|
118 | 120 | """Determine the target of a function return type hint."""
|
119 |
| - annotations = getattr(func, "__annotations__", {}) |
120 |
| - ret_annotation = annotations.get("return", None) |
| 121 | + annotation = get_type_hints(func).get("return") |
121 | 122 |
|
122 | 123 | # Inner type from typing.Optional or Union[None, T]
|
123 |
| - origin = getattr(ret_annotation, "__origin__", None) |
124 |
| - args = getattr(ret_annotation, "__args__", None) |
125 |
| - if origin is Union and len(args) == 2: # noqa: PLR2004 |
| 124 | + origin = getattr(annotation, "__origin__", None) |
| 125 | + args = getattr(annotation, "__args__", None) |
| 126 | + if (origin is Union or isinstance(annotation, UnionType)) and len(args) == 2: # noqa: PLR2004 |
126 | 127 | nonetype = type(None)
|
127 | 128 | if args[0] is nonetype:
|
128 |
| - ret_annotation = args[1] |
| 129 | + annotation = args[1] |
129 | 130 | elif args[1] is nonetype:
|
130 |
| - ret_annotation = args[0] |
131 |
| - |
132 |
| - # Try to resolve a string annotation in the module scope |
133 |
| - if isinstance(ret_annotation, str): |
134 |
| - location = fully_qualified_name(func) |
135 |
| - mod, _ = closest_module(tuple(location.split("."))) |
136 |
| - ret_annotation = getattr(mod, ret_annotation, ret_annotation) |
| 131 | + annotation = args[0] |
137 | 132 |
|
138 | 133 | if (
|
139 |
| - not ret_annotation |
140 |
| - or not isinstance(ret_annotation, type) |
141 |
| - or hasattr(ret_annotation, "__origin__") |
| 134 | + not annotation |
| 135 | + or not isinstance(annotation, type) |
| 136 | + or hasattr(annotation, "__origin__") |
142 | 137 | ):
|
143 | 138 | msg = f"Unable to follow return annotation of {get_name_for_debugging(func)}."
|
144 | 139 | raise CouldNotResolve(msg)
|
145 | 140 |
|
146 |
| - return ret_annotation |
| 141 | + return annotation |
147 | 142 |
|
148 | 143 |
|
149 | 144 | def fully_qualified_name(thing: type | Callable) -> str:
|
|
0 commit comments