diff --git a/stubs/tensorflow/tensorflow/config/__init__.pyi b/stubs/tensorflow/tensorflow/config/__init__.pyi new file mode 100644 index 000000000000..78721953f462 --- /dev/null +++ b/stubs/tensorflow/tensorflow/config/__init__.pyi @@ -0,0 +1,13 @@ +from _typeshed import Incomplete +from typing import NamedTuple + +from tensorflow.config import experimental as experimental + +class PhysicalDevice(NamedTuple): + name: str + device_type: str + +def list_physical_devices(device_type: None | str = None) -> list[PhysicalDevice]: ... +def get_visible_devices(device_type: None | str = None) -> list[PhysicalDevice]: ... +def set_visible_devices(devices: list[PhysicalDevice] | PhysicalDevice, device_type: None | str = None) -> None: ... +def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/config/experimental.pyi b/stubs/tensorflow/tensorflow/config/experimental.pyi new file mode 100644 index 000000000000..53b4277656ab --- /dev/null +++ b/stubs/tensorflow/tensorflow/config/experimental.pyi @@ -0,0 +1,17 @@ +import typing_extensions +from _typeshed import Incomplete +from typing import TypedDict + +from tensorflow.config import PhysicalDevice + +class _MemoryInfo(TypedDict): + current: int + peak: int + +def get_memory_info(device: str) -> _MemoryInfo: ... +def reset_memory_stats(device: str) -> None: ... +@typing_extensions.deprecated("This function is deprecated in favor of tf.config.experimental.get_memory_info") +def get_memory_usage(device: PhysicalDevice) -> int: ... +def get_memory_growth(device: PhysicalDevice) -> bool: ... +def set_memory_growth(device: PhysicalDevice, enable: bool) -> None: ... +def __getattr__(name: str) -> Incomplete: ...