Skip to content

Add TensorFlow stubs #11306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 77 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
ba67d5b
Add convert_to_tensor
hoel-bagard Jan 19, 2024
c994063
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
476c5f6
Add a few functions.
hoel-bagard Jan 21, 2024
8a8df54
feat: add config functions.
hoel-bagard Jan 22, 2024
9f69588
Add missing tf.keras.Model methods
hoel-bagard Jan 24, 2024
cbc2641
Add TF2 functions/models/methods.
hoel-bagard Jan 24, 2024
5a353e3
revert TensorShape changes.
hoel-bagard Jan 24, 2024
c265a66
Fix tf functions.
hoel-bagard Jan 24, 2024
eb3098f
Fix CheckpointOptions's callbacks type.
hoel-bagard Jan 24, 2024
005cf22
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2024
48a2479
fix: file name train.py -> train.pyi
hoel-bagard Jan 24, 2024
e1515cc
fix: Model MRO error
hoel-bagard Jan 24, 2024
d598706
fix: concat's values type
hoel-bagard Jan 24, 2024
e3911ea
fix: remove __getitem__ from IndexedSlices.
hoel-bagard Jan 24, 2024
f0cc8eb
fix: remove __index__ from TensorShape.
hoel-bagard Jan 24, 2024
2db437a
fix: reorganize config module.
hoel-bagard Jan 24, 2024
8e01170
fix: add missing Model methods/properties (wip)
hoel-bagard Jan 24, 2024
8962c93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2024
8f7df46
fix: use Iterator from collections.abc instead of typing
hoel-bagard Jan 24, 2024
8bc143f
fix: import Self from typing_extensions instead of typing
hoel-bagard Jan 24, 2024
3e27624
feat: add empty Metric class
hoel-bagard Jan 25, 2024
59f6f6f
feat: Add type hints for keras Model's compile, fit and evaluate
hoel-bagard Jan 25, 2024
1f216a4
feat: Add type hints for keras Model's predict, train_on_batch, fit_g…
hoel-bagard Jan 25, 2024
a6f7ba2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
f7cc67d
feat: Add type hints for keras Model's save, save_weights and load_we…
hoel-bagard Jan 25, 2024
e3573cf
feat: Add type hints for keras Model's get_layer
hoel-bagard Jan 25, 2024
f31992b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
6c63b89
fix: add Any type to args and kwargs.
hoel-bagard Jan 25, 2024
d7b08a6
fix: allow tuple[int, ...] as input to expand_dims
hoel-bagard Jan 25, 2024
f10de3f
fix: add empty tensorflow.kerras.callbacks.Callback calls
hoel-bagard Jan 25, 2024
41b5a23
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
9d1d1b4
fix: Model's get_config return type.
hoel-bagard Jan 25, 2024
b63f23d
feat: add empty History class
hoel-bagard Jan 25, 2024
2a422dc
fix: add type hints for test_on_batch, add missing return types
hoel-bagard Jan 25, 2024
d2f218e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
5d99337
fix: add missing types
hoel-bagard Jan 25, 2024
924a59f
fix: allow call's training to be None
hoel-bagard Jan 25, 2024
f5384f3
fix: predict input type.
hoel-bagard Jan 25, 2024
29b0d4a
fix: fix CheckpointOptions (add init)
hoel-bagard Jan 25, 2024
0326484
fix: add __init__ and __new__ to Metric
hoel-bagard Jan 25, 2024
f0f1c01
fix: comment out properties not present at runtime.
hoel-bagard Jan 25, 2024
3c772c5
fix: import Self from typing_extensions instead of typing
hoel-bagard Jan 25, 2024
1e804ab
adding alias types, wip
hoel-bagard Jan 26, 2024
5e75c93
Fix tf functions.
hoel-bagard Jan 24, 2024
3488064
Fix CheckpointOptions's callbacks type.
hoel-bagard Jan 24, 2024
cfd5048
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2024
b762251
fix: file name train.py -> train.pyi
hoel-bagard Jan 24, 2024
33f0132
fix: Model MRO error
hoel-bagard Jan 24, 2024
d67f0b9
fix: concat's values type
hoel-bagard Jan 24, 2024
114f213
fix: add missing Model methods/properties (wip)
hoel-bagard Jan 24, 2024
cd5b7cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2024
6ee7428
fix: use Iterator from collections.abc instead of typing
hoel-bagard Jan 24, 2024
36499f8
fix: import Self from typing_extensions instead of typing
hoel-bagard Jan 24, 2024
38f8480
feat: add empty Metric class
hoel-bagard Jan 25, 2024
10b432f
feat: Add type hints for keras Model's compile, fit and evaluate
hoel-bagard Jan 25, 2024
7cae2a0
feat: Add type hints for keras Model's predict, train_on_batch, fit_g…
hoel-bagard Jan 25, 2024
95ee069
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
08a0aa1
feat: Add type hints for keras Model's save, save_weights and load_we…
hoel-bagard Jan 25, 2024
0164ed9
feat: Add type hints for keras Model's get_layer
hoel-bagard Jan 25, 2024
14ea8de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
2c09fcc
fix: add Any type to args and kwargs.
hoel-bagard Jan 25, 2024
2c1603a
fix: allow tuple[int, ...] as input to expand_dims
hoel-bagard Jan 25, 2024
9ad65a2
fix: add empty tensorflow.kerras.callbacks.Callback calls
hoel-bagard Jan 25, 2024
3177cd1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
eb911e5
fix: Model's get_config return type.
hoel-bagard Jan 25, 2024
1ba6088
feat: add empty History class
hoel-bagard Jan 25, 2024
ff7bf94
fix: add type hints for test_on_batch, add missing return types
hoel-bagard Jan 25, 2024
a7a12dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
53a6825
fix: add missing types
hoel-bagard Jan 25, 2024
4bd2f0b
fix: allow call's training to be None
hoel-bagard Jan 25, 2024
fe1599b
fix: predict input type.
hoel-bagard Jan 25, 2024
d9f9669
fix: fix CheckpointOptions (add init)
hoel-bagard Jan 25, 2024
1405abe
fix: add __init__ and __new__ to Metric
hoel-bagard Jan 25, 2024
054c8cf
fix: comment out properties not present at runtime.
hoel-bagard Jan 25, 2024
b2a5e7b
fix: import Self from typing_extensions instead of typing
hoel-bagard Jan 25, 2024
04c93ea
adding alias types, wip
hoel-bagard Jan 26, 2024
62f42ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions stubs/tensorflow/tensorflow/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from tensorflow import (
keras as keras,
math as math,
)
from tensorflow._aliases import ContainerGradients, ContainerTensors, ContainerTensorsLike, Gradients, TensorLike
from tensorflow._aliases import _ContainerGradients, _ContainerTensors, _ContainerTensorsLike, _Gradients, _TensorLike
from tensorflow.core.protobuf import struct_pb2

