Skip to content

Commit 3d6152f

Browse files
committed
fix: add missing Model methods/properties (wip)
1 parent 2480b4c commit 3d6152f

File tree

1 file changed

+200
-10
lines changed

1 file changed

+200
-10
lines changed
Lines changed: 200 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,187 @@
1+
from _typeshed import Incomplete
12
from collections.abc import Callable
23
from pathlib import Path
4+
from typing import Any, Iterator, Literal, Self, Sequence
5+
6+
import numpy as np
7+
import numpy.typing as npt
38

49
import tensorflow
510
import tensorflow as tf
6-
from tensorflow import _ShapeLike, _TensorCompatible
11+
from tensorflow import _ShapeLike, _TensorCompatible, Variable
712
from tensorflow.keras.layers import Layer, _InputT, _OutputT
813

914
class Model(Layer[_InputT, _OutputT], tf.Module):
15+
def __new__(cls, *args, **kwargs) -> Model[_InputT, _OutputT]: ...
1016
def __init__(self, *args, **kwargs) -> None: ...
17+
def __setattr__(self, name, value) -> None: ...
18+
def __reduce__(self) -> Incomplete: ...
19+
def __deepcopy__(self, memo) -> Incomplete: ...
1120
def build(self, input_shape: _ShapeLike) -> None: ...
12-
def summary(
21+
def __call__(self, inputs: _InputT, *, training: bool = False, mask: _TensorCompatible | None = None) -> _OutputT: ...
22+
def call(self, inputs: _InputT, training: bool = False, mask: _TensorCompatible | None = None) -> _OutputT: ...
23+
def compile(
1324
self,
14-
line_length: None | int = None,
15-
positions: None | list[float] = None,
16-
print_fn: None | Callable[[str], None] = None,
17-
expand_nested: bool = False,
18-
show_trainable: bool = False,
19-
layer_range: None | list[str] | tuple[str, str] = None,
25+
optimizer="rmsprop",
26+
loss=None,
27+
metrics=None,
28+
loss_weights=None,
29+
weighted_metrics=None,
30+
run_eagerly=None,
31+
steps_per_execution=None,
32+
jit_compile=None,
33+
pss_evaluation_shards=0,
34+
**kwargs,
35+
) -> Incomplete: ...
36+
@property
37+
def metrics(self) -> list[Incomplete]: ...
38+
@property
39+
def metrics_names(self) -> list[str]: ...
40+
@property
41+
def distribute_strategy(self) -> Incomplete: ... # tf.distribute.Strategy
42+
@property
43+
def run_eagerly(self) -> bool: ...
44+
@property
45+
def autotune_steps_per_execution(self) -> Incomplete: ...
46+
@property
47+
def steps_per_execution(self) -> int: ...
48+
@property
49+
def jit_compile(self) -> bool: ...
50+
@property
51+
def distribute_reduction_method(self) -> Incomplete | Literal["auto"]: ...
52+
def train_step(self, data: _TensorCompatible) -> Incomplete: ...
53+
def compute_loss(
54+
self,
55+
x: _TensorCompatible | None = None,
56+
y: _TensorCompatible | None = None,
57+
y_pred: _TensorCompatible | None = None,
58+
sample_weight: Incomplete | None = None,
59+
) -> tf.Tensor | None: ...
60+
def compute_metrics(
61+
self,
62+
x: _TensorCompatible,
63+
y: _TensorCompatible,
64+
y_pred: _TensorCompatible,
65+
sample_weight: Incomplete,
66+
) -> dict[str, float]: ...
67+
def get_metrics_result(self) -> dict[str, float]: ...
68+
def make_train_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], dict[str, float]]: ...
69+
def fit(
70+
self,
71+
x=None,
72+
y=None,
73+
batch_size=None,
74+
epochs=1,
75+
verbose="auto",
76+
callbacks=None,
77+
validation_split=0.0,
78+
validation_data=None,
79+
shuffle=True,
80+
class_weight=None,
81+
sample_weight=None,
82+
initial_epoch=0,
83+
steps_per_epoch=None,
84+
validation_steps=None,
85+
validation_batch_size=None,
86+
validation_freq=1,
87+
max_queue_size=10,
88+
workers=1,
89+
use_multiprocessing=False,
90+
): ...
91+
def test_step(self, data: _TensorCompatible) -> dict[str, float]: ...
92+
def make_test_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], dict[str, float]]: ...
93+
def evaluate(
94+
self,
95+
x=None,
96+
y=None,
97+
batch_size=None,
98+
verbose="auto",
99+
sample_weight=None,
100+
steps=None,
101+
callbacks=None,
102+
max_queue_size=10,
103+
workers=1,
104+
use_multiprocessing=False,
105+
return_dict=False,
106+
**kwargs,
107+
): ...
108+
def predict_step(self, data: _InputT) -> _OutputT: ...
109+
def make_predict_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], _OutputT]: ...
110+
def predict(
111+
self,
112+
x,
113+
batch_size=None,
114+
verbose="auto",
115+
steps=None,
116+
callbacks=None,
117+
max_queue_size=10,
118+
workers=1,
119+
use_multiprocessing=False,
120+
): ...
121+
def reset_metrics(self) -> None: ...
122+
def train_on_batch(
123+
self,
124+
x,
125+
y=None,
126+
sample_weight=None,
127+
class_weight=None,
128+
reset_metrics=True,
129+
return_dict=False,
130+
) -> float | list[float]: ...
131+
def test_on_batch(
132+
self,
133+
x,
134+
y=None,
135+
sample_weight=None,
136+
reset_metrics=True,
137+
return_dict=False,
138+
) -> float | list[float]: ...
139+
def predict_on_batch(self, x: Iterator[_InputT]) -> Incomplete: ... # npt.NDArray[_OutputT]
140+
def fit_generator(
141+
self,
142+
generator,
143+
steps_per_epoch=None,
144+
epochs=1,
145+
verbose=1,
146+
callbacks=None,
147+
validation_data=None,
148+
validation_steps=None,
149+
validation_freq=1,
150+
class_weight=None,
151+
max_queue_size=10,
152+
workers=1,
153+
use_multiprocessing=False,
154+
shuffle=True,
155+
initial_epoch=0,
156+
): ...
157+
def evaluate_generator(
158+
self,
159+
generator,
160+
steps=None,
161+
callbacks=None,
162+
max_queue_size=10,
163+
workers=1,
164+
use_multiprocessing=False,
165+
verbose=0,
166+
): ...
167+
def predict_generator(
168+
self,
169+
generator,
170+
steps=None,
171+
callbacks=None,
172+
max_queue_size=10,
173+
workers=1,
174+
use_multiprocessing=False,
175+
verbose=0,
176+
): ...
177+
@property
178+
def trainable_weights(self) -> list[Variable]: ...
179+
@property
180+
def non_trainable_weights(self) -> list[Variable]: ...
181+
def get_weights(self): ...
182+
def save(self, filepath, overwrite=True, save_format=None, **kwargs): ...
183+
def save_weights(
184+
self, filepath, overwrite=True, save_format=None, options=None
20185
): ...
21186
def load_weights(
22187
self,
@@ -25,5 +190,30 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
25190
by_name: bool = False,
26191
options: None | tensorflow.train.CheckpointOptions = None,
27192
): ...
28-
def __call__(self, inputs: _InputT, *, training: bool = False, mask: _TensorCompatible | None = None) -> _OutputT: ...
29-
def call(self, __inputs: _InputT) -> _OutputT: ...
193+
def get_config(self): ...
194+
@classmethod
195+
def from_config(cls, config: dict[str, Any], custom_objects=None) -> Self: ...
196+
def to_json(self, **kwargs) -> str: ...
197+
def to_yaml(self, **kwargs) -> str: ...
198+
def reset_states(self) -> None: ...
199+
@property
200+
def state_updates(self) -> list[Incomplete]: ...
201+
@property
202+
def weights(self) -> list[Variable]: ...
203+
def summary(
204+
self,
205+
line_length: None | int = None,
206+
positions: None | list[float] = None,
207+
print_fn: None | Callable[[str], None] = None,
208+
expand_nested: bool = False,
209+
show_trainable: bool = False,
210+
layer_range: None | list[str] | tuple[str, str] = None,
211+
): ...
212+
@property
213+
def layers(self) -> list[Layer[Incomplete, Incomplete]]: ...
214+
def get_layer(self, name=None, index=None) -> Layer[Incomplete, Incomplete]: ...
215+
def get_weight_paths(self): ...
216+
def get_compile_config(self) -> dict[str, Any]: ...
217+
def compile_from_config(self, config: dict[str, Any]) -> Self: ...
218+
def export(self, filepath: str | Path) -> None: ...
219+
def save_spec(self, dynamic_batch: bool = True) -> tuple[tuple[tf.TensorSpec, ...], dict[str, tf.TensorSpec]] | None: ...

0 commit comments

Comments
 (0)