Skip to content

Add tensorflow type stubs #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
479 changes: 479 additions & 0 deletions tensorflow/__init__.pyi

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tensorflow/autograph/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tensorflow.autograph import experimental as experimental
10 changes: 10 additions & 0 deletions tensorflow/autograph/experimental.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from enum import Enum

class Feature(Enum):
ALL = "ALL"
AUTO_CONTROL_DEPS = "AUTO_CONTROL_DEPS"
ASSERT_STATEMENTS = "ASSERT_STATEMENTS"
BUILTIN_FUNCTIONS = "BUILTIN_FUNCTIONS"
EQUALITY_OPERATORS = "EQUALITY_OPERATORS"
LISTS = "LISTS"
NAME_SCOPES = "NAME_SCOPES"
1 change: 1 addition & 0 deletions tensorflow/compat/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tensorflow.compat import v1 as v1
80 changes: 80 additions & 0 deletions tensorflow/compat/v1/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, Mapping, MutableSequence, Sequence, overload
from typing_extensions import Self

from types import TracebackType

from google.protobuf.message import Message

import numpy as np

import tensorflow as tf
from tensorflow.compat.v1 import graph_util as graph_util
from tensorflow.compat.v1 import saved_model as saved_model

from bento.utils.tensor_types import FloatDataSequence

# Would be better to use mypy-protobuf to make this.
class GraphDef(Message):
node: MutableSequence[NodeDef]
def __getattr__(self, name: str) -> Any: ...

class MetaGraphDef(Message): ...

class NodeDef(Message):
name: str
op: str
input: MutableSequence[str]
def __getattr__(self, name: str) -> Any: ...

class RunOptions(Message): ...
class RunMetadata(Message): ...

_GraphElement = tf.Tensor | tf.SparseTensor | tf.Operation | str
_FeedElement = float | str | np.ndarray[Any, Any] | FloatDataSequence
# This is a simplification. Key being invariant in a Mapping makes the real type difficult to write. This
# is enough to cover vast majority of use cases.
_FeedDict = Mapping[str, _FeedElement] | Mapping[tf.Tensor, _FeedElement] | Mapping[tf.SparseTensor, _FeedElement]

class Session:
graph: tf.Graph
graph_def: GraphDef
def __init__(
self,
*,
graph: tf.Graph | None = None,
) -> None: ...
@overload
def run(
self,
fetches: _GraphElement,
feed_dict: _FeedDict | None = None,
options: RunOptions | None = None,
run_metadata: RunMetadata | None = None,
) -> np.ndarray[Any, Any]: ...
@overload
def run(
self,
fetches: Sequence[_GraphElement],
feed_dict: _FeedDict | None = None,
options: RunOptions | None = None,
run_metadata: RunMetadata | None = None,
) -> list[np.ndarray[Any, Any]]: ...
@overload
def run(
self,
fetches: Mapping[str, _GraphElement],
feed_dict: _FeedDict | None = None,
options: RunOptions | None = None,
run_metadata: RunMetadata | None = None,
) -> dict[str, np.ndarray[Any, Any]]: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
) -> None: ...

def disable_eager_execution() -> None: ...
def disable_v2_behavior() -> None: ...
def global_variables_initializer() -> tf.Operation: ...
def tables_initializer() -> tf.Operation: ...
def get_default_graph() -> tf.Graph: ...
def __getattr__(name: str) -> Any: ...
12 changes: 12 additions & 0 deletions tensorflow/compat/v1/graph_util.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Iterable

from tensorflow.compat.v1 import GraphDef, Session

def extract_sub_graph(graph_def: GraphDef, dest_nodes: Iterable[str]) -> GraphDef: ...
def convert_variables_to_constants(
sess: Session,
input_graph_def: GraphDef,
output_node_names: Iterable[str],
variable_names_whitelist: Iterable[str] | None = None,
variable_names_blacklist: Iterable[str] | None = None,
) -> GraphDef: ...
17 changes: 17 additions & 0 deletions tensorflow/compat/v1/saved_model/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Mapping