# Explicit import of DType is covered by the wildcard, but
Expand Down Expand Up @@ -73,7 +73,7 @@ from tensorflow.sparse import SparseTensor as SparseTensor
# we will skip making Tensor generic. Also good type hints for shapes will
# run quickly into many places where type system is not strong enough today.
# So shape typing is probably not worth doing anytime soon.
_Slice: TypeAlias = int | slice | None
_Slice: TypeAlias = int | slice | Tensor | None

_FloatDataSequence: TypeAlias = Sequence[float] | Sequence[_FloatDataSequence]
_StrDataSequence: TypeAlias = Sequence[str] | Sequence[_StrDataSequence]
Expand Down Expand Up @@ -307,39 +307,39 @@ class GradientTape:
@overload
def gradient(
self,
target: ContainerTensors,
sources: TensorLike,
target: _ContainerTensors,
sources: _TensorLike,
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> Gradients: ...
) -> _Gradients: ...
@overload
def gradient(
self,
target: ContainerTensors,
target: _ContainerTensors,
sources: Sequence[Tensor],
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> list[Gradients]: ...
) -> list[_Gradients]: ...
@overload
def gradient(
self,
target: ContainerTensors,
target: _ContainerTensors,
sources: Mapping[str, Tensor],
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> dict[str, Gradients]: ...
) -> dict[str, _Gradients]: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: ContainerTensors,
target: _ContainerTensors,
sources: _ContainerTensors,
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> ContainerGradients: ...
) -> _ContainerGradients: ...
@contextmanager
def stop_recording(self) -> Generator[None, None, None]: ...
def reset(self) -> None: ...
def watch(self, tensor: ContainerTensorsLike) -> None: ...
def watch(self, tensor: _ContainerTensorsLike) -> None: ...
def watched_variables(self) -> tuple[Variable, ...]: ...
def __getattr__(self, name: str) -> Incomplete: ...

