Skip to content

Commit e0d8c9e

Browse files
authored
Support for Offset Noise in examples (#2753)
* add noise offset * make style
1 parent 92e1164 commit e0d8c9e

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def parse_args():
297297
parser.add_argument(
298298
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
299299
)
300+
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
300301

301302
args = parser.parse_args()
302303
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -705,6 +706,12 @@ def collate_fn(examples):
705706

706707
# Sample noise that we'll add to the latents
707708
noise = torch.randn_like(latents)
709+
if args.noise_offset:
710+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
711+
noise += args.noise_offset * torch.randn(
712+
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
713+
)
714+
708715
bsz = latents.shape[0]
709716
# Sample a random timestep for each image
710717
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def parse_args():
333333
parser.add_argument(
334334
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
335335
)
336+
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
336337

337338
args = parser.parse_args()
338339
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -718,6 +719,12 @@ def collate_fn(examples):
718719

719720
# Sample noise that we'll add to the latents
720721
noise = torch.randn_like(latents)
722+
if args.noise_offset:
723+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
724+
noise += args.noise_offset * torch.randn(
725+
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
726+
)
727+
721728
bsz = latents.shape[0]
722729
# Sample a random timestep for each image
723730
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)

0 commit comments

Comments
 (0)