@@ -10,9 +10,18 @@ import tensorflow
10
10
import tensorflow as tf
11
11
from tensorflow import Variable , _ShapeLike , _TensorCompatible
12
12
from tensorflow .keras .layers import Layer , _InputT , _OutputT
13
- from tensorflow .keras .optimizers .legacy import Optimizer
13
+ from tensorflow .keras import _Loss , _Metric
14
+ from tensorflow ._aliases import _ContainerGeneric
15
+
16
+ _BothOptimizer = tf .optimizers .Optimizer | tf .optimizers .experimental .Optimizer
14
17
15
18
class Model (Layer [_InputT , _OutputT ], tf .Module ):
19
+ _train_counter : tf .Variable
20
+ _test_counter : tf .Variable
21
+ optimizer : _BothOptimizer | None
22
+ loss : tf .keras .losses .Loss | dict [str , tf .keras .losses .Loss ]
23
+ stop_training : bool
24
+
16
25
def __new__ (cls , * args : Any , ** kwargs : Any ) -> Model [_InputT , _OutputT ]: ...
17
26
def __init__ (self , * args : Any , ** kwargs : Any ) -> None : ...
18
27
def __setattr__ (self , name : str , value : Any ) -> None : ...
@@ -21,13 +30,14 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
21
30
def build (self , input_shape : _ShapeLike ) -> None : ...
22
31
def __call__ (self , inputs : _InputT , * , training : bool = False , mask : _TensorCompatible | None = None ) -> _OutputT : ...
23
32
def call (self , inputs : _InputT , training : bool | None = None , mask : _TensorCompatible | None = None ) -> _OutputT : ...
33
+ # Ideally loss/metrics/output would share the same structure but higher kinded types are not supported.
24
34
def compile (
25
35
self ,
26
- optimizer : Optimizer | str = "rmsprop" ,
27
- loss : tf . keras . losses . Loss | str | None = None ,
28
- metrics : list [ tf . keras . metrics . Metric | str ] | None = None ,
29
- loss_weights : list [ float ] | dict [ str , float ] | None = None ,
30
- weighted_metrics : list [ tf . keras . metrics . Metric ] | None = None ,
36
+ optimizer : _BothOptimizer | str = "rmsprop" ,
37
+ loss : _ContainerGeneric [ _Loss ] | None = None ,
38
+ metrics : _ContainerGeneric [ _Metric ] | None = None ,
39
+ loss_weights : _ContainerGeneric [ float ] | None = None ,
40
+ weighted_metrics : _ContainerGeneric [ _Metric ] | None = None ,
31
41
run_eagerly : bool | None = None ,
32
42
steps_per_execution : int | Literal ["auto" ] | None = None ,
33
43
jit_compile : bool | None = None ,
0 commit comments