Skip to content

Commit 9540c29

Browse files
bottlerfacebook-github-bot
authored andcommitted
Make Module.__init__ automatic
Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else. Reviewed By: shapovalov Differential Revision: D42760349 fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
1 parent 97f8f9b commit 9540c29

29 files changed

+36
-87
lines changed

projects/implicitron_trainer/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry
212212
class XRayRenderer(BaseRenderer, torch.nn.Module):
213213
n_pts_per_ray: int = 64
214214
215-
# if there are other base classes, make sure to call `super().__init__()` explicitly
216215
def __post_init__(self):
217-
super().__init__()
218216
# custom initialization
219217
220218
def forward(

pytorch3d/implicitron/eval_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def evaluate_dbir_for_category(
130130
raise ValueError("Image size should be set in the dataset")
131131

132132
# init the simple DBIR model
133-
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
133+
model = ModelDBIR(
134134
render_image_width=image_size,
135135
render_image_height=image_size,
136136
bg_color=bg_color,

pytorch3d/implicitron/models/base_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
4949
# the training loop.
5050
log_vars: List[str] = field(default_factory=lambda: ["objective"])
5151

52-
def __init__(self) -> None:
53-
super().__init__()
54-
5552
def forward(
5653
self,
5754
*, # force keyword-only arguments

pytorch3d/implicitron/models/feature_extractor/feature_extractor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
1515
Base class for an extractor of a set of features from images.
1616
"""
1717

18-
def __init__(self):
19-
super().__init__()
20-
2118
def get_feat_dims(self) -> int:
2219
"""
2320
Returns:

pytorch3d/implicitron/models/feature_extractor/resnet_feature_extractor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
7878
feature_rescale: float = 1.0
7979

8080
def __post_init__(self):
81-
super().__init__()
8281
if self.normalize_image:
8382
# register buffers needed to normalize the image
8483
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):

pytorch3d/implicitron/models/generic_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
304304
)
305305

306306
def __post_init__(self):
307-
super().__init__()
308-
309307
if self.view_pooler_enabled:
310308
if self.image_feature_extractor_class_type is None:
311309
raise ValueError(

pytorch3d/implicitron/models/global_encoder/autodecoder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module):
2929
ignore_input: bool = False
3030

3131
def __post_init__(self):
32-
super().__init__()
33-
3432
if self.n_instances <= 0:
3533
raise ValueError(f"Invalid n_instances {self.n_instances}")
3634

pytorch3d/implicitron/models/global_encoder/global_encoder.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase):
2626
(`SequenceAutodecoder`).
2727
"""
2828

29-
def __init__(self) -> None:
30-
super().__init__()
31-
3229
def get_encoding_dim(self):
3330
"""
3431
Returns the dimensionality of the returned encoding.
@@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
6966
autodecoder: Autodecoder
7067

7168
def __post_init__(self):
72-
super().__init__()
7369
run_auto_creation(self)
7470

7571
def get_encoding_dim(self):
@@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
10399
time_divisor: float = 1.0
104100

105101
def __post_init__(self):
106-
super().__init__()
107102
self._harmonic_embedding = HarmonicEmbedding(
108103
n_harmonic_functions=self.n_harmonic_functions,
109104
append_input=self.append_input,

pytorch3d/implicitron/models/implicit_function/base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414

1515

1616
class ImplicitFunctionBase(ABC, ReplaceableBase):
17-
def __init__(self):
18-
super().__init__()
19-
2017
@abstractmethod
2118
def forward(
2219
self,

pytorch3d/implicitron/models/implicit_function/decoding_functions.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
4545
space and transforms it into the required quantity (for example density and color).
4646
"""
4747

48-
def __post_init__(self):
49-
super().__init__()
50-
5148
def forward(
5249
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
5350
) -> torch.Tensor:
@@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase):
8380
operation: DecoderActivation = DecoderActivation.IDENTITY
8481

8582
def __post_init__(self):
86-
super().__post_init__()
8783
if self.operation not in [
8884
DecoderActivation.RELU,
8985
DecoderActivation.SOFTPLUS,
@@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
163159
use_xavier_init: bool = True
164160

165161
def __post_init__(self):
166-
super().__init__()
167-
168162
try:
169163
last_activation = {
170164
DecoderActivation.RELU: torch.nn.ReLU(True),
@@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase):
284278
network: MLPWithInputSkips
285279

286280
def __post_init__(self):
287-
super().__post_init__()
288281
run_auto_creation(self)
289282

290283
def forward(

0 commit comments

Comments
 (0)