From 52fa0fb701e327f2b4fa74a96bf2494d27baaf5b Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Fri, 28 Jun 2019 14:32:54 +0300 Subject: [PATCH] xml.etree.ElementTree: fix missing None in get(), findtext() return type Before, this code from xml.etree.ElementTree import Element reveal_type(Element('foo').findtext('bar')) would reveal `builtins.str`. This is incorrect because if the path is not found, a default None is returned. After, it reveals `Union[builtins.str, None]`. The solution with `overload`s is the same as used by `Mapping.get` in `stdlib/3/typing.pyi`. --- stdlib/2and3/xml/etree/ElementTree.pyi | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/stdlib/2and3/xml/etree/ElementTree.pyi b/stdlib/2and3/xml/etree/ElementTree.pyi index d5f64b67a96e..2f2c1a59ecca 100644 --- a/stdlib/2and3/xml/etree/ElementTree.pyi +++ b/stdlib/2and3/xml/etree/ElementTree.pyi @@ -1,6 +1,6 @@ # Stubs for xml.etree.ElementTree -from typing import Any, Callable, Dict, Generator, IO, ItemsView, Iterable, Iterator, KeysView, List, MutableSequence, Optional, overload, Sequence, Text, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generator, IO, ItemsView, Iterable, Iterator, KeysView, List, MutableSequence, Optional, overload, Sequence, Text, Tuple, TypeVar, Union, overload import io import sys @@ -53,8 +53,14 @@ class Element(MutableSequence[Element]): def extend(self, elements: Iterable[Element]) -> None: ... def find(self, path: _str_argument_type, namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> Optional[Element]: ... def findall(self, path: _str_argument_type, namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> List[Element]: ... - def findtext(self, path: _str_argument_type, default: Optional[_T] = ..., namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> Union[_T, _str_result_type]: ... - def get(self, key: _str_argument_type, default: Optional[_T] = ...) -> Union[_str_result_type, _T]: ... + @overload + def findtext(self, path: _str_argument_type, *, namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> Optional[_str_result_type]: ... + @overload + def findtext(self, path: _str_argument_type, default: _T, namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> Union[_T, _str_result_type]: ... + @overload + def get(self, key: _str_argument_type) -> Optional[_str_result_type]: ... + @overload + def get(self, key: _str_argument_type, default: _T) -> Union[_str_result_type, _T]: ... def getchildren(self) -> List[Element]: ... def getiterator(self, tag: Optional[_str_argument_type] = ...) -> List[Element]: ... if sys.version_info >= (3, 2): @@ -102,7 +108,10 @@ class ElementTree: def iter(self, tag: Optional[_str_argument_type] = ...) -> Generator[Element, None, None]: ... def getiterator(self, tag: Optional[_str_argument_type] = ...) -> List[Element]: ... def find(self, path: _str_argument_type, namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> Optional[Element]: ... - def findtext(self, path: _str_argument_type, default: Optional[_T] = ..., namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> Union[_T, _str_result_type]: ... + @overload + def findtext(self, path: _str_argument_type, *, namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> Optional[_str_result_type]: ... + @overload + def findtext(self, path: _str_argument_type, default: _T, namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> Union[_T, _str_result_type]: ... def findall(self, path: _str_argument_type, namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> List[Element]: ... def iterfind(self, path: _str_argument_type, namespaces: Optional[Dict[_str_argument_type, _str_argument_type]] = ...) -> List[Element]: ... if sys.version_info >= (3, 4):