Expand Down Expand Up @@ -407,8 +407,14 @@ class RaggedTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):

def __getattr__(name: str) -> Incomplete: ...
def convert_to_tensor(
value: _TensorCompatible | IndexedSlices,
dtype: _DTypeLike | None = None,
dtype_hint: _DTypeLike | None = None,
name: str | None = None,
value: _TensorCompatible, dtype: DType | None = None, dtype_hint: DType | None = None, name: str | None = None
) -> Tensor: ...
def expand_dims(input: Tensor | tuple[int, ...], axis: int, name: None | str = None) -> Tensor: ...
def concat(values: _TensorCompatible, axis: int | Tensor, name: None | str = "concat") -> Tensor: ...
def squeeze(input: Tensor, axis: None | _TensorCompatible = None, name: None | str = None) -> Tensor: ...
def tensor_scatter_nd_update(tensor: Tensor, indices: Tensor, updates: Tensor, name: None | str = None) -> Tensor: ...
def constant(
value: _TensorCompatible, dtype: None | _DTypeLike = None, shape: None | _ShapeLike = None, name: None | str = "Const"
) -> Tensor: ...
def cast(x: Tensor | _TensorCompatible, dtype: DType, name: None | str = None) -> Tensor: ...
def reshape(tensor: Tensor | _TensorCompatible, shape: _ShapeLike, name: None | str = None) -> Tensor: ...
50 changes: 38 additions & 12 deletions stubs/tensorflow/tensorflow/_aliases.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,54 @@
# equivalent.

from collections.abc import Mapping, Sequence
from typing import Any, Protocol, TypeVar
from typing import Any, Iterable, Mapping, Protocol, Sequence, TypeVar
from typing_extensions import TypeAlias

import numpy
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import InputSpec

_T1 = TypeVar("_T1")
ContainerGeneric: TypeAlias = Mapping[str, ContainerGeneric[_T1]] | Sequence[ContainerGeneric[_T1]] | _T1
_T = TypeVar("_T")
_ContainerGeneric: TypeAlias = Mapping[str, _ContainerGeneric[_T1]] | Sequence[_ContainerGeneric[_T1]] | _T1

TensorLike: TypeAlias = tf.Tensor | tf.RaggedTensor | tf.SparseTensor
Gradients: TypeAlias = tf.Tensor | tf.IndexedSlices

ContainerTensorsLike: TypeAlias = ContainerGeneric[TensorLike]
ContainerTensors: TypeAlias = ContainerGeneric[tf.Tensor]
ContainerGradients: TypeAlias = ContainerGeneric[Gradients]

AnyArray: TypeAlias = numpy.ndarray[Any, Any]
_TensorLike: TypeAlias = tf.Tensor | tf.RaggedTensor | tf.SparseTensor
_SparseTensorLike = tf.Tensor | tf.SparseTensor
_RaggedTensorLike = tf.Tensor | tf.RaggedTensor
_RaggedTensorLikeT = TypeVar("_RaggedTensorLikeT", tf.Tensor, tf.RaggedTensor)
_Gradients: TypeAlias = tf.Tensor | tf.IndexedSlices

class _KerasSerializable1(Protocol):
def get_config(self) -> dict[str, Any]: ...

class _KerasSerializable2(Protocol):
__name__: str

KerasSerializable: TypeAlias = _KerasSerializable1 | _KerasSerializable2
_KerasSerializable: TypeAlias = _KerasSerializable1 | _KerasSerializable2

_FloatDataSequence = Sequence[float] | Sequence[_FloatDataSequence]
_StrDataSequence = Sequence[str] | Sequence[_StrDataSequence]
_ScalarTensorCompatible = tf.Tensor | str | float | np.ndarray[Any, Any] | np.number[Any]

_TensorCompatible = _ScalarTensorCompatible | Sequence[_TensorCompatible]
_TensorCompatibleT = TypeVar("_TensorCompatibleT", bound=_TensorCompatible)
# Sparse tensors are very annoying. Some operations work on them, but many do not. You
# will need to manually verify if an operation supports them. SparseTensorCompatible is intended to be a
# broader type than TensorCompatible and not all operations will support broader version. If unsure,
# use TensorCompatible instead.
_SparseTensorCompatible = _TensorCompatible | tf.SparseTensor