import tensorflow as tf
from tensorflow.compat.v1 import Session
from tensorflow.compat.v1.saved_model import loader as loader
from tensorflow.compat.v1.saved_model import tag_constants as tag_constants
from tensorflow.compat.v1.saved_model.builder import SavedModelBuilder

Builder = SavedModelBuilder

def simple_save(
session: Session,
export_dir: str,
inputs: Mapping[str, tf.Tensor],
outputs: Mapping[str, tf.Tensor],
legacy_init_op: tf.Operation | None = None,
) -> None: ...
3 changes: 3 additions & 0 deletions tensorflow/compat/v1/saved_model/builder.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class SavedModelBuilder:
def __init__(self, export_dir: str) -> None: ...
def save(self, as_text: bool = False) -> str: ...
7 changes: 7 additions & 0 deletions tensorflow/compat/v1/saved_model/loader.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Iterable

from tensorflow.compat.v1 import MetaGraphDef, Session

def load(
sess: Session, tags: Iterable[str], export_dir: str, import_scope: str | None = None, **saver_kwargs
) -> MetaGraphDef: ...
4 changes: 4 additions & 0 deletions tensorflow/compat/v1/saved_model/tag_constants.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from tensorflow.saved_model import GPU as GPU
from tensorflow.saved_model import SERVING as SERVING
from tensorflow.saved_model import TPU as TPU
from tensorflow.saved_model import TRAINING as TRAINING
Empty file added tensorflow/config/__init__.pyi
Empty file.
1 change: 1 addition & 0 deletions tensorflow/config/experimental.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
def enable_op_determinism() -> None: ...
63 changes: 63 additions & 0 deletions tensorflow/data/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Any, Callable, Generic, Iterator, Literal, Sequence, TypeVar
from typing_extensions import Self

import numpy as np

from tensorflow import Tensor, TensorCompatibleT, TypeSpec
from tensorflow.data import experimental as experimental

from bento.utils.tensor_types import ContainerGeneric, ScalarTensorCompatible, TensorCompatible

_T1 = TypeVar("_T1", covariant=True)
_T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")

class Dataset(Generic[_T1]):
element_spec: ContainerGeneric[TypeSpec]
def apply(self: Dataset[_T1], transformation_func: Callable[[Dataset[_T1]], Dataset[_T2]]) -> Dataset[_T2]: ...
def as_numpy_iterator(self) -> Iterator[np.ndarray[Any, Any]]: ...
def batch(
self: Self,
batch_size: ScalarTensorCompatible,
drop_remainder: bool = False,
num_parallel_calls: int | None = None,
deterministic: bool | None = None,
name: str | None = None,
) -> Self: ...
def cache(self: Self, filename: str = "", name: str | None = None) -> Self: ...
@classmethod
def from_tensor_slices(
cls, tensors: Sequence[TensorCompatibleT] | TensorCompatibleT, name: str | None = None
) -> Dataset[TensorCompatibleT]: ...
def __iter__(self) -> Iterator[_T1]: ...
def map(
self: Dataset[_T1],
map_func: Callable[[_T1], _T2],
num_parallel_calls: int | None = None,
deterministic: None | bool = None,
name: str | None = None,
) -> Dataset[_T2]: ...
def prefetch(self: Self, buffer_size: int, name: str | None = None) -> Self: ...
def reduce(self, initial_state: _T2, reduce_func: Callable[[_T2, _T1], _T2], name: str | None = None) -> _T2: ...
def repeat(self: Self, count: int | None = None, name: str | None = None) -> Self: ...
def shard(self: Self, num_shards: int, index: int, name: str | None = None) -> Self: ...
def shuffle(
self: Self,
buffer_size: int,
seed: int | None = None,
reshuffle_each_iteration: bool = True,
name: str | None = None,
) -> Self: ...
def take(self: Self, count: int, name: str | None = None) -> Self: ...
@staticmethod
def zip(datasets: tuple[Dataset[_T2], Dataset[_T3]], name: str | None = None) -> Dataset[tuple[_T2, _T3]]: ...

class TFRecordDataset(Dataset[Tensor]):
def __init__(
self,
filenames: TensorCompatible | Dataset[str],
compression_type: Literal["", "ZLIB", "GZIP"] | None = None,
buffer_size: int | None = None,
num_parallel_reads: int | None = None,
name: str | None = None,
) -> None: ...
20 changes: 20 additions & 0 deletions tensorflow/data/experimental.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Callable, TypeVar

from tensorflow.data import Dataset

AUTOTUNE: int
INFINITE_CARDINALITY: int
SHARD_HINT: int
UNKNOWN_CARDINALITY: int

_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")

def parallel_interleave(
map_func: Callable[[_T1], Dataset[_T2]],
cycle_length: int,
block_length: int = 1,
sloppy: bool | None = False,
buffer_output_elements: int | None = None,
prefetch_input_elements: int | None = None,
) -> Callable[[Dataset[_T1]], Dataset[_T2]]: ...
39 changes: 39 additions & 0 deletions tensorflow/distribute.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Generic, Iterator, NamedTuple, TypeVar

from enum import Enum

from tensorflow import TypeSpec

from bento.utils.tensor_types import ContainerGeneric

class InputContext:
def __init__(
self, num_input_pipelines: int = 1, input_pipeline_id: int = 0, num_replicas_in_sync: int = 1
) -> None: ...
@property
def num_input_pipelines(self) -> int: ...
@property
def input_pipeline_id(self) -> int: ...
@property
def num_replicas_in_sync(self) -> int: ...
def get_per_replica_batch_size(self, global_batch_size: int) -> int: ...

class InputReplicationMode(Enum):
PER_WORKER = "PER_WORKER"
PER_REPLICA = "PER_REPLICA"

class InputOptions(NamedTuple):
experimental_fetch_to_device: bool | None = None
experimental_replication_mode: InputReplicationMode = InputReplicationMode.PER_WORKER
experimental_place_dataset_on_device: bool = False
experimental_per_replica_buffer_size: int = 1

_T1 = TypeVar("_T1", covariant=True)

class DistributedIterator(Generic[_T1]):
element_spec: ContainerGeneric[TypeSpec]
def __iter__(self) -> Iterator[_T1]: ...

class DistributedDataset(Generic[_T1]):
element_spec: ContainerGeneric[TypeSpec]
def __iter__(self) -> DistributedIterator[_T1]: ...
54 changes: 54 additions & 0 deletions tensorflow/dtypes.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any

from builtins import bool as _bool

import numpy as np

from tensorflow import DTypeLike

# If we want to handle tensors as generic on dtypes we likely need to make
# this class an Enum. That's a minor lie type wise, but Literals only work
# with basic types + enums.
class DType:
@property
def name(self) -> str: ...
@property
def as_numpy_dtype(self) -> type[np.number[Any]]: ...
@property
def is_numpy_compatible(self) -> _bool: ...
@property
def is_bool(self) -> _bool: ...
@property
def is_floating(self) -> _bool: ...
@property
def is_integer(self) -> _bool: ...
@property
def is_quantized(self) -> _bool: ...
@property
def is_unsigned(self) -> _bool: ...

bool: DType = ...
complex128: DType = ...
complex64: DType = ...
bfloat16: DType = ...
float16: DType = ...
half: DType = ...
float32: DType = ...
float64: DType = ...
double: DType = ...
int8: DType = ...
int16: DType = ...
int32: DType = ...
int64: DType = ...
uint8: DType = ...
uint16: DType = ...
uint32: DType = ...
uint64: DType = ...
qint8: DType = ...
qint16: DType = ...
qint32: DType = ...
quint8: DType = ...
quint16: DType = ...
string: DType = ...

def as_dtype(dtype: DTypeLike) -> DType: ...
3 changes: 3 additions & 0 deletions tensorflow/errors.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Any

def __getattr__(name: str) -> Any: ...
Loading