Skip to content

Commit a8e73f4

Browse files
committed
adding alias types, wip
1 parent fbe6fdb commit a8e73f4

File tree

9 files changed

+105
-52
lines changed

9 files changed

+105
-52
lines changed

stubs/tensorflow/tensorflow/__init__.pyi

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ from tensorflow import (
1919
keras as keras,
2020
math as math,
2121
)
22-
from tensorflow._aliases import ContainerGradients, ContainerTensors, ContainerTensorsLike, Gradients, TensorLike
22+
from tensorflow._aliases import _ContainerGradients, _ContainerTensors, _ContainerTensorsLike, _Gradients, _TensorLike
2323
from tensorflow.core.protobuf import struct_pb2
2424

2525
# Explicit import of DType is covered by the wildcard, but
@@ -307,39 +307,39 @@ class GradientTape:
307307
@overload
308308
def gradient(
309309
self,
310-
target: ContainerTensors,
311-
sources: TensorLike,
310+
target: _ContainerTensors,
311+
sources: _TensorLike,
312312
output_gradients: list[Tensor] | None = None,
313313
unconnected_gradients: UnconnectedGradients = ...,
314-
) -> Gradients: ...
314+
) -> _Gradients: ...
315315
@overload
316316
def gradient(
317317
self,
318-
target: ContainerTensors,
318+
target: _ContainerTensors,
319319
sources: Sequence[Tensor],
320320
output_gradients: list[Tensor] | None = None,
321321
unconnected_gradients: UnconnectedGradients = ...,
322-
) -> list[Gradients]: ...
322+
) -> list[_Gradients]: ...
323323
@overload
324324
def gradient(
325325
self,
326-
target: ContainerTensors,
326+
target: _ContainerTensors,
327327
sources: Mapping[str, Tensor],
328328
output_gradients: list[Tensor] | None = None,
329329
unconnected_gradients: UnconnectedGradients = ...,
330-
) -> dict[str, Gradients]: ...
330+
) -> dict[str, _Gradients]: ...
331331
@overload
332332
def gradient(
333333
self,
334-
target: ContainerTensors,
335-
sources: ContainerTensors,
334+
target: _ContainerTensors,
335+
sources: _ContainerTensors,
336336
output_gradients: list[Tensor] | None = None,
337337
unconnected_gradients: UnconnectedGradients = ...,
338-
) -> ContainerGradients: ...
338+
) -> _ContainerGradients: ...
339339
@contextmanager
340340
def stop_recording(self) -> Generator[None, None, None]: ...
341341
def reset(self) -> None: ...
342-
def watch(self, tensor: ContainerTensorsLike) -> None: ...
342+
def watch(self, tensor: _ContainerTensorsLike) -> None: ...
343343
def watched_variables(self) -> tuple[Variable, ...]: ...
344344
def __getattr__(self, name: str) -> Incomplete: ...
345345

stubs/tensorflow/tensorflow/_aliases.pyi

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,53 @@
55
from collections.abc import Mapping, Sequence
66
from typing import Any, Protocol, TypeVar
77
from typing_extensions import TypeAlias
8+
from typing import Any, Iterable, Mapping, Sequence, TypeVar
89

9-
import numpy
10+
import numpy as np
1011
import tensorflow as tf
12+
from tensorflow.keras.layers import InputSpec
1113

1214
_T1 = TypeVar("_T1")
13-
ContainerGeneric: TypeAlias = Mapping[str, ContainerGeneric[_T1]] | Sequence[ContainerGeneric[_T1]] | _T1
15+
_T = TypeVar("_T")
16+
_ContainerGeneric: TypeAlias = Mapping[str, _ContainerGeneric[_T1]] | Sequence[_ContainerGeneric[_T1]] | _T1
1417

15-
TensorLike: TypeAlias = tf.Tensor | tf.RaggedTensor | tf.SparseTensor
16-
Gradients: TypeAlias = tf.Tensor | tf.IndexedSlices
17-
18-
ContainerTensorsLike: TypeAlias = ContainerGeneric[TensorLike]
19-
ContainerTensors: TypeAlias = ContainerGeneric[tf.Tensor]
20-
ContainerGradients: TypeAlias = ContainerGeneric[Gradients]
21-
22-
AnyArray: TypeAlias = numpy.ndarray[Any, Any]
18+
_TensorLike: TypeAlias = tf.Tensor | tf.RaggedTensor | tf.SparseTensor
19+
_SparseTensorLike = tf.Tensor | tf.SparseTensor
20+
_RaggedTensorLike = tf.Tensor | tf.RaggedTensor
21+
_RaggedTensorLikeT = TypeVar("_RaggedTensorLikeT", tf.Tensor, tf.RaggedTensor)
22+
_Gradients: TypeAlias = tf.Tensor | tf.IndexedSlices
2323