_ShapeLike = tf.TensorShape | Iterable[_ScalarTensorCompatible | None] | int | tf.Tensor
_DTypeLike = tf.DType | str | np.dtype[Any] | int
_GradientsT = tf.Tensor | tf.IndexedSlices

_ContainerTensors = _ContainerGeneric[tf.Tensor]
_ContainerTensorsLike = _ContainerGeneric[_TensorLike]
_ContainerTensorCompatible = _ContainerGeneric[_TensorCompatible]
_ContainerGradients = _ContainerGeneric[_GradientsT]
_ContainerTensorShape = _ContainerGeneric[tf.TensorShape]
_ContainerInputSpec = _ContainerGeneric[InputSpec]

_AnyArray = np.ndarray[Any, Any]
_FloatArray = np.ndarray[Any, np.dtype[np.float_ | np.float16 | np.float32 | np.float64]]
_IntArray = np.ndarray[Any, np.dtype[np.int_ | np.uint8 | np.int32 | np.int64]]
10 changes: 10 additions & 0 deletions stubs/tensorflow/tensorflow/config/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import NamedTuple

from tensorflow.config import experimental as experimental

class PhysicalDevice(NamedTuple):
name: str
device_type: str

def list_physical_devices(device_type: None | str = None) -> list[PhysicalDevice]: ...
def set_visible_devices(devices: list[PhysicalDevice], device_type: None | str = None) -> None: ...
3 changes: 3 additions & 0 deletions stubs/tensorflow/tensorflow/config/experimental.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from tensorflow.config import PhysicalDevice

def set_memory_growth(device: PhysicalDevice, enable: bool) -> None: ...
22 changes: 11 additions & 11 deletions stubs/tensorflow/tensorflow/data/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ from typing_extensions import Self
import numpy as np
import tensorflow as tf
from tensorflow import TypeSpec, _ScalarTensorCompatible, _TensorCompatible
from tensorflow._aliases import ContainerGeneric
from tensorflow._aliases import _ContainerGeneric
from tensorflow.data import experimental as experimental
from tensorflow.data.experimental import AUTOTUNE as AUTOTUNE
from tensorflow.dtypes import DType
Expand All @@ -21,7 +21,7 @@ _T3 = TypeVar("_T3")
class Iterator(_Iterator[_T1], Trackable, ABC):
@property
@abstractmethod
def element_spec(self) -> ContainerGeneric[TypeSpec[Any]]: ...
def element_spec(self) -> _ContainerGeneric[TypeSpec[Any]]: ...
@abstractmethod
def get_next(self) -> _T1: ...
@abstractmethod
Expand All @@ -43,8 +43,8 @@ class Dataset(ABC, Generic[_T1]):
element_length_func: Callable[[_T1], _ScalarTensorCompatible],
bucket_boundaries: Sequence[int],
bucket_batch_sizes: Sequence[int],
padded_shapes: ContainerGeneric[tf.TensorShape | _TensorCompatible] | None = None,
padding_values: ContainerGeneric[_ScalarTensorCompatible] | None = None,
padded_shapes: _ContainerGeneric[tf.TensorShape | _TensorCompatible] | None = None,
padding_values: _ContainerGeneric[_ScalarTensorCompatible] | None = None,
pad_to_bucket_boundary: bool = False,
no_padding: bool = False,
drop_remainder: bool = False,
Expand All @@ -63,18 +63,18 @@ class Dataset(ABC, Generic[_T1]):
) -> Dataset[tf.Tensor]: ...
@property
@abstractmethod
def element_spec(self) -> ContainerGeneric[TypeSpec[Any]]: ...
def element_spec(self) -> _ContainerGeneric[TypeSpec[Any]]: ...
def enumerate(self, start: _ScalarTensorCompatible = 0, name: str | None = None) -> Dataset[tuple[int, _T1]]: ...
def filter(self, predicate: Callable[[_T1], bool | tf.Tensor], name: str | None = None) -> Dataset[_T1]: ...
def flat_map(self, map_func: Callable[[_T1], Dataset[_T2]], name: str | None = None) -> Dataset[_T2]: ...
# PEP 646 can be used here for a more precise type when better supported.
@staticmethod
def from_generator(
generator: Callable[..., _T2],
output_types: ContainerGeneric[DType] | None = None,
output_shapes: ContainerGeneric[tf.TensorShape | Sequence[int | None]] | None = None,
output_types: _ContainerGeneric[DType] | None = None,
output_shapes: _ContainerGeneric[tf.TensorShape | Sequence[int | None]] | None = None,
args: tuple[object, ...] | None = None,
output_signature: ContainerGeneric[TypeSpec[Any]] | None = None,
output_signature: _ContainerGeneric[TypeSpec[Any]] | None = None,
name: str | None = None,
) -> Dataset[_T2]: ...
@staticmethod
Expand Down Expand Up @@ -111,7 +111,7 @@ class Dataset(ABC, Generic[_T1]):
@staticmethod
def load(
path: str,
element_spec: ContainerGeneric[tf.TypeSpec[Any]] | None = None,
element_spec: _ContainerGeneric[tf.TypeSpec[Any]] | None = None,
compression: _CompressionTypes = None,
reader_func: Callable[[Dataset[Dataset[Any]]], Dataset[Any]] | None = None,
) -> Dataset[Any]: ...
Expand All @@ -127,8 +127,8 @@ class Dataset(ABC, Generic[_T1]):
def padded_batch(
self,
batch_size: _ScalarTensorCompatible,
padded_shapes: ContainerGeneric[tf.TensorShape | _TensorCompatible] | None = None,
padding_values: ContainerGeneric[_ScalarTensorCompatible] | None = None,
padded_shapes: _ContainerGeneric[tf.TensorShape | _TensorCompatible] | None = None,
padding_values: _ContainerGeneric[_ScalarTensorCompatible] | None = None,
drop_remainder: bool = False,
name: str | None = None,
) -> Dataset[_T1]: ...
Expand Down
4 changes: 2 additions & 2 deletions stubs/tensorflow/tensorflow/io/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from typing import Literal, NamedTuple
from typing_extensions import Self, TypeAlias

