Skip to content

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Jun 6, 2023

refactored stable diffusion 4x upscaler( see discussion here #3654)

test

import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch
import numpy as np

# load model and scheduler
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

# let's download an  image
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
response = requests.get(url)
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
low_res_img = low_res_img.resize((128, 128))

prompt = "a white cat"

generator = torch.Generator(device="cuda").manual_seed(0)
upscaled_image = pipeline(prompt=prompt, image=low_res_img, generator=generator).images[0]
upscaled_image.save("upsampled_cat.png")


# test pt [0,1]


low_res_img_np = np.array(low_res_img).astype(np.float32) / 255.0
low_res_img_pt = torch.from_numpy(low_res_img_np.transpose(2,0,1))[None,:]
print(low_res_img_pt.min(), low_res_img_pt.max(), low_res_img_pt.shape)


generator = torch.Generator(device="cuda").manual_seed(0)
upscaled_image = pipeline(prompt=prompt, image=low_res_img_pt, generator=generator).images[0]
upscaled_image.save("upsampled_cat_pt.png")


# test pt[-1,1]
low_res_img_pt = low_res_img_pt * 2 -1
print(low_res_img_pt.min(), low_res_img_pt.max(), low_res_img_pt.shape)


generator = torch.Generator(device="cuda").manual_seed(0)
upscaled_image = pipeline(prompt=prompt, image=low_res_img_pt, generator=generator).images[0]
upscaled_image.save("upsampled_cat_pt_neg2_to_1.png")

upsampled_cat
upsampled_cat_pt
upsampled_cat_pt_neg2_to_1

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 6, 2023

The documentation is not available anymore as the PR was closed or merged.

@yiyixuxu yiyixuxu requested a review from patrickvonplaten June 6, 2023 18:08
@patrickvonplaten
Copy link
Contributor

Great!

@patrickvonplaten patrickvonplaten merged commit 017ee16 into main Jun 6, 2023
@kashif kashif deleted the image-processor-4xupscaler branch June 7, 2023 09:41
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* refactor x4 upscaler

* style

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* refactor x4 upscaler

* style

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants