@@ -70,6 +70,7 @@ def _gen_instance_class(fields):
7070
7171 class _FieldType :
7272 def __init__ (self , name , type_ ):
73+ assert isinstance (name , str ), f"Field name must be str, got { name } "
7374 self .name = name
7475 self .type_ = type_
7576 self .annotation = f"{ type_ .__module__ } .{ type_ .__name__ } "
@@ -86,11 +87,13 @@ def indent(level, s):
8687
8788 cls_name = "Instances_patched{}" .format (_counter )
8889
90+ field_names = tuple (x .name for x in fields )
8991 lines .append (
9092 f"""
9193class { cls_name } :
9294 def __init__(self, image_size: Tuple[int, int]):
9395 self.image_size = image_size
96+ self._field_names = { field_names }
9497"""
9598 )
9699
@@ -155,23 +158,6 @@ def has(self, name: str) -> bool:
155158"""
156159 )
157160
158- # support an additional method `from_instances` to convert from the original Instances class
159- lines .append (
160- f"""
161- @torch.jit.unused
162- @staticmethod
163- def from_instances(instances: Instances) -> "{ cls_name } ":
164- fields = instances.get_fields()
165- image_size = instances.image_size
166- new_instances = { cls_name } (image_size)
167- for name, val in fields.items():
168- assert hasattr(new_instances, '_{{}}'.format(name)), \\
169- "No attribute named {{}} in { cls_name } ".format(name)
170- setattr(new_instances, name, deepcopy(val))
171- return new_instances
172- """
173- )
174-
175161 # support method `to`
176162 lines .append (
177163 f"""
@@ -197,6 +183,34 @@ def to(self, device: torch.device) -> "{cls_name}":
197183 return ret
198184"""
199185 )
186+
187+ # support additional methods `from_instances` and `to_instances` to
188+ # convert from/to the original Instances class
189+ lines .append (
190+ f"""
191+ @torch.jit.unused
192+ @staticmethod
193+ def from_instances(instances: Instances) -> "{ cls_name } ":
194+ fields = instances.get_fields()
195+ image_size = instances.image_size
196+ new_instances = { cls_name } (image_size)
197+ for name, val in fields.items():
198+ assert hasattr(new_instances, '_{{}}'.format(name)), \\
199+ "No attribute named {{}} in { cls_name } ".format(name)
200+ setattr(new_instances, name, deepcopy(val))
201+ return new_instances
202+
203+ @torch.jit.unused
204+ def to_instances(self):
205+ ret = Instances(self.image_size)
206+ for name in self._field_names:
207+ val = getattr(self, "_" + name, None)
208+ if val is not None:
209+ ret.set(name, deepcopy(val))
210+ return ret
211+ """
212+ )
213+
200214 return cls_name , os .linesep .join (lines )
201215
202216
0 commit comments