from tensorflow import _DTypeLike, _ShapeLike, _TensorCompatible
from tensorflow._aliases import TensorLike
from tensorflow._aliases import _TensorLike
from tensorflow.io import gfile as gfile

_FeatureSpecs: TypeAlias = Mapping[str, FixedLenFeature | FixedLenSequenceFeature | VarLenFeature | RaggedFeature | SparseFeature]
Expand Down Expand Up @@ -102,5 +102,5 @@ class RaggedFeature(NamedTuple):

def parse_example(
serialized: _TensorCompatible, features: _FeatureSpecs, example_names: Iterable[str] | None = None, name: str | None = None
) -> dict[str, TensorLike]: ...
) -> dict[str, _TensorLike]: ...
def __getattr__(name: str) -> Incomplete: ...
9 changes: 9 additions & 0 deletions stubs/tensorflow/tensorflow/keras/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from _typeshed import Incomplete
from typing import Callable

import tensorflow as tf
from tensorflow._aliases import _TensorCompatible
from tensorflow.keras import (
activations as activations,
constraints as constraints,
initializers as initializers,
layers as layers,
losses as losses,
metrics as metrics,
models as models,
optimizers as optimizers,
regularizers as regularizers,
)
from tensorflow.keras.models import Model as Model

def __getattr__(name: str) -> Incomplete: ...

_Loss = str | tf.keras.losses.Loss | Callable[[_TensorCompatible, _TensorCompatible], tf._Tensor]

_Metric = str | tf.keras.metrics.Metric | Callable[[_TensorCompatible, _TensorCompatible], tf._Tensor] | None
4 changes: 4 additions & 0 deletions stubs/tensorflow/tensorflow/keras/callbacks.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from _typeshed import Incomplete

class Callback(Incomplete): ...
class History(Incomplete): ...
57 changes: 53 additions & 4 deletions stubs/tensorflow/tensorflow/keras/layers.pyi
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from _typeshed import Incomplete
from collections.abc import Callable, Iterable, Sequence
from typing import Any, Generic, TypeVar, overload
from typing import Any, Generic, Literal, TypeVar, overload
from typing_extensions import Self, TypeAlias

