Skip to content

Commit 9b6a6db

Browse files
Apply suggestions from code review
Co-authored-by: Will Berman <[email protected]>
1 parent 06c4a65 commit 9b6a6db

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ def collate_fn(examples):
907907
original_image = download_image(args.val_image_url)
908908
edited_images = []
909909
with torch.autocast(
910-
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
910+
"cuda", enabled=accelerator.mixed_precision == "fp16"
911911
):
912912
for _ in range(args.num_validation_images):
913913
edited_images.append(
@@ -963,7 +963,7 @@ def collate_fn(examples):
963963
if args.validation_prompt is not None:
964964
edited_images = []
965965
pipeline = pipeline.to(accelerator.device)
966-
with torch.autocast(str(accelerator.device).replace(":0", "")):
966+
with torch.autocast("cuda"):
967967
for _ in range(args.num_validation_images):
968968
edited_images.append(
969969
pipeline(

0 commit comments

Comments
 (0)