2424
class _KerasSerializable1(Protocol):
2525
def get_config(self) -> dict[str, Any]: ...
2626

2727
class _KerasSerializable2(Protocol):
2828
__name__: str
2929

30-
KerasSerializable: TypeAlias = _KerasSerializable1 | _KerasSerializable2
30+
_KerasSerializable: TypeAlias = _KerasSerializable1 | _KerasSerializable2
31+
32+
_FloatDataSequence = Sequence[float] | Sequence[_FloatDataSequence]
33+
_StrDataSequence = Sequence[str] | Sequence[_StrDataSequence]
34+
_ScalarTensorCompatible = tf.Tensor | str | float | np.ndarray[Any, Any] | np.number[Any]
35+
36+
_TensorCompatible = _ScalarTensorCompatible | Sequence[_TensorCompatible]
37+
_TensorCompatibleT = TypeVar("_TensorCompatibleT", bound=_TensorCompatible)
38+
# Sparse tensors are very annoying. Some operations work on them, but many do not. You
39+
# will need to manually verify if an operation supports them. SparseTensorCompatible is intended to be a
40+
# broader type than TensorCompatible and not all operations will support broader version. If unsure,
41+
# use TensorCompatible instead.
42+
_SparseTensorCompatible = _TensorCompatible | tf.SparseTensor
43+
44+
_ShapeLike = tf.TensorShape | Iterable[_ScalarTensorCompatible | None] | int | tf.Tensor
45+
_DTypeLike = tf.DType | str | np.dtype[Any] | int
46+
_GradientsT = tf.Tensor | tf.IndexedSlices
47+
48+
_ContainerTensors = _ContainerGeneric[tf.Tensor]
49+
_ContainerTensorsLike = _ContainerGeneric[_TensorLike]
50+
_ContainerTensorCompatible = _ContainerGeneric[_TensorCompatible]
51+
_ContainerGradients = _ContainerGeneric[_GradientsT]
52+
_ContainerTensorShape = _ContainerGeneric[tf.TensorShape]
53+
_ContainerInputSpec = _ContainerGeneric[InputSpec]
54+
55+
_AnyArray = np.ndarray[Any, Any]
56+
_FloatArray = np.ndarray[Any, np.dtype[np.float_ | np.float16 | np.float32 | np.float64]]
57+
_IntArray = np.ndarray[Any, np.dtype[np.int_ | np.uint8 | np.int32 | np.int64]]

