Skip to content

Commit c06d52b

Browse files
authored
properly support deepcopying and serialization of model weights (#7107)
1 parent 93df9a5 commit c06d52b

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

test/test_extended_models.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import os
3+
import pickle
34

45
import pytest
56
import test_models as TM
@@ -73,10 +74,32 @@ def test_get_model_weights(name, weight):
7374
],
7475
)
7576
def test_weights_copyable(copy_fn, name):
76-
model_weights = models.get_model_weights(name)
77-
for weights in list(model_weights):
78-
copied_weights = copy_fn(weights)
79-
assert copied_weights is weights
77+
for weights in list(models.get_model_weights(name)):
78+
# It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
79+
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
80+
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
81+
# support for the identity operation in the future.
82+
assert copy_fn(weights) is weights
83+
84+
85+
@pytest.mark.parametrize(
86+
"name",
87+
[
88+
"resnet50",
89+
"retinanet_resnet50_fpn_v2",
90+
"raft_large",
91+
"quantized_resnet50",
92+
"lraspp_mobilenet_v3_large",
93+
"mvit_v1_b",
94+
],
95+
)
96+
def test_weights_deserializable(name):
97+
for weights in list(models.get_model_weights(name)):
98+
# It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
99+
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
100+
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
101+
# support for the identity operation in the future.
102+
assert pickle.loads(pickle.dumps(weights)) is weights
80103

81104

82105
@pytest.mark.parametrize(

torchvision/models/_api.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import sys
44
from dataclasses import dataclass, fields
5+
from functools import partial
56
from inspect import signature
67
from types import ModuleType
78
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
@@ -37,6 +38,32 @@ class Weights:
3738
transforms: Callable
3839
meta: Dict[str, Any]
3940

41+
def __eq__(self, other: Any) -> bool:
42+
# We need this custom implementation for correct deep-copy and deserialization behavior.
43+
# TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
44+
# involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
45+
# defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
46+
# for it, the check against the defined members would fail and effectively prevent the weights from being
47+
# deep-copied or deserialized.
48+
# See https://github.com/pytorch/vision/pull/7107 for details.
49+
if not isinstance(other, Weights):
50+
return NotImplemented
51+
52+
if self.url != other.url:
53+
return False
54+
55+
if self.meta != other.meta:
56+
return False
57+
58+
if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
59+
return (
60+
self.transforms.func == other.transforms.func
61+
and self.transforms.args == other.transforms.args
62+
and self.transforms.keywords == other.transforms.keywords
63+
)
64+
else:
65+
return self.transforms == other.transforms
66+
4067

4168
class WeightsEnum(StrEnum):
4269
"""
@@ -75,9 +102,6 @@ def __getattr__(self, name):
75102
return object.__getattribute__(self.value, name)
76103
return super().__getattr__(name)
77104

78-
def __deepcopy__(self, memodict=None):
79-
return self
80-
81105

82106
def get_weight(name: str) -> WeightsEnum:
83107
"""

0 commit comments

Comments
 (0)