Skip to content

Commit a3e4ea3

Browse files
author
Prashant Kumar
authored
Remove the dependency of the torchvision. (huggingface#858)
Remove the dependency of torchvision library for the conversion of tensor layout format to what PIL library expects.
1 parent 56f16d6 commit a3e4ea3

File tree

1 file changed

+2
-6
lines changed
  • shark/examples/shark_inference/stable_diffusion

1 file changed

+2
-6
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from transformers import CLIPTextModel, CLIPTokenizer
66
import torch
77
from PIL import Image
8-
import torchvision.transforms as T
98
from diffusers import (
109
LMSDiscreteScheduler,
1110
PNDMScheduler,
@@ -272,11 +271,8 @@ def end_profiling(device):
272271
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
273272
print(f"\nTotal image generation time: {total_time}sec")
274273

275-
transform = T.ToPILImage()
276-
pil_images = [
277-
transform(image)
278-
for image in torch.from_numpy(images).to(torch.uint8)
279-
]
274+
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
275+
pil_images = [Image.fromarray(image) for image in images.numpy()]
280276

281277
if args.output_dir is not None:
282278
output_path = Path(args.output_dir)

0 commit comments

Comments
 (0)