stubs/tensorflow/tensorflow/data/__init__.pyi

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ from typing_extensions import Self
77
import numpy as np
88
import tensorflow as tf
99
from tensorflow import TypeSpec, _ScalarTensorCompatible, _TensorCompatible
10-
from tensorflow._aliases import ContainerGeneric
10+
from tensorflow._aliases import _ContainerGeneric
1111
from tensorflow.data import experimental as experimental
1212
from tensorflow.data.experimental import AUTOTUNE as AUTOTUNE
1313
from tensorflow.dtypes import DType
@@ -21,7 +21,7 @@ _T3 = TypeVar("_T3")
2121
class Iterator(_Iterator[_T1], Trackable, ABC):
2222
@property
2323
@abstractmethod
24-
def element_spec(self) -> ContainerGeneric[TypeSpec[Any]]: ...
24+
def element_spec(self) -> _ContainerGeneric[TypeSpec[Any]]: ...
2525
@abstractmethod
2626
def get_next(self) -> _T1: ...
2727
@abstractmethod
@@ -43,8 +43,8 @@ class Dataset(ABC, Generic[_T1]):
4343
element_length_func: Callable[[_T1], _ScalarTensorCompatible],
4444
bucket_boundaries: Sequence[int],
4545
bucket_batch_sizes: Sequence[int],
46-
padded_shapes: ContainerGeneric[tf.TensorShape | _TensorCompatible] | None = None,
47-
padding_values: ContainerGeneric[_ScalarTensorCompatible] | None = None,
46+
padded_shapes: _ContainerGeneric[tf.TensorShape | _TensorCompatible] | None = None,
47+
padding_values: _ContainerGeneric[_ScalarTensorCompatible] | None = None,
4848
pad_to_bucket_boundary: bool = False,
4949
no_padding: bool = False,
5050
drop_remainder: bool = False,
@@ -63,18 +63,18 @@ class Dataset(ABC, Generic[_T1]):
6363
) -> Dataset[tf.Tensor]: ...
6464
@property
6565
@abstractmethod
66-
def element_spec(self) -> ContainerGeneric[TypeSpec[Any]]: ...
66+
def element_spec(self) -> _ContainerGeneric[TypeSpec[Any]]: ...
6767
def enumerate(self, start: _ScalarTensorCompatible = 0, name: str | None = None) -> Dataset[tuple[int, _T1]]: ...
6868
def filter(self, predicate: Callable[[_T1], bool | tf.Tensor], name: str | None = None) -> Dataset[_T1]: ...
6969
def flat_map(self, map_func: Callable[[_T1], Dataset[_T2]], name: str | None = None) -> Dataset[_T2]: ...
7070
# PEP 646 can be used here for a more precise type when better supported.
7171
@staticmethod
7272
def from_generator(
7373
generator: Callable[..., _T2],
74-
output_types: ContainerGeneric[DType] | None = None,
75-
output_shapes: ContainerGeneric[tf.TensorShape | Sequence[int | None]] | None = None,
74+
output_types: _ContainerGeneric[DType] | None = None,
75+
output_shapes: _ContainerGeneric[tf.TensorShape | Sequence[int | None]] | None = None,
7676
args: tuple[object, ...] | None = None,
77-
output_signature: ContainerGeneric[TypeSpec[Any]] | None = None,
77+
output_signature: _ContainerGeneric[TypeSpec[Any]] | None = None,
7878
name: str | None = None,
7979
) -> Dataset[_T2]: ...
8080
@staticmethod
@@ -111,7 +111,7 @@ class Dataset(ABC, Generic[_T1]):
111111
@staticmethod
112112
def load(
113113
path: str,
114-
element_spec: ContainerGeneric[tf.TypeSpec[Any]] | None = None,
114+
element_spec: _ContainerGeneric[tf.TypeSpec[Any]] | None = None,
115115
compression: _CompressionTypes = None,
116116
reader_func: Callable[[Dataset[Dataset[Any]]], Dataset[Any]] | None = None,
117117
) -> Dataset[Any]: ...
@@ -127,8 +127,8 @@ class Dataset(ABC, Generic[_T1]):
127127
def padded_batch(
128128
self,
129129
batch_size: _ScalarTensorCompatible,
130-
padded_shapes: ContainerGeneric[tf.TensorShape | _TensorCompatible] | None = None,
131-
padding_values: ContainerGeneric[_ScalarTensorCompatible] | None = None,
130+
padded_shapes: _ContainerGeneric[tf.TensorShape | _TensorCompatible] | None = None,
131+
padding_values: _ContainerGeneric[_ScalarTensorCompatible] | None = None,
132132
drop_remainder: bool = False,
133133
name: str | None = None,
134134
) -> Dataset[_T1]: ...

stubs/tensorflow/tensorflow/io/__init__.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from typing import Literal, NamedTuple
55
from typing_extensions import Self, TypeAlias
66

77
from tensorflow import _DTypeLike, _ShapeLike, _TensorCompatible
8-
from tensorflow._aliases import TensorLike
8+
from tensorflow._aliases import _TensorLike
99
from tensorflow.io import gfile as gfile
1010

1111
_FeatureSpecs: TypeAlias = Mapping[str, FixedLenFeature | FixedLenSequenceFeature | VarLenFeature | RaggedFeature | SparseFeature]
@@ -102,5 +102,5 @@ class RaggedFeature(NamedTuple):
102102

103103
def parse_example(
104104
serialized: _TensorCompatible, features: _FeatureSpecs, example_names: Iterable[str] | None = None, name: str | None = None
105-
) -> dict[str, TensorLike]: ...
105+
) -> dict[str, _TensorLike]: ...
106106
def __getattr__(name: str) -> Incomplete: ...

