Skip to content

ENH Add a css wrapper to generated types #397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 48 additions & 6 deletions src/sphinx_autodoc_typehints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AnyStr, Callable, ForwardRef, NewType, TypeVar, get_type_hints

from docutils import nodes
from docutils.frontend import OptionParser
from docutils.parsers.rst import Parser as RstParser
from docutils.parsers.rst import states
from docutils.utils import new_document
from sphinx.ext.autodoc.mock import mock
from sphinx.util import logging
from sphinx.util import logging, rst
from sphinx.util.inspect import signature as sphinx_signature
from sphinx.util.inspect import stringify_signature

Expand Down Expand Up @@ -209,7 +211,7 @@ def format_annotation(annotation: Any, config: Config) -> str: # noqa: C901, PL
fully_qualified: bool = getattr(config, "typehints_fully_qualified", False)
prefix = "" if fully_qualified or full_name == class_name else "~"
role = "data" if module == "typing" and class_name in _PYDATA_ANNOTATIONS else "class"
args_format = "\\[{}]"
args_format = "\\ \\[{}]"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be better to use r here and in lots of other places: r"\ \[{}]" would have half as many slashes...

formatted_args: str | None = ""

# Some types require special handling
Expand Down Expand Up @@ -242,9 +244,9 @@ def format_annotation(annotation: Any, config: Config) -> str: # noqa: C901, PL
args = tuple(x for x in args if x is not type(None))
elif full_name in ("typing.Callable", "collections.abc.Callable") and args and args[0] is not ...:
fmt = [format_annotation(arg, config) for arg in args]
formatted_args = f"\\[\\[{', '.join(fmt[:-1])}], {fmt[-1]}]"
formatted_args = f"\\ \\[\\[{', '.join(fmt[:-1])}], {fmt[-1]}]"
elif full_name == "typing.Literal":
formatted_args = f"\\[{', '.join(f'``{arg!r}``' for arg in args)}]"
formatted_args = f"\\ \\[{', '.join(f'``{arg!r}``' for arg in args)}]"
elif full_name == "types.UnionType":
return " | ".join([format_annotation(arg, config) for arg in args])

