Skip to content

Commit a2ac4de

Browse files
authored
Remove nested import logic for torchvision (#40940)
* remove nested import logic for torchvision * remove unnecessary protected imports * remove unnecessarry protected import in modular (and modeling) * fix wrongly remove protected imports
1 parent 8e837f6 commit a2ac4de

File tree

90 files changed

+403
-838
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+403
-838
lines changed

src/transformers/image_processing_utils_fast.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@
6161

6262
if is_torchvision_available():
6363
from .image_utils import pil_torch_interpolation_mapping
64-
65-
if is_torchvision_v2_available():
66-
from torchvision.transforms.v2 import functional as F
67-
else:
68-
from torchvision.transforms import functional as F
6964
else:
7065
pil_torch_interpolation_mapping = None
7166

67+
if is_torchvision_v2_available():
68+
from torchvision.transforms.v2 import functional as F
69+
elif is_torchvision_available():
70+
from torchvision.transforms import functional as F
71+
7272
logger = logging.get_logger(__name__)
7373

7474

src/transformers/models/aria/modeling_aria.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from dataclasses import dataclass
2222
from typing import Callable, Optional, Union
2323

24+
import torch
25+
from torch import nn
26+
2427
from ...activations import ACT2FN
2528
from ...cache_utils import Cache, DynamicCache
2629
from ...generation import GenerationMixin
@@ -35,16 +38,10 @@
3538
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
3639
from ...utils.deprecation import deprecate_kwarg
3740
from ...utils.generic import check_model_inputs
38-
from ...utils.import_utils import is_torch_available
3941
from ..auto import AutoModel
4042
from .configuration_aria import AriaConfig, AriaTextConfig
4143

4244

43-
if is_torch_available():
44-
import torch
45-
from torch import nn
46-
47-
4845
@use_kernel_forward_from_hub("RMSNorm")
4946
class AriaTextRMSNorm(nn.Module):
5047
def __init__(self, hidden_size, eps=1e-6):

src/transformers/models/aria/modular_aria.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from typing import Optional, Union
1717

1818
import numpy as np
19+
import torch
20+
from torch import nn
1921

2022
from ...activations import ACT2FN
2123
from ...cache_utils import Cache
@@ -39,7 +41,6 @@
3941
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
4042
from ...tokenization_utils import PreTokenizedInput, TextInput
4143
from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, logging
42-
from ...utils.import_utils import is_torch_available
4344
from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
4445
from ..llama.configuration_llama import LlamaConfig
4546
from ..llama.modeling_llama import (
@@ -62,10 +63,6 @@
6263

6364
logger = logging.get_logger(__name__)
6465

65-
if is_torch_available():
66-
import torch
67-
from torch import nn
68-
6966

7067
def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
7168
"""

src/transformers/models/beit/image_processing_beit_fast.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from typing import Optional, Union
1818

19+
import torch
20+
1921
from ...image_processing_utils import BatchFeature
2022
from ...image_processing_utils_fast import (
2123
BaseImageProcessorFast,
@@ -36,18 +38,13 @@
3638
from ...utils import (
3739
TensorType,
3840
auto_docstring,
39-
is_torch_available,
40-
is_torchvision_available,
4141
is_torchvision_v2_available,
4242
)
4343

4444

45-
if is_torch_available():
46-
import torch
47-
4845
if is_torchvision_v2_available():
4946
from torchvision.transforms.v2 import functional as F
50-
elif is_torchvision_available():
47+
else:
5148
from torchvision.transforms import functional as F
5249

5350

src/transformers/models/bridgetower/image_processing_bridgetower_fast.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from collections.abc import Iterable
1818
from typing import Optional, Union
1919

20+
import torch
21+
2022
from ...image_processing_utils_fast import (
2123
BaseImageProcessorFast,
2224
BatchFeature,
@@ -29,17 +31,13 @@
2931
reorder_images,
3032
)
3133
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
32-
from ...utils import auto_docstring, is_torch_available, is_torchvision_available, is_torchvision_v2_available
33-
34+
from ...utils import auto_docstring, is_torchvision_v2_available
3435

35-
if is_torch_available():
36-
import torch
3736

38-
if is_torchvision_available():
39-
if is_torchvision_v2_available():
40-
from torchvision.transforms.v2 import functional as F
41-
else:
42-
from torchvision.transforms import functional as F
37+
if is_torchvision_v2_available():
38+
from torchvision.transforms.v2 import functional as F
39+
else:
40+
from torchvision.transforms import functional as F
4341

4442

4543
def make_pixel_mask(

src/transformers/models/chameleon/image_processing_chameleon_fast.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,18 @@
1717
from typing import Optional
1818

1919
import numpy as np
20+
import PIL
21+
import torch
2022

2123
from ...image_processing_utils_fast import BaseImageProcessorFast
2224
from ...image_utils import ImageInput, PILImageResampling, SizeDict
23-
from ...utils import (
24-
auto_docstring,
25-
is_torch_available,
26-
is_torchvision_available,
27-
is_torchvision_v2_available,
28-
is_vision_available,
29-
logging,
30-
)
31-
32-
33-
if is_vision_available():
34-
import PIL
35-
if is_torch_available():
36-
import torch
37-
if is_torchvision_available():
38-
if is_torchvision_v2_available():
39-
from torchvision.transforms.v2 import functional as F
40-
else:
41-
from torchvision.transforms import functional as F
25+
from ...utils import auto_docstring, is_torchvision_v2_available, logging
26+
27+
28+
if is_torchvision_v2_available():
29+
from torchvision.transforms.v2 import functional as F
30+
else:
31+
from torchvision.transforms import functional as F
4232

4333
logger = logging.get_logger(__name__)
4434

src/transformers/models/cohere2_vision/modular_cohere2_vision.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,7 @@
3232
from ...cache_utils import Cache
3333
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3434
from ...processing_utils import Unpack
35-
from ...utils import (
36-
TransformersKwargs,
37-
auto_docstring,
38-
logging,
39-
)
35+
from ...utils import TransformersKwargs, auto_docstring, logging
4036
from ...utils.generic import check_model_inputs
4137
from .configuration_cohere2_vision import Cohere2VisionConfig
4238

src/transformers/models/colpali/modular_colpali.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
if is_torch_available():
2929
import torch
3030

31-
3231
logger = logging.get_logger(__name__)
3332

3433

src/transformers/models/colqwen2/modular_colqwen2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
if is_torch_available():
3131
import torch
3232

33-
3433
logger = logging.get_logger(__name__)
3534

3635

src/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import pathlib
88
from typing import Any, Optional, Union
99

10+
import torch
11+
from torch import nn
12+
from torchvision.io import read_image
13+
1014
from ...image_processing_utils import BatchFeature, get_size_dict
1115
from ...image_processing_utils_fast import (
1216
BaseImageProcessorFast,
@@ -29,14 +33,7 @@
2933
validate_annotations,
3034
)
3135
from ...processing_utils import Unpack
32-
from ...utils import (
33-
TensorType,
34-
auto_docstring,
35-
is_torch_available,
36-
is_torchvision_available,
37-
is_torchvision_v2_available,
38-
logging,
39-
)
36+
from ...utils import TensorType, auto_docstring, is_torchvision_v2_available, logging
4037
from ...utils.import_utils import requires
4138
from .image_processing_conditional_detr import (
4239
compute_segments,
@@ -46,20 +43,9 @@
4643
)
4744

4845

49-
if is_torch_available():
50-
import torch
51-
52-
53-
if is_torch_available():
54-
from torch import nn
55-
56-
5746
if is_torchvision_v2_available():
58-
from torchvision.io import read_image
5947
from torchvision.transforms.v2 import functional as F
60-
61-
elif is_torchvision_available():
62-
from torchvision.io import read_image
48+
else:
6349
from torchvision.transforms import functional as F
6450

6551

0 commit comments

Comments
 (0)