diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 7aaf14d773bb..94a32bcc07f8 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -939,6 +939,32 @@ def __init__( self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) + + self.original_sizes_class_imgs = [] + self.crop_top_lefts_class_imgs = [] + self.pixel_values_class_imgs = [] + self.class_images = [Image.open(path) for path in self.class_images_path] + for image in self.class_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + self.original_sizes_class_imgs.append((image.height, image.width)) + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + crop_top_left = (y1, x1) + self.crop_top_lefts_class_imgs.append(crop_top_left) + image = train_transforms(image) + self.pixel_values_class_imgs.append(image) + if class_num is not None: self.num_class_images = min(len(self.class_images_path), class_num) else: @@ -961,12 +987,9 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = self.pixel_values[index % self.num_instance_images] - original_size = self.original_sizes[index % self.num_instance_images] - crop_top_left = self.crop_top_lefts[index % self.num_instance_images] - example["instance_images"] = instance_image - example["original_size"] = original_size - example["crop_top_left"] = crop_top_left + example["instance_images"] = self.pixel_values[index % self.num_instance_images] + example["original_size"] = self.original_sizes[index % self.num_instance_images] + example["crop_top_left"] = self.crop_top_lefts[index % self.num_instance_images] if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -983,13 +1006,10 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) - class_image = exif_transpose(class_image) - - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - example["class_images"] = self.image_transforms(class_image) example["class_prompt"] = self.class_prompt + example["class_images"] = self.pixel_values_class_imgs[index % self.num_class_images] + example["class_original_size"] = self.original_sizes_class_imgs[index % self.num_class_images] + example["class_crop_top_left"] = self.crop_top_lefts_class_imgs[index % self.num_class_images] return example @@ -1005,6 +1025,8 @@ def collate_fn(examples, with_prior_preservation=False): if with_prior_preservation: pixel_values += [example["class_images"] for example in examples] prompts += [example["class_prompt"] for example in examples] + original_sizes += [example["class_original_size"] for example in examples] + crop_top_lefts += [example["class_crop_top_left"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()