Expand Down Expand Up @@ -724,7 +726,7 @@ def _inject_signature( # noqa: C901
if annotation is None:
type_annotation = f":type {arg_name}: "
else:
formatted_annotation = format_annotation(annotation, app.config)
formatted_annotation = add_type_css_class(format_annotation(annotation, app.config))
type_annotation = f":type {arg_name}: {formatted_annotation}"

if app.config.typehints_defaults:
Expand Down Expand Up @@ -843,7 +845,7 @@ def _inject_rtype( # noqa: PLR0913
if not app.config.typehints_use_rtype and r.found_return and " -- " in lines[insert_index]:
return

formatted_annotation = format_annotation(type_hints["return"], app.config)
formatted_annotation = add_type_css_class(format_annotation(type_hints["return"], app.config))

if r.found_param and insert_index < len(lines) and lines[insert_index].strip():
insert_index -= 1
Expand Down Expand Up @@ -874,6 +876,45 @@ def validate_config(app: Sphinx, env: BuildEnvironment, docnames: list[str]) ->
raise ValueError(msg)


def unescape(escaped: str) -> str:
# For some reason the string we get has a bunch of null bytes in it??
# Remove them...
escaped = escaped.replace("\x00", "")
# For some reason the extra slash before spaces gets lost between the .rst
# source and when this directive is called. So don't replace "\<space>" =>
# "<space>"
return re.sub(r"\\([^ ])", r"\1", escaped)


def add_type_css_class(type_rst: str) -> str:
return f":sphinx_autodoc_typehints_type:`{rst.escape(type_rst)}`"


def sphinx_autodoc_typehints_type_role(
_role: str,
_rawtext: str,
text: str,
_lineno: int,
inliner: states.Inliner,
_options: dict[str, Any] | None = None,
_content: list[str] | None = None,
) -> tuple[list[Node], list[Node]]:
"""
Add css tag around rendered type.

The body should be escaped rst. This renders its body as rst and wraps the
result in <span class="sphinx_autodoc_typehints-type"> </span>
"""
unescaped = unescape(text)
# the typestubs for docutils don't have any info about Inliner
doc = new_document("", inliner.document.settings) # type: ignore[attr-defined]
RstParser().parse(unescaped, doc)
n = nodes.inline(text)
n["classes"].append("sphinx_autodoc_typehints-type")
n += doc.children[0].children
return [n], []


def setup(app: Sphinx) -> dict[str, bool]:
app.add_config_value("always_document_param_types", False, "html") # noqa: FBT003
app.add_config_value("typehints_fully_qualified", False, "env") # noqa: FBT003
Expand All @@ -884,6 +925,7 @@ def setup(app: Sphinx) -> dict[str, bool]:
app.add_config_value("typehints_formatter", None, "env")
app.add_config_value("typehints_use_signature", False, "env") # noqa: FBT003
app.add_config_value("typehints_use_signature_return", False, "env") # noqa: FBT003
app.add_role("sphinx_autodoc_typehints_type", sphinx_autodoc_typehints_type_role)
app.connect("env-before-read-docs", validate_config) # config may be changed after “config-inited” event
app.connect("autodoc-process-signature", process_signature)
app.connect("autodoc-process-docstring", process_docstring)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def function_with_escaped_default(x: str = "\b"): # noqa: ANN201, ARG001
Function docstring.

Parameters:
**x** (*a.b.c*) -- foo
**x** (a.b.c) -- foo
""",
)
def function_with_unresolvable_annotation(x: a.b.c): # noqa: ANN201, ARG001, F821
Expand Down
111 changes: 53 additions & 58 deletions tests/test_sphinx_autodoc_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,87 +201,87 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
(type, ":py:class:`type`"),
(collections.abc.Callable, ":py:class:`~collections.abc.Callable`"),
(Type, ":py:class:`~typing.Type`"),
(Type[A], ":py:class:`~typing.Type`\\[:py:class:`~%s.A`]" % __name__),
(Type[A], ":py:class:`~typing.Type`\\ \\[:py:class:`~%s.A`]" % __name__),
(Any, ":py:data:`~typing.Any`"),
(AnyStr, ":py:data:`~typing.AnyStr`"),
(Generic[T], ":py:class:`~typing.Generic`\\[:py:class:`~typing.TypeVar`\\(``T``)]"),
(Generic[T], ":py:class:`~typing.Generic`\\ \\[:py:class:`~typing.TypeVar`\\(``T``)]"),
(Mapping, ":py:class:`~typing.Mapping`"),
(
Mapping[T, int], # type: ignore[valid-type]
":py:class:`~typing.Mapping`\\[:py:class:`~typing.TypeVar`\\(``T``), :py:class:`int`]",
":py:class:`~typing.Mapping`\\ \\[:py:class:`~typing.TypeVar`\\(``T``), :py:class:`int`]",
),
(
Mapping[str, V_contra], # type: ignore[valid-type]
":py:class:`~typing.Mapping`\\[:py:class:`str`, :py:class:`~typing.TypeVar`\\("
":py:class:`~typing.Mapping`\\ \\[:py:class:`str`, :py:class:`~typing.TypeVar`\\("
"``V_contra``, contravariant=True)]",
),
(
Mapping[T, U_co], # type: ignore[valid-type]
":py:class:`~typing.Mapping`\\[:py:class:`~typing.TypeVar`\\(``T``), "
":py:class:`~typing.Mapping`\\ \\[:py:class:`~typing.TypeVar`\\(``T``), "
":py:class:`~typing.TypeVar`\\(``U_co``, covariant=True)]",
),
(Mapping[str, bool], ":py:class:`~typing.Mapping`\\[:py:class:`str`, :py:class:`bool`]"),
(Mapping[str, bool], ":py:class:`~typing.Mapping`\\ \\[:py:class:`str`, :py:class:`bool`]"),
(Dict, ":py:class:`~typing.Dict`"),
(
Dict[T, int], # type: ignore[valid-type]
":py:class:`~typing.Dict`\\[:py:class:`~typing.TypeVar`\\(``T``), :py:class:`int`]",
":py:class:`~typing.Dict`\\ \\[:py:class:`~typing.TypeVar`\\(``T``), :py:class:`int`]",
),
(
Dict[str, V_contra], # type: ignore[valid-type]
":py:class:`~typing.Dict`\\[:py:class:`str`, :py:class:`~typing.TypeVar`\\(``V_contra``, contravariant=True)]",
":py:class:`~typing.Dict`\\ \\[:py:class:`str`, :py:class:`~typing.TypeVar`\\(``V_contra``, contravariant=True)]", # noqa: E501
),
(
Dict[T, U_co], # type: ignore[valid-type]
":py:class:`~typing.Dict`\\[:py:class:`~typing.TypeVar`\\(``T``),"
":py:class:`~typing.Dict`\\ \\[:py:class:`~typing.TypeVar`\\(``T``),"
" :py:class:`~typing.TypeVar`\\(``U_co``, covariant=True)]",
),
(Dict[str, bool], ":py:class:`~typing.Dict`\\[:py:class:`str`, :py:class:`bool`]"),
(Dict[str, bool], ":py:class:`~typing.Dict`\\ \\[:py:class:`str`, :py:class:`bool`]"),
(Tuple, ":py:data:`~typing.Tuple`"),
(Tuple[str, bool], ":py:data:`~typing.Tuple`\\[:py:class:`str`, :py:class:`bool`]"),
(Tuple[int, int, int], ":py:data:`~typing.Tuple`\\[:py:class:`int`, :py:class:`int`, :py:class:`int`]"),
(Tuple[str, ...], ":py:data:`~typing.Tuple`\\[:py:class:`str`, :py:data:`...<Ellipsis>`]"),
(Tuple[str, bool], ":py:data:`~typing.Tuple`\\ \\[:py:class:`str`, :py:class:`bool`]"),
(Tuple[int, int, int], ":py:data:`~typing.Tuple`\\ \\[:py:class:`int`, :py:class:`int`, :py:class:`int`]"),
(Tuple[str, ...], ":py:data:`~typing.Tuple`\\ \\[:py:class:`str`, :py:data:`...<Ellipsis>`]"),
(Union, ":py:data:`~typing.Union`"),
(Union[str, bool], ":py:data:`~typing.Union`\\[:py:class:`str`, :py:class:`bool`]"),
(Union[str, bool, None], ":py:data:`~typing.Union`\\[:py:class:`str`, :py:class:`bool`, :py:obj:`None`]"),
pytest.param(Union[str, Any], ":py:data:`~typing.Union`\\[:py:class:`str`, :py:data:`~typing.Any`]"),
(Optional[str], ":py:data:`~typing.Optional`\\[:py:class:`str`]"),
(Union[str, None], ":py:data:`~typing.Optional`\\[:py:class:`str`]"),
(Union[str, bool], ":py:data:`~typing.Union`\\ \\[:py:class:`str`, :py:class:`bool`]"),
(Union[str, bool, None], ":py:data:`~typing.Union`\\ \\[:py:class:`str`, :py:class:`bool`, :py:obj:`None`]"),
pytest.param(Union[str, Any], ":py:data:`~typing.Union`\\ \\[:py:class:`str`, :py:data:`~typing.Any`]"),
(Optional[str], ":py:data:`~typing.Optional`\\ \\[:py:class:`str`]"),
(Union[str, None], ":py:data:`~typing.Optional`\\ \\[:py:class:`str`]"),
(
Optional[Union[str, bool]],
":py:data:`~typing.Union`\\[:py:class:`str`, :py:class:`bool`, :py:obj:`None`]",
":py:data:`~typing.Union`\\ \\[:py:class:`str`, :py:class:`bool`, :py:obj:`None`]",
),
(Callable, ":py:data:`~typing.Callable`"),
(Callable[..., int], ":py:data:`~typing.Callable`\\[:py:data:`...<Ellipsis>`, :py:class:`int`]"),
(Callable[[int], int], ":py:data:`~typing.Callable`\\[\\[:py:class:`int`], :py:class:`int`]"),
(Callable[..., int], ":py:data:`~typing.Callable`\\ \\[:py:data:`...<Ellipsis>`, :py:class:`int`]"),
(Callable[[int], int], ":py:data:`~typing.Callable`\\ \\[\\[:py:class:`int`], :py:class:`int`]"),
(
Callable[[int, str], bool],
":py:data:`~typing.Callable`\\[\\[:py:class:`int`, :py:class:`str`], :py:class:`bool`]",
":py:data:`~typing.Callable`\\ \\[\\[:py:class:`int`, :py:class:`str`], :py:class:`bool`]",
),
(
Callable[[int, str], None],
":py:data:`~typing.Callable`\\[\\[:py:class:`int`, :py:class:`str`], :py:obj:`None`]",
":py:data:`~typing.Callable`\\ \\[\\[:py:class:`int`, :py:class:`str`], :py:obj:`None`]",
),
(
Callable[[T], T],
":py:data:`~typing.Callable`\\[\\[:py:class:`~typing.TypeVar`\\(``T``)],"
":py:data:`~typing.Callable`\\ \\[\\[:py:class:`~typing.TypeVar`\\(``T``)],"
" :py:class:`~typing.TypeVar`\\(``T``)]",
),
(
AbcCallable[[int, str], bool], # type: ignore[valid-type,misc,type-arg]
":py:class:`~collections.abc.Callable`\\[\\[:py:class:`int`, :py:class:`str`], :py:class:`bool`]",
":py:class:`~collections.abc.Callable`\\ \\[\\[:py:class:`int`, :py:class:`str`], :py:class:`bool`]",
),
(Pattern, ":py:class:`~typing.Pattern`"),
(Pattern[str], ":py:class:`~typing.Pattern`\\[:py:class:`str`]"),
(Pattern[str], ":py:class:`~typing.Pattern`\\ \\[:py:class:`str`]"),
(IO, ":py:class:`~typing.IO`"),
(IO[str], ":py:class:`~typing.IO`\\[:py:class:`str`]"),
(IO[str], ":py:class:`~typing.IO`\\ \\[:py:class:`str`]"),
(Metaclass, ":py:class:`~%s.Metaclass`" % __name__),
(A, ":py:class:`~%s.A`" % __name__),
(B, ":py:class:`~%s.B`" % __name__),
(B[int], ":py:class:`~%s.B`\\[:py:class:`int`]" % __name__),
(B[int], ":py:class:`~%s.B`\\ \\[:py:class:`int`]" % __name__),
(C, ":py:class:`~%s.C`" % __name__),
(D, ":py:class:`~%s.D`" % __name__),
(E, ":py:class:`~%s.E`" % __name__),
(E[int], ":py:class:`~%s.E`\\[:py:class:`int`]" % __name__),
(E[int], ":py:class:`~%s.E`\\ \\[:py:class:`int`]" % __name__),
(W, f":py:{'class' if PY310_PLUS else 'func'}:`~typing.NewType`\\(``W``, :py:class:`str`)"),
(T, ":py:class:`~typing.TypeVar`\\(``T``)"),
(U_co, ":py:class:`~typing.TypeVar`\\(``U_co``, covariant=True)"),
Expand All @@ -306,17 +306,17 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
# Zero-length tuple remains
(Tuple[()], ":py:data:`~typing.Tuple`"),
# Internal single tuple with simple types is flattened in the output
(Tuple[(int,)], ":py:data:`~typing.Tuple`\\[:py:class:`int`]"),
(Tuple[(int, int)], ":py:data:`~typing.Tuple`\\[:py:class:`int`, :py:class:`int`]"),
(Tuple[(int,)], ":py:data:`~typing.Tuple`\\ \\[:py:class:`int`]"),
(Tuple[(int, int)], ":py:data:`~typing.Tuple`\\ \\[:py:class:`int`, :py:class:`int`]"),
# Ellipsis in single tuple also gets flattened
(Tuple[(int, ...)], ":py:data:`~typing.Tuple`\\[:py:class:`int`, :py:data:`...<Ellipsis>`]"),
(Tuple[(int, ...)], ":py:data:`~typing.Tuple`\\ \\[:py:class:`int`, :py:data:`...<Ellipsis>`]"),
(
RecList,
":py:data:`~typing.Union`\\[:py:class:`int`, :py:class:`~typing.List`\\[RecList]]",
":py:data:`~typing.Union`\\ \\[:py:class:`int`, :py:class:`~typing.List`\\ \\[RecList]]",
),
(
MutualRecA,
":py:data:`~typing.Union`\\[:py:class:`bool`, :py:class:`~typing.List`\\[MutualRecB]]",
":py:data:`~typing.Union`\\ \\[:py:class:`bool`, :py:class:`~typing.List`\\ \\[MutualRecB]]",
),
]

Expand All @@ -327,39 +327,39 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t
(
nptyping.NDArray[nptyping.Shape["*"], nptyping.Float],
(
":py:class:`~nptyping.ndarray.NDArray`\\[:py:class:`~nptyping.base_meta_classes.Shape`\\[*], "
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:class:`~nptyping.base_meta_classes.Shape`\\ \\[*], "
":py:class:`~numpy.float64`]"
),
),
(
nptyping.NDArray[nptyping.Shape["64"], nptyping.Float],
(
":py:class:`~nptyping.ndarray.NDArray`\\[:py:class:`~nptyping.base_meta_classes.Shape`\\[64],"
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:class:`~nptyping.base_meta_classes.Shape`\\ \\[64],"
" :py:class:`~numpy.float64`]"
),
),
(
nptyping.NDArray[nptyping.Shape["*, *"], nptyping.Float],
(
":py:class:`~nptyping.ndarray.NDArray`\\[:py:class:`~nptyping.base_meta_classes.Shape`\\[*, "
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:class:`~nptyping.base_meta_classes.Shape`\\ \\[*, "
"*], :py:class:`~numpy.float64`]"
),
),
(
nptyping.NDArray[nptyping.Shape["*, ..."], nptyping.Float],
":py:class:`~nptyping.ndarray.NDArray`\\[:py:data:`~typing.Any`, :py:class:`~numpy.float64`]",
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:data:`~typing.Any`, :py:class:`~numpy.float64`]",
),
(
nptyping.NDArray[nptyping.Shape["*, 3"], nptyping.Float],
(
":py:class:`~nptyping.ndarray.NDArray`\\[:py:class:`~nptyping.base_meta_classes.Shape`\\[*, 3"
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:class:`~nptyping.base_meta_classes.Shape`\\ \\[*, 3"
"], :py:class:`~numpy.float64`]"
),
),
(
nptyping.NDArray[nptyping.Shape["3, ..."], nptyping.Float],
(
":py:class:`~nptyping.ndarray.NDArray`\\[:py:class:`~nptyping.base_meta_classes.Shape`\\[3, "
":py:class:`~nptyping.ndarray.NDArray`\\ \\[:py:class:`~nptyping.base_meta_classes.Shape`\\ \\[3, "
"...], :py:class:`~numpy.float64`]"
),
),
Expand All @@ -379,7 +379,7 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
# subsequent tests
expected_result_not_simplified = expected_result.replace(", ``None``", "")
# encapsulate Union in typing.Optional
expected_result_not_simplified = ":py:data:`~typing.Optional`\\[" + expected_result_not_simplified
expected_result_not_simplified = ":py:data:`~typing.Optional`\\ \\[" + expected_result_not_simplified
expected_result_not_simplified += "]"
conf = create_autospec(Config, simplify_optional_unions=False, _annotation_globals=globals())
assert format_annotation(annotation, conf) == expected_result_not_simplified
Expand Down Expand Up @@ -421,11 +421,11 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
@pytest.mark.parametrize(
("annotation", "params", "expected_result"),
[
("ClassVar", int, ":py:data:`~typing.ClassVar`\\[:py:class:`int`]"),
("ClassVar", int, ":py:data:`~typing.ClassVar`\\ \\[:py:class:`int`]"),
("NoReturn", None, ":py:data:`~typing.NoReturn`"),
("Literal", ("a", 1), ":py:data:`~typing.Literal`\\[``'a'``, ``1``]"),
("Literal", ("a", 1), ":py:data:`~typing.Literal`\\ \\[``'a'``, ``1``]"),
("Type", None, ":py:class:`~typing.Type`"),
("Type", (A,), f":py:class:`~typing.Type`\\[:py:class:`~{__name__}.A`]"),
("Type", (A,), f":py:class:`~typing.Type`\\ \\[:py:class:`~{__name__}.A`]"),
],
)
def test_format_annotation_both_libs(library: ModuleType, annotation: str, params: Any, expected_result: str) -> None:
Expand Down Expand Up @@ -524,16 +524,11 @@ class dummy_module.DataClass(x)

def maybe_fix_py310(expected_contents: str) -> str:
if not PY310_PLUS:
return expected_contents
return expected_contents.replace('"', "")

for old, new in [
("*bool** | **None*", '"Optional"["bool"]'),
("*int** | **str** | **float*", '"int" | "str" | "float"'),
("*str** | **None*", '"Optional"["str"]'),
("(*bool*)", '("bool")'),
("(*int*", '("int"'),
(" str", ' "str"'),
('"Optional"["str"]', '"Optional"["str"]'),
('"Optional"["Callable"[["int", "bytes"], "int"]]', '"Optional"["Callable"[["int", "bytes"], "int"]]'),
("bool | None", '"Optional"["bool"]'),
("str | None", '"Optional"["str"]'),
]:
expected_contents = expected_contents.replace(old, new)
return expected_contents
Expand All @@ -559,14 +554,14 @@ def test_sphinx_output_future_annotations(app: SphinxTestApp, status: StringIO)
Method docstring.

Parameters:
* **x** (*bool** | **None*) -- foo
* **x** (bool | None) -- foo

* **y** (*int** | **str** | **float*) -- bar
* **y** ("int" | "str" | "float") -- bar

* **z** (*str** | **None*) -- baz
* **z** (str | None) -- baz

Return type:
str
"str"
"""
expected_contents = maybe_fix_py310(dedent(expected_contents))
assert contents == expected_contents
Expand Down Expand Up @@ -625,7 +620,7 @@ def test_sphinx_output_defaults(
("formatter_config_val", "expected"),
[
(None, ['("bool") -- foo', '("int") -- bar', '"str"']),
(lambda ann, conf: "Test", ["(*Test*) -- foo", "(*Test*) -- bar", "Test"]), # noqa: ARG005
(lambda ann, conf: "Test", ["(Test) -- foo", "(Test) -- bar", "Test"]), # noqa: ARG005
("some string", Exception("needs to be callable or `None`")),
],
)
Expand Down