stubs/tensorflow/tensorflow/keras/__init__.pyi

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from _typeshed import Incomplete
2+
from typing import Callable
23

34
from tensorflow.keras import (
45
activations as activations,
@@ -12,5 +13,20 @@ from tensorflow.keras import (
1213
regularizers as regularizers,
1314
)
1415
from tensorflow.keras.models import Model as Model
16+
import tensorflow as tf
17+
from tensorflow._aliases import _TensorCompatible
1518

1619
def __getattr__(name: str) -> Incomplete: ...
20+
21+
_Loss = (
22+
str
23+
| tf.keras.losses.Loss
24+
| Callable[[_TensorCompatible, _TensorCompatible], tf._Tensor]
25+
)
26+
27+
_Metric = (
28+
str
29+
| tf.keras.metrics.Metric
30+
| Callable[[_TensorCompatible, _TensorCompatible], tf._Tensor]
31+
| None
32+
)

stubs/tensorflow/tensorflow/keras/layers.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from typing_extensions import Self, TypeAlias
55

66
import tensorflow as tf
77
from tensorflow import Tensor, Variable, VariableAggregation, VariableSynchronization, _TensorCompatible
8-
from tensorflow._aliases import AnyArray
8+
from tensorflow._aliases import _AnyArray
99
from tensorflow.keras.activations import _Activation
1010
from tensorflow.keras.constraints import Constraint
1111
from tensorflow.keras.initializers import _Initializer
@@ -90,8 +90,8 @@ class Layer(tf.Module, Generic[_InputT, _OutputT]):
9090
def non_trainable_weights(self) -> list[Variable]: ...
9191
@property
9292
def losses(self) -> list[Tensor]: ...
93-
def get_weights(self) -> list[AnyArray]: ...
94-
def set_weights(self, weights: Sequence[AnyArray]) -> None: ...
93+
def get_weights(self) -> list[_AnyArray]: ...
94+
def set_weights(self, weights: Sequence[_AnyArray]) -> None: ...
9595
def get_config(self) -> dict[str, Any]: ...
9696
@classmethod
9797
def from_config(cls, config: dict[str, Any]) -> Self: ...

stubs/tensorflow/tensorflow/keras/losses.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from typing import Any, Final, Literal, TypeVar, overload
55
from typing_extensions import Self, TypeAlias, TypeGuard
66

77
from tensorflow import Tensor, _TensorCompatible
8-
from tensorflow._aliases import KerasSerializable
8+
from tensorflow._aliases import _KerasSerializable
99
from tensorflow.keras.metrics import (
1010
binary_crossentropy as binary_crossentropy,
1111
categorical_crossentropy as categorical_crossentropy,
@@ -136,7 +136,7 @@ def log_cosh(y_true: _TensorCompatible, y_pred: _TensorCompatible) -> Tensor: ..
136136
def deserialize(
137137
name: str | dict[str, Any], custom_objects: dict[str, Any] | None = None, use_legacy_format: bool = False
138138
) -> Loss: ...
139-
def serialize(loss: KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...
139+
def serialize(loss: _KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...
140140

141141
_FuncT = TypeVar("_FuncT", bound=Callable[..., Any])
142142

stubs/tensorflow/tensorflow/keras/models.pyi

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,18 @@ import tensorflow
1010
import tensorflow as tf
1111
from tensorflow import Variable, _ShapeLike, _TensorCompatible
1212
from tensorflow.keras.layers import Layer, _InputT, _OutputT
13-
from tensorflow.keras.optimizers.legacy import Optimizer
13+
from tensorflow.keras import _Loss, _Metric
14+
from tensorflow._aliases import _ContainerGeneric
15+
16+
_BothOptimizer = tf.optimizers.Optimizer | tf.optimizers.experimental.Optimizer
1417

1518
class Model(Layer[_InputT, _OutputT], tf.Module):
19+
_train_counter: tf.Variable
20+
_test_counter: tf.Variable
21+
optimizer: _BothOptimizer | None
22+
loss: tf.keras.losses.Loss | dict[str, tf.keras.losses.Loss]
23+
stop_training: bool
24+
1625
def __new__(cls, *args: Any, **kwargs: Any) -> Model[_InputT, _OutputT]: ...
1726
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
1827
def __setattr__(self, name: str, value: Any) -> None: ...
@@ -21,13 +30,14 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
2130
def build(self, input_shape: _ShapeLike) -> None: ...
2231
def __call__(self, inputs: _InputT, *, training: bool = False, mask: _TensorCompatible | None = None) -> _OutputT: ...
2332
def call(self, inputs: _InputT, training: bool | None = None, mask: _TensorCompatible | None = None) -> _OutputT: ...
33+
# Ideally loss/metrics/output would share the same structure but higher kinded types are not supported.
2434
def compile(
2535
self,
26-
optimizer: Optimizer | str = "rmsprop",
27-
loss: tf.keras.losses.Loss | str | None = None,
28-
metrics: list[tf.keras.metrics.Metric | str] | None = None,
29-
loss_weights: list[float] | dict[str, float] | None = None,
30-
weighted_metrics: list[tf.keras.metrics.Metric] | None = None,
36+
optimizer: _BothOptimizer | str = "rmsprop",
37+
loss: _ContainerGeneric[_Loss] | None = None,
38+
metrics: _ContainerGeneric[_Metric] | None = None,
39+
loss_weights: _ContainerGeneric[float] | None = None,
40+
weighted_metrics: _ContainerGeneric[_Metric] | None = None,
3141
run_eagerly: bool | None = None,
3242
steps_per_execution: int | Literal["auto"] | None = None,
3343
jit_compile: bool | None = None,

stubs/tensorflow/tensorflow/keras/optimizers/legacy/__init__.pyi

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@ from typing import Any
55
from typing_extensions import Self, TypeAlias
66

77
import tensorflow as tf
8-
from tensorflow._aliases import Gradients
8+
from tensorflow._aliases import _Gradients
99
from tensorflow.keras.optimizers import schedules as schedules
1010
from tensorflow.python.trackable.base import Trackable
1111

1212
_Initializer: TypeAlias = str | Callable[[], tf.Tensor] | dict[str, Any]
1313
_Shape: TypeAlias = tf.TensorShape | Iterable[int | None]
1414
_Dtype: TypeAlias = tf.DType | str | None
1515
_LearningRate: TypeAlias = float | tf.Tensor | schedules.LearningRateSchedule | Callable[[], float | tf.Tensor]
16-
_GradientAggregator: TypeAlias = Callable[[list[tuple[Gradients, tf.Variable]]], list[tuple[Gradients, tf.Variable]]] | None
16+
_GradientAggregator: TypeAlias = Callable[[list[tuple[_Gradients, tf.Variable]]], list[tuple[_Gradients, tf.Variable]]] | None
1717
_GradientTransformer: TypeAlias = (
18-
Iterable[Callable[[list[tuple[Gradients, tf.Variable]]], list[tuple[Gradients, tf.Variable]]]] | None
18+
Iterable[Callable[[list[tuple[_Gradients, tf.Variable]]], list[tuple[_Gradients, tf.Variable]]]] | None
1919
)
2020

2121
# kwargs here and in other optimizers can be given better type after Unpack[TypedDict], PEP 692, is supported.
@@ -53,7 +53,7 @@ class Optimizer(Trackable):
5353
) -> tf.Variable: ...
5454
def apply_gradients(
5555
self,
56-
grads_and_vars: Iterable[tuple[Gradients, tf.Variable]],
56+
grads_and_vars: Iterable[tuple[_Gradients, tf.Variable]],
5757
name: str | None = None,
5858
experimental_aggregate_gradients: bool = True,
5959
) -> tf.Operation | None: ...
@@ -64,7 +64,7 @@ class Optimizer(Trackable):
6464
def get_config(self) -> dict[str, Any]: ...
6565
def get_slot(self, var: tf.Variable, slot_name: str) -> tf.Variable: ...
6666
def get_slot_names(self) -> list[str]: ...
67-
def get_gradients(self, loss: tf.Tensor, params: list[tf.Variable]) -> list[Gradients]: ...
67+
def get_gradients(self, loss: tf.Tensor, params: list[tf.Variable]) -> list[_Gradients]: ...
6868
def minimize(
6969
self,
7070
loss: tf.Tensor | Callable[[], tf.Tensor],

0 commit comments

Comments
 (0)