From a863e2687659f3f2c74501faf73ce6ceabfce861 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 17 Jun 2021 06:59:18 -0700 Subject: [PATCH] Support opt mode for torchvision ops (#4077) Summary: Pull Request resolved: https://github.com/pytorch/vision/pull/4077 Reviewed By: fmassa Differential Revision: D29159451 fbshipit-source-id: ee1d4b677daf1e0ce7bb0a3796669cc739f93ec4 --- torchvision/_register_extension.py | 41 ++++++++++++++++++++++++ torchvision/extension.py | 50 +++++------------------------- torchvision/io/_video_opt.py | 43 +++---------------------- torchvision/io/image.py | 45 +++------------------------ 4 files changed, 57 insertions(+), 122 deletions(-) create mode 100644 torchvision/_register_extension.py diff --git a/torchvision/_register_extension.py b/torchvision/_register_extension.py new file mode 100644 index 00000000000..e8cb097d9b5 --- /dev/null +++ b/torchvision/_register_extension.py @@ -0,0 +1,41 @@ +import os +import importlib.machinery + + +def _get_extension_path(lib_name): + + lib_dir = os.path.dirname(__file__) + if os.name == 'nt': + # Register the main torchvision library location on the default DLL path + import ctypes + import sys + + kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) + with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise err + + kernel32.SetErrorMode(prev_error_mode) + + loader_details = ( + importlib.machinery.ExtensionFileLoader, + importlib.machinery.EXTENSION_SUFFIXES + ) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + ext_specs = extfinder.find_spec(lib_name) + if ext_specs is None: + raise ImportError + + return ext_specs.origin diff --git a/torchvision/extension.py b/torchvision/extension.py index 265c989a8ce..ade9cc21980 100644 --- a/torchvision/extension.py +++ b/torchvision/extension.py @@ -1,54 +1,18 @@ -_HAS_OPS = False - +import torch -def _has_ops(): - return False - - -def _register_extensions(): - import os - import importlib - import torch +from ._register_extension import _get_extension_path - # load the custom_op_library and register the custom ops - lib_dir = os.path.dirname(__file__) - if os.name == 'nt': - # Register the main torchvision library location on the default DLL path - import ctypes - import sys - kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) - with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') - prev_error_mode = kernel32.SetErrorMode(0x0001) - - if with_load_library_flags: - kernel32.AddDllDirectory.restype = ctypes.c_void_p - - if sys.version_info >= (3, 8): - os.add_dll_directory(lib_dir) - elif with_load_library_flags: - res = kernel32.AddDllDirectory(lib_dir) - if res is None: - err = ctypes.WinError(ctypes.get_last_error()) - err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' - raise err - - kernel32.SetErrorMode(prev_error_mode) +_HAS_OPS = False - loader_details = ( - importlib.machinery.ExtensionFileLoader, - importlib.machinery.EXTENSION_SUFFIXES - ) - extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) - ext_specs = extfinder.find_spec("_C") - if ext_specs is None: - raise ImportError - torch.ops.load_library(ext_specs.origin) +def _has_ops(): + return False try: - _register_extensions() + lib_path = _get_extension_path('_C') + torch.ops.load_library(lib_path) _HAS_OPS = True def _has_ops(): # noqa: F811 diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index e92ac1bd396..07795b63348 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -1,5 +1,4 @@ -import importlib import math import os import warnings @@ -9,47 +8,15 @@ import numpy as np import torch +from .._register_extension import _get_extension_path -_HAS_VIDEO_OPT = False try: - lib_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - - loader_details = ( - importlib.machinery.ExtensionFileLoader, - importlib.machinery.EXTENSION_SUFFIXES - ) - - extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) - ext_specs = extfinder.find_spec("video_reader") - - if os.name == 'nt': - # Load the video_reader extension using LoadLibraryExW - import ctypes - - kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) - with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') - prev_error_mode = kernel32.SetErrorMode(0x0001) - - if with_load_library_flags: - kernel32.LoadLibraryExW.restype = ctypes.c_void_p - - if ext_specs is not None: - res = kernel32.LoadLibraryExW(ext_specs.origin, None, 0x00001100) - if res is None: - err = ctypes.WinError(ctypes.get_last_error()) - err.strerror += (f' Error loading "{ext_specs.origin}" or any or ' - 'its dependencies.') - raise err - - kernel32.SetErrorMode(prev_error_mode) - - if ext_specs is not None: - torch.ops.load_library(ext_specs.origin) - _HAS_VIDEO_OPT = True + lib_path = _get_extension_path('video_reader') + torch.ops.load_library(lib_path) + _HAS_VIDEO_OPT = True except (ImportError, OSError): - pass - + _HAS_VIDEO_OPT = False default_timebase = Fraction(0, 1) diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 4f824abad60..08f6d65b1f4 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -1,49 +1,12 @@ import torch - -import os -import os.path as osp -import importlib.machinery - from enum import Enum -_HAS_IMAGE_OPT = False - -try: - lib_dir = osp.abspath(osp.join(osp.dirname(__file__), "..")) - - loader_details = ( - importlib.machinery.ExtensionFileLoader, - importlib.machinery.EXTENSION_SUFFIXES - ) - - extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) # type: ignore[arg-type] - ext_specs = extfinder.find_spec("image") +from .._register_extension import _get_extension_path - if os.name == 'nt': - # Load the image extension using LoadLibraryExW - import ctypes - kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) - with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') - prev_error_mode = kernel32.SetErrorMode(0x0001) - - kernel32.LoadLibraryW.restype = ctypes.c_void_p - if with_load_library_flags: - kernel32.LoadLibraryExW.restype = ctypes.c_void_p - - if ext_specs is not None: - res = kernel32.LoadLibraryExW(ext_specs.origin, None, 0x00001100) - if res is None: - err = ctypes.WinError(ctypes.get_last_error()) - err.strerror += (f' Error loading "{ext_specs.origin}" or any or ' - 'its dependencies.') - raise err - - kernel32.SetErrorMode(prev_error_mode) - - if ext_specs is not None: - torch.ops.load_library(ext_specs.origin) - _HAS_IMAGE_OPT = True +try: + lib_path = _get_extension_path('image') + torch.ops.load_library(lib_path) except (ImportError, OSError): pass