Skip to content

Commit e5d5acb

Browse files
authored
Remove torchvision requirements from web (huggingface#860)
1 parent 00e38ab commit e5d5acb

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

web/models/stable_diffusion/main.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import os
33
from PIL import Image
4-
import torchvision.transforms as T
54
from tqdm.auto import tqdm
65
from models.stable_diffusion.cache_objects import model_cache
76
from models.stable_diffusion.stable_args import args
@@ -268,10 +267,8 @@ def stable_diff_inf(
268267
print(f"\nTotal image generation time: {total_time}sec")
269268

270269
# generate outputs to web.
271-
transform = T.ToPILImage()
272-
pil_images = [
273-
transform(image) for image in torch.from_numpy(images).to(torch.uint8)
274-
]
270+
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
271+
pil_images = [Image.fromarray(image) for image in images.numpy()]
275272

276273
text_output = f"prompt={args.prompts}"
277274
text_output += f"\nnegative prompt={args.negative_prompts}"

0 commit comments

Comments
 (0)