import tensorflow as tf
from tensorflow import Tensor, Variable, VariableAggregation, VariableSynchronization, _TensorCompatible
from tensorflow._aliases import AnyArray
from tensorflow._aliases import _AnyArray
from tensorflow.keras.activations import _Activation
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.initializers import _Initializer
Expand Down Expand Up @@ -90,8 +90,8 @@ class Layer(tf.Module, Generic[_InputT, _OutputT]):
def non_trainable_weights(self) -> list[Variable]: ...
@property
def losses(self) -> list[Tensor]: ...
def get_weights(self) -> list[AnyArray]: ...
def set_weights(self, weights: Sequence[AnyArray]) -> None: ...
def get_weights(self) -> list[_AnyArray]: ...
def set_weights(self, weights: Sequence[_AnyArray]) -> None: ...
def get_config(self) -> dict[str, Any]: ...
@classmethod
def from_config(cls, config: dict[str, Any]) -> Self: ...
Expand Down Expand Up @@ -194,4 +194,53 @@ class Embedding(Layer[tf.Tensor, tf.Tensor]):
name: str | None = None,
) -> None: ...

class Conv2D(Layer[tf.Tensor, tf.Tensor]):
def __init__(
self,
filters: int,
kernel_size: int | tuple[int, int],
strides: int | tuple[int, int] = (1, 1),
padding: Literal["valid", "same"] = "valid",
data_format: None | Literal["channels_last", "channels_first"] = None,
dilation_rate: int | tuple[int, int] = (1, 1),
groups: int = 1,
activation: _Activation = None,
use_bias: bool = True,
kernel_initializer: _Initializer = "glorot_uniform",
bias_initializer: _Initializer = "zeros",
kernel_regularizer: _Regularizer = None,
bias_regularizer: _Regularizer = None,
activity_regularizer: _Regularizer = None,
kernel_constraint: _Constraint = None,
bias_constraint: _Constraint = None,
trainable: bool = True,
dtype: _LayerDtype = None,
dynamic: bool = False,
name: str | None = None,
) -> None: ...

class Identity(Layer[tf.Tensor, tf.Tensor]):
def __init__(
self, trainable: bool = True, dtype: _LayerDtype = None, dynamic: bool = False, name: str | None = None
) -> None: ...

class LayerNormalization(Layer[tf.Tensor, tf.Tensor]):
def __init__(
self,
axis: int = -1,
epsilon: float = 0.001,
center: bool = True,
scale: bool = True,
beta_initializer: _Initializer = "zeros",
gamma_initializer: _Initializer = "ones",
beta_regularizer: _Regularizer = None,
gamma_regularizer: _Regularizer = None,
beta_constraint: _Constraint = None,
gamma_constraint: _Constraint = None,
trainable: bool = True,
dtype: _LayerDtype = None,
dynamic: bool = False,
name: str | None = None,
) -> None: ...

def __getattr__(name: str) -> Incomplete: ...
4 changes: 2 additions & 2 deletions stubs/tensorflow/tensorflow/keras/losses.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from typing import Any, Final, Literal, TypeVar, overload
from typing_extensions import Self, TypeAlias, TypeGuard

from tensorflow import Tensor, _TensorCompatible
from tensorflow._aliases import KerasSerializable
from tensorflow._aliases import _KerasSerializable
from tensorflow.keras.metrics import (
binary_crossentropy as binary_crossentropy,
categorical_crossentropy as categorical_crossentropy,
Expand Down Expand Up @@ -136,7 +136,7 @@ def log_cosh(y_true: _TensorCompatible, y_pred: _TensorCompatible) -> Tensor: ..
def deserialize(
name: str | dict[str, Any], custom_objects: dict[str, Any] | None = None, use_legacy_format: bool = False
) -> Loss: ...
def serialize(loss: KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...
def serialize(loss: _KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...

_FuncT = TypeVar("_FuncT", bound=Callable[..., Any])

Expand Down
8 changes: 8 additions & 0 deletions stubs/tensorflow/tensorflow/keras/metrics.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from typing import Any
from typing_extensions import Self

from tensorflow import Tensor, _TensorCompatible
from tensorflow.dtypes import DType

class Metric:
def __init__(self, name: str, dtype: DType) -> None: ...
def __new__(cls, *args: Any, **kwargs: Any) -> Self: ...

def binary_crossentropy(
y_true: _TensorCompatible, y_pred: _TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1
Expand Down
Loading