From dc7519652c186bb2100452ba023b707a98beba73 Mon Sep 17 00:00:00 2001 From: haofanwang Date: Mon, 20 Mar 2023 21:50:53 +0800 Subject: [PATCH 1/2] add noise offset --- examples/text_to_image/train_text_to_image.py | 5 +++++ examples/text_to_image/train_text_to_image_lora.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 06a847e6ca61..43b947ad99a7 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -297,6 +297,7 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -705,6 +706,10 @@ def collate_fn(examples): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 43bbd8ebf415..d67f010d0c91 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -333,6 +333,7 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -718,6 +719,10 @@ def collate_fn(examples): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) From fb24f8ff74b589483afeba1e03f8ea836eddc273 Mon Sep 17 00:00:00 2001 From: haofanwang Date: Thu, 23 Mar 2023 11:38:01 +0800 Subject: [PATCH 2/2] make style --- examples/text_to_image/train_text_to_image.py | 4 +++- examples/text_to_image/train_text_to_image_lora.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 43b947ad99a7..6139a0e6514d 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -708,7 +708,9 @@ def collate_fn(examples): noise = torch.randn_like(latents) if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) bsz = latents.shape[0] # Sample a random timestep for each image diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index d67f010d0c91..3b54cc286663 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -721,7 +721,9 @@ def collate_fn(examples): noise = torch.randn_like(latents) if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) bsz = latents.shape[0] # Sample a random timestep for each image