diff --git a/stubs/tensorflow/METADATA.toml b/stubs/tensorflow/METADATA.toml new file mode 100644 index 000000000000..5b1aac831e20 --- /dev/null +++ b/stubs/tensorflow/METADATA.toml @@ -0,0 +1,2 @@ +version = "2.8.*" +requires = ["numpy"] \ No newline at end of file diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi new file mode 100644 index 000000000000..20fc68d3b070 --- /dev/null +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -0,0 +1,68 @@ +# Alias for bool is used because tensorflow name shadows bool with tf.bool. +from builtins import bool as _bool +from typing import Any, Iterable, Iterator, NoReturn, overload + +import numpy as np +from tensorflow._aliases import _TensorCompatible + +# Most tf.math functions are exported from tf., but not all of them are. +from tensorflow.math import abs as abs + +def __getattr__(name: str) -> Any: ... # incomplete + +class Tensor: + @property + def shape(self) -> TensorShape: ... + def get_shape(self) -> TensorShape: ... + @property + def name(self) -> str: ... + def numpy(self) -> np.ndarray[Any, Any]: ... + def __int__(self) -> int: ... + def __abs__(self) -> Tensor: ... + def __add__(self, other: _TensorCompatible) -> Tensor: ... + def __radd__(self, other: _TensorCompatible) -> Tensor: ... + def __sub__(self, other: _TensorCompatible) -> Tensor: ... + def __rsub__(self, other: _TensorCompatible) -> Tensor: ... + def __mul__(self, other: _TensorCompatible) -> Tensor: ... + def __rmul__(self, other: _TensorCompatible) -> Tensor: ... + def __matmul__(self, other: _TensorCompatible) -> Tensor: ... + def __rmatmul__(self, other: _TensorCompatible) -> Tensor: ... + def __floordiv__(self, other: _TensorCompatible) -> Tensor: ... + def __rfloordiv__(self, other: _TensorCompatible) -> Tensor: ... + def __truediv__(self, other: _TensorCompatible) -> Tensor: ... + def __rtruediv__(self, other: _TensorCompatible) -> Tensor: ... + def __neg__(self) -> Tensor: ... + def __and__(self, other: _TensorCompatible) -> Tensor: ... + def __rand__(self, other: _TensorCompatible) -> Tensor: ... + def __or__(self, other: _TensorCompatible) -> Tensor: ... + def __ror__(self, other: _TensorCompatible) -> Tensor: ... + def __eq__(self, other: _TensorCompatible) -> Tensor: ... + def __ne__(self, other: _TensorCompatible) -> Tensor: ... + def __ge__(self, other: _TensorCompatible) -> Tensor: ... + def __gt__(self, other: _TensorCompatible) -> Tensor: ... + def __le__(self, other: _TensorCompatible) -> Tensor: ... + def __lt__(self, other: _TensorCompatible) -> Tensor: ... + def __bool__(self) -> NoReturn: ... + def __getitem__(self, slice_spec: int | slice | tuple[int | slice, ...]) -> Tensor: ... + def __len__(self) -> int: ... + # This only works for rank 0 tensors. + def __index__(self) -> int: ... + def __getattr__(self, name: str) -> Any: ... # incomplete + +class TensorShape: + def __init__(self, dims: Iterable[int | None]): ... + @property + def rank(self) -> int: ... + def as_list(self) -> list[int | None]: ... + def assert_has_rank(self, rank: int) -> None: ... + def __bool__(self) -> _bool: ... + @overload + def __getitem__(self, key: int) -> int | None: ... + @overload + def __getitem__(self, key: slice) -> TensorShape: ... + def __iter__(self) -> Iterator[int | None]: ... + def __len__(self) -> int: ... + def __add__(self, other: Iterable[int | None]) -> TensorShape: ... + def __radd__(self, other: Iterable[int | None]) -> TensorShape: ... + def __eq__(self, other: Iterable[int | None]) -> _bool: ... # type: ignore + def __getattr__(self, name: str) -> Any: ... # incomplete diff --git a/stubs/tensorflow/tensorflow/_aliases.pyi b/stubs/tensorflow/tensorflow/_aliases.pyi new file mode 100644 index 000000000000..5f595244768e --- /dev/null +++ b/stubs/tensorflow/tensorflow/_aliases.pyi @@ -0,0 +1,36 @@ +# Collection of commonly need type aliases. These are all private +# and do not exist at runtime. + +from typing import Iterable, Mapping, Optional, Sequence, TypeVar, Union + +import numpy as np +import tensorflow as tf + +# These aliases mostly ignore rank/shape/dtype information as that +# will complicate the types heavily and can be a follow up problem. +_FloatDataSequence = Union[Sequence[float], Sequence["_FloatDataSequence"]] +_StrDataSequence = Union[Sequence[str], Sequence["_StrDataSequence"]] +_ScalarTensorConvertible = Union[str, float, np.number, np.ndarray] +_ScalarTensorCompatible = Union[tf.Tensor, _ScalarTensorConvertible] +_TensorConvertible = Union[_ScalarTensorConvertible, _FloatDataSequence, _StrDataSequence] +_TensorCompatible = Union[tf.Tensor, _TensorConvertible] + +# Sparse tensors need to be treated carefully. Most functions do +# not document if they handle sparse tensors. Most functions do +# not support them. Ragged tensors usually work and are documented +# here, https://www.tensorflow.org/api_docs/python/tf/ragged +_SparseTensorCompatible = Union[_TensorCompatible, tf.SparseTensor] +_RaggedTensorCompatible = Union[_TensorCompatible, tf.RaggedTensor] +_AnyTensorCompatible = Union[_TensorCompatible, tf.Tensor, tf.Variable] + +_SparseTensorLike = Union[tf.Tensor, tf.SparseTensor] +_RaggedTensorLike = Union[tf.Tensor, tf.RaggedTensor] +_AnyTensorLike = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor] + +_T1 = TypeVar("_T1", covariant=True) +_ContainerGeneric = Union[Mapping[str, "_ContainerGeneric[_T1]"], Sequence["_ContainerGeneric[_T1]"], _T1] +_ContainerTensors = _ContainerGeneric[tf.Tensor] +_ContainerTensorCompatible = _ContainerGeneric[_TensorCompatible] + +_ShapeLike = Union[tf.TensorShape, Iterable[Optional[int]], int, tf.Tensor] +_DTypeLike = Union[tf.DType, str, np.dtype] diff --git a/stubs/tensorflow/tensorflow/math.pyi b/stubs/tensorflow/tensorflow/math.pyi new file mode 100644 index 000000000000..952c3bf4afa8 --- /dev/null +++ b/stubs/tensorflow/tensorflow/math.pyi @@ -0,0 +1,11 @@ +from typing import overload + +from tensorflow import RaggedTensor, SparseTensor, Tensor +from tensorflow._aliases import _TensorCompatible + +@overload +def abs(x: _TensorCompatible, name: str | None = ...) -> Tensor: ... +@overload +def abs(x: SparseTensor, name: str | None = ...) -> SparseTensor: ... +@overload +def abs(x: RaggedTensor, name: str | None = ...) -> RaggedTensor: ...