Skip to content

Commit 28557e0

Browse files
authored
Fix copypaste collate pickle issues (#6181)
1 parent d0d7058 commit 28557e0

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

references/detection/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
from transforms import SimpleCopyPaste
3636

3737

38+
def copypaste_collate_fn(batch):
39+
copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR)
40+
return copypaste(*utils.collate_fn(batch))
41+
42+
3843
def get_dataset(name, image_set, transform, data_path):
3944
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
4045
p, ds_fn, num_classes = paths[name]
@@ -194,11 +199,6 @@ def main(args):
194199
if args.data_augmentation != "lsj":
195200
raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")
196201

197-
copypaste = SimpleCopyPaste(resize_interpolation=InterpolationMode.BILINEAR, blending=True)
198-
199-
def copypaste_collate_fn(batch):
200-
return copypaste(*utils.collate_fn(batch))
201-
202202
train_collate_fn = copypaste_collate_fn
203203

204204
data_loader = torch.utils.data.DataLoader(

0 commit comments

Comments
 (0)