Skip to content

Commit 8e3effc

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
add to_instances helper method to scripted Instances class
Summary: Pull Request resolved: #2232 Reviewed By: theschnitz Differential Revision: D24753888 fbshipit-source-id: 424bfc7ab0cf085e99333dbb9475d9cb32cd567c
1 parent 93f0f36 commit 8e3effc

File tree

2 files changed

+44
-17
lines changed

2 files changed

+44
-17
lines changed

detectron2/export/torchscript_patch.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def _gen_instance_class(fields):
7070

7171
class _FieldType:
7272
def __init__(self, name, type_):
73+
assert isinstance(name, str), f"Field name must be str, got {name}"
7374
self.name = name
7475
self.type_ = type_
7576
self.annotation = f"{type_.__module__}.{type_.__name__}"
@@ -86,11 +87,13 @@ def indent(level, s):
8687

8788
cls_name = "Instances_patched{}".format(_counter)
8889

90+
field_names = tuple(x.name for x in fields)
8991
lines.append(
9092
f"""
9193
class {cls_name}:
9294
def __init__(self, image_size: Tuple[int, int]):
9395
self.image_size = image_size
96+
self._field_names = {field_names}
9497
"""
9598
)
9699

@@ -155,23 +158,6 @@ def has(self, name: str) -> bool:
155158
"""
156159
)
157160

158-
# support an additional method `from_instances` to convert from the original Instances class
159-
lines.append(
160-
f"""
161-
@torch.jit.unused
162-
@staticmethod
163-
def from_instances(instances: Instances) -> "{cls_name}":
164-
fields = instances.get_fields()
165-
image_size = instances.image_size
166-
new_instances = {cls_name}(image_size)
167-
for name, val in fields.items():
168-
assert hasattr(new_instances, '_{{}}'.format(name)), \\
169-
"No attribute named {{}} in {cls_name}".format(name)
170-
setattr(new_instances, name, deepcopy(val))
171-
return new_instances
172-
"""
173-
)
174-
175161
# support method `to`
176162
lines.append(
177163
f"""
@@ -197,6 +183,34 @@ def to(self, device: torch.device) -> "{cls_name}":
197183
return ret
198184
"""
199185
)
186+
187+
# support additional methods `from_instances` and `to_instances` to
188+
# convert from/to the original Instances class
189+
lines.append(
190+
f"""
191+
@torch.jit.unused
192+
@staticmethod
193+
def from_instances(instances: Instances) -> "{cls_name}":
194+
fields = instances.get_fields()
195+
image_size = instances.image_size
196+
new_instances = {cls_name}(image_size)
197+
for name, val in fields.items():
198+
assert hasattr(new_instances, '_{{}}'.format(name)), \\
199+
"No attribute named {{}} in {cls_name}".format(name)
200+
setattr(new_instances, name, deepcopy(val))
201+
return new_instances
202+
203+
@torch.jit.unused
204+
def to_instances(self):
205+
ret = Instances(self.image_size)
206+
for name in self._field_names:
207+
val = getattr(self, "_" + name, None)
208+
if val is not None:
209+
ret.set(name, deepcopy(val))
210+
return ret
211+
"""
212+
)
213+
200214
return cls_name, os.linesep.join(lines)
201215

202216

tests/structures/test_instances.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,19 @@ def forward(self, x: Instances):
151151
x.a = box_tensors
152152
script_module(x)
153153

154+
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
155+
def test_from_to_instances(self):
156+
orig = Instances((30, 30))
157+
orig.proposal_boxes = Boxes(torch.rand(3, 4))
158+
159+
fields = {"proposal_boxes": Boxes, "a": Tensor}
160+
with patch_instances(fields) as NewInstances:
161+
# convert to NewInstances and back
162+
new1 = NewInstances.from_instances(orig)
163+
new2 = new1.to_instances()
164+
self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new1.proposal_boxes.tensor))
165+
self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new2.proposal_boxes.tensor))
166+
154167

155168
if __name__ == "__main__":
156169
unittest.main()

0 commit comments

Comments
 (0)