|
2 | 2 | import inspect
|
3 | 3 | import sys
|
4 | 4 | from dataclasses import dataclass, fields
|
| 5 | +from functools import partial |
5 | 6 | from inspect import signature
|
6 | 7 | from types import ModuleType
|
7 | 8 | from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
|
@@ -37,6 +38,32 @@ class Weights:
|
37 | 38 | transforms: Callable
|
38 | 39 | meta: Dict[str, Any]
|
39 | 40 |
|
| 41 | + def __eq__(self, other: Any) -> bool: |
| 42 | + # We need this custom implementation for correct deep-copy and deserialization behavior. |
| 43 | + # TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it, |
| 44 | + # involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often |
| 45 | + # defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling |
| 46 | + # for it, the check against the defined members would fail and effectively prevent the weights from being |
| 47 | + # deep-copied or deserialized. |
| 48 | + # See https://github.com/pytorch/vision/pull/7107 for details. |
| 49 | + if not isinstance(other, Weights): |
| 50 | + return NotImplemented |
| 51 | + |
| 52 | + if self.url != other.url: |
| 53 | + return False |
| 54 | + |
| 55 | + if self.meta != other.meta: |
| 56 | + return False |
| 57 | + |
| 58 | + if isinstance(self.transforms, partial) and isinstance(other.transforms, partial): |
| 59 | + return ( |
| 60 | + self.transforms.func == other.transforms.func |
| 61 | + and self.transforms.args == other.transforms.args |
| 62 | + and self.transforms.keywords == other.transforms.keywords |
| 63 | + ) |
| 64 | + else: |
| 65 | + return self.transforms == other.transforms |
| 66 | + |
40 | 67 |
|
41 | 68 | class WeightsEnum(StrEnum):
|
42 | 69 | """
|
@@ -75,9 +102,6 @@ def __getattr__(self, name):
|
75 | 102 | return object.__getattribute__(self.value, name)
|
76 | 103 | return super().__getattr__(name)
|
77 | 104 |
|
78 |
| - def __deepcopy__(self, memodict=None): |
79 |
| - return self |
80 |
| - |
81 | 105 |
|
82 | 106 | def get_weight(name: str) -> WeightsEnum:
|
83 | 107 | """
|
|
0 commit comments