1
+ from _typeshed import Incomplete
1
2
from collections .abc import Callable
2
3
from pathlib import Path
4
+ from typing import Any , Iterator , Literal , Self , Sequence
5
+
6
+ import numpy as np
7
+ import numpy .typing as npt
3
8
4
9
import tensorflow
5
10
import tensorflow as tf
6
- from tensorflow import _ShapeLike , _TensorCompatible
11
+ from tensorflow import _ShapeLike , _TensorCompatible , Variable
7
12
from tensorflow .keras .layers import Layer , _InputT , _OutputT
8
13
9
14
class Model (Layer [_InputT , _OutputT ], tf .Module ):
15
+ def __new__ (cls , * args , ** kwargs ) -> Model [_InputT , _OutputT ]: ...
10
16
def __init__ (self , * args , ** kwargs ) -> None : ...
17
+ def __setattr__ (self , name , value ) -> None : ...
18
+ def __reduce__ (self ) -> Incomplete : ...
19
+ def __deepcopy__ (self , memo ) -> Incomplete : ...
11
20
def build (self , input_shape : _ShapeLike ) -> None : ...
12
- def summary (
21
+ def __call__ (self , inputs : _InputT , * , training : bool = False , mask : _TensorCompatible | None = None ) -> _OutputT : ...
22
+ def call (self , inputs : _InputT , training : bool = False , mask : _TensorCompatible | None = None ) -> _OutputT : ...
23
+ def compile (
13
24
self ,
14
- line_length : None | int = None ,
15
- positions : None | list [float ] = None ,
16
- print_fn : None | Callable [[str ], None ] = None ,
17
- expand_nested : bool = False ,
18
- show_trainable : bool = False ,
19
- layer_range : None | list [str ] | tuple [str , str ] = None ,
25
+ optimizer = "rmsprop" ,
26
+ loss = None ,
27
+ metrics = None ,
28
+ loss_weights = None ,
29
+ weighted_metrics = None ,
30
+ run_eagerly = None ,
31
+ steps_per_execution = None ,
32
+ jit_compile = None ,
33
+ pss_evaluation_shards = 0 ,
34
+ ** kwargs ,
35
+ ) -> Incomplete : ...
36
+ @property
37
+ def metrics (self ) -> list [Incomplete ]: ...
38
+ @property
39
+ def metrics_names (self ) -> list [str ]: ...
40
+ @property
41
+ def distribute_strategy (self ) -> Incomplete : ... # tf.distribute.Strategy
42
+ @property
43
+ def run_eagerly (self ) -> bool : ...
44
+ @property
45
+ def autotune_steps_per_execution (self ) -> Incomplete : ...
46
+ @property
47
+ def steps_per_execution (self ) -> int : ...
48
+ @property
49
+ def jit_compile (self ) -> bool : ...
50
+ @property
51
+ def distribute_reduction_method (self ) -> Incomplete | Literal ["auto" ]: ...
52
+ def train_step (self , data : _TensorCompatible ) -> Incomplete : ...
53
+ def compute_loss (
54
+ self ,
55
+ x : _TensorCompatible | None = None ,
56
+ y : _TensorCompatible | None = None ,
57
+ y_pred : _TensorCompatible | None = None ,
58
+ sample_weight : Incomplete | None = None ,
59
+ ) -> tf .Tensor | None : ...
60
+ def compute_metrics (
61
+ self ,
62
+ x : _TensorCompatible ,
63
+ y : _TensorCompatible ,
64
+ y_pred : _TensorCompatible ,
65
+ sample_weight : Incomplete ,
66
+ ) -> dict [str , float ]: ...
67
+ def get_metrics_result (self ) -> dict [str , float ]: ...
68
+ def make_train_function (self , force : bool = False ) -> Callable [[tf .data .Iterator [Incomplete ]], dict [str , float ]]: ...
69
+ def fit (
70
+ self ,
71
+ x = None ,
72
+ y = None ,
73
+ batch_size = None ,
74
+ epochs = 1 ,
75
+ verbose = "auto" ,
76
+ callbacks = None ,
77
+ validation_split = 0.0 ,
78
+ validation_data = None ,
79
+ shuffle = True ,
80
+ class_weight = None ,
81
+ sample_weight = None ,
82
+ initial_epoch = 0 ,
83
+ steps_per_epoch = None ,
84
+ validation_steps = None ,
85
+ validation_batch_size = None ,
86
+ validation_freq = 1 ,
87
+ max_queue_size = 10 ,
88
+ workers = 1 ,
89
+ use_multiprocessing = False ,
90
+ ): ...
91
+ def test_step (self , data : _TensorCompatible ) -> dict [str , float ]: ...
92
+ def make_test_function (self , force : bool = False ) -> Callable [[tf .data .Iterator [Incomplete ]], dict [str , float ]]: ...
93
+ def evaluate (
94
+ self ,
95
+ x = None ,
96
+ y = None ,
97
+ batch_size = None ,
98
+ verbose = "auto" ,
99
+ sample_weight = None ,
100
+ steps = None ,
101
+ callbacks = None ,
102
+ max_queue_size = 10 ,
103
+ workers = 1 ,
104
+ use_multiprocessing = False ,
105
+ return_dict = False ,
106
+ ** kwargs ,
107
+ ): ...
108
+ def predict_step (self , data : _InputT ) -> _OutputT : ...
109
+ def make_predict_function (self , force : bool = False ) -> Callable [[tf .data .Iterator [Incomplete ]], _OutputT ]: ...
110
+ def predict (
111
+ self ,
112
+ x ,
113
+ batch_size = None ,
114
+ verbose = "auto" ,
115
+ steps = None ,
116
+ callbacks = None ,
117
+ max_queue_size = 10 ,
118
+ workers = 1 ,
119
+ use_multiprocessing = False ,
120
+ ): ...
121
+ def reset_metrics (self ) -> None : ...
122
+ def train_on_batch (
123
+ self ,
124
+ x ,
125
+ y = None ,
126
+ sample_weight = None ,
127
+ class_weight = None ,
128
+ reset_metrics = True ,
129
+ return_dict = False ,
130
+ ) -> float | list [float ]: ...
131
+ def test_on_batch (
132
+ self ,
133
+ x ,
134
+ y = None ,
135
+ sample_weight = None ,
136
+ reset_metrics = True ,
137
+ return_dict = False ,
138
+ ) -> float | list [float ]: ...
139
+ def predict_on_batch (self , x : Iterator [_InputT ]) -> Incomplete : ... # npt.NDArray[_OutputT]
140
+ def fit_generator (
141
+ self ,
142
+ generator ,
143
+ steps_per_epoch = None ,
144
+ epochs = 1 ,
145
+ verbose = 1 ,
146
+ callbacks = None ,
147
+ validation_data = None ,
148
+ validation_steps = None ,
149
+ validation_freq = 1 ,
150
+ class_weight = None ,
151
+ max_queue_size = 10 ,
152
+ workers = 1 ,
153
+ use_multiprocessing = False ,
154
+ shuffle = True ,
155
+ initial_epoch = 0 ,
156
+ ): ...
157
+ def evaluate_generator (
158
+ self ,
159
+ generator ,
160
+ steps = None ,
161
+ callbacks = None ,
162
+ max_queue_size = 10 ,
163
+ workers = 1 ,
164
+ use_multiprocessing = False ,
165
+ verbose = 0 ,
166
+ ): ...
167
+ def predict_generator (
168
+ self ,
169
+ generator ,
170
+ steps = None ,
171
+ callbacks = None ,
172
+ max_queue_size = 10 ,
173
+ workers = 1 ,
174
+ use_multiprocessing = False ,
175
+ verbose = 0 ,
176
+ ): ...
177
+ @property
178
+ def trainable_weights (self ) -> list [Variable ]: ...
179
+ @property
180
+ def non_trainable_weights (self ) -> list [Variable ]: ...
181
+ def get_weights (self ): ...
182
+ def save (self , filepath , overwrite = True , save_format = None , ** kwargs ): ...
183
+ def save_weights (
184
+ self , filepath , overwrite = True , save_format = None , options = None
20
185
): ...
21
186
def load_weights (
22
187
self ,
@@ -25,5 +190,30 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
25
190
by_name : bool = False ,
26
191
options : None | tensorflow .train .CheckpointOptions = None ,
27
192
): ...
28
- def __call__ (self , inputs : _InputT , * , training : bool = False , mask : _TensorCompatible | None = None ) -> _OutputT : ...
29
- def call (self , __inputs : _InputT ) -> _OutputT : ...
193
+ def get_config (self ): ...
194
+ @classmethod
195
+ def from_config (cls , config : dict [str , Any ], custom_objects = None ) -> Self : ...
196
+ def to_json (self , ** kwargs ) -> str : ...
197
+ def to_yaml (self , ** kwargs ) -> str : ...
198
+ def reset_states (self ) -> None : ...
199
+ @property
200
+ def state_updates (self ) -> list [Incomplete ]: ...
201
+ @property
202
+ def weights (self ) -> list [Variable ]: ...
203
+ def summary (
204
+ self ,
205
+ line_length : None | int = None ,
206
+ positions : None | list [float ] = None ,
207
+ print_fn : None | Callable [[str ], None ] = None ,
208
+ expand_nested : bool = False ,
209
+ show_trainable : bool = False ,
210
+ layer_range : None | list [str ] | tuple [str , str ] = None ,
211
+ ): ...
212
+ @property
213
+ def layers (self ) -> list [Layer [Incomplete , Incomplete ]]: ...
214
+ def get_layer (self , name = None , index = None ) -> Layer [Incomplete , Incomplete ]: ...
215
+ def get_weight_paths (self ): ...
216
+ def get_compile_config (self ) -> dict [str , Any ]: ...
217
+ def compile_from_config (self , config : dict [str , Any ]) -> Self : ...
218
+ def export (self , filepath : str | Path ) -> None : ...
219
+ def save_spec (self , dynamic_batch : bool = True ) -> tuple [tuple [tf .TensorSpec , ...], dict [str , tf .TensorSpec ]] | None : ...
0 commit comments