Skip to content

Conversation

@yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Mar 6, 2024

motivated by #6531

Create a stable diffusion pipeline with from_pretrained

from diffusers import DiffusionPipeline, StableDiffusionSAGPipeline, AnimateDiffPipeline, MotionAdapter, DDIMScheduler
from diffusers.utils import export_to_gif, load_image
import torch
import gc
from accelerate.utils import compute_module_sizes

def flush():
    gc.collect()
    torch.cuda.empty_cache()

def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024

base_repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
num_inference_steps = 50
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
prompt="bear eating pizza"
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"

# test1
print(" ")
print("test1: pipe_sd")
pipe_sd = DiffusionPipeline.from_pretrained(base_repo, torch_dtype=torch.float16)
pipe_sd.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe_sd.set_ip_adapter_scale(0.6)
pipe_sd.to("cuda")

generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
    prompt=prompt,
    negative_prompt=negative_prompt,
    ip_adapter_image=image,
    num_inference_steps=num_inference_steps,
    generator=generator,
).images[0]
out_sd.save("yiyi_test_4_out_1_sd.png")

flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 4.408166408538818 GB

yiyi_test_4_out_1_sd

test2: SD -> SAG

# test2
print(" ")
print("test2: pipe_sd -> pipe_sag")

pipe_sag = StableDiffusionSAGPipeline.from_pipe(
    pipe_sd,
    safety_checker=None,
)
# the pipe_sag already have ip-adapter loaded
generator = torch.Generator(device="cpu").manual_seed(33)
out_sag = pipe_sag(
    prompt = prompt,
    negative_prompt=negative_prompt,
    ip_adapter_image=image,
    num_inference_steps=num_inference_steps,
    generator=generator,
    guidance_scale=1.0,
    sag_scale=0.75).images[0]
out_sag.save("yiyi_test_4_out_2_sag.png")
flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 4.408166408538818 GB

yiyi_test_4_out_2_sag

test3: run SD again

# test3
print(" ")
print("test3: run pipe_sd again (should have same output as before)")
generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
    prompt=prompt,
    negative_prompt=negative_prompt,
    ip_adapter_image=image,
    num_inference_steps=num_inference_steps,
    generator=generator,
).images[0]
out_sd.save("yiyi_test_4_out_3_sd.png")
flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 4.408166408538818 GB

yiyi_test_4_out_3_sd

test4: run pipe_sd after pipe_sag.unload_ip_adapter()

# test4
print("")
print(f" test4: run pipe_sd after unload ip_adapter from pipe_sag (should get an error)")
pipe_sag.unload_ip_adapter()
try:
    generator = torch.Generator(device="cpu").manual_seed(33)
    out_sd = pipe_sd(
        prompt=prompt,
        negative_prompt=negative_prompt,
        ip_adapter_image=image,
        num_inference_steps=num_inference_steps,
        generator=generator,
    ).images[0]
except Exception as e:
    print(f"error: {e}")
error: 'NoneType' object has no attribute 'image_projection_layers'

test5: SD -> AnimateDiff

# test5
print(" ")
print("test5: pipe_sd -> pipe_animate")

adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)

pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")
# load ip_adapter again and load lora weights
pipe_animate.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
pipe_animate.to("cuda")

generator = torch.Generator(device="cpu").manual_seed(33)
pipe_animate.set_adapters("zoom-out", adapter_weights=0.75)
out = pipe_animate(
    prompt= prompt,
    num_frames=16,
    num_inference_steps=num_inference_steps,
    ip_adapter_image = image,
    generator=generator,
).frames[0]

export_to_gif(out, "yiyi_test_4_out_5_animate.gif")
flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 15.185057640075684 GB

yiyi_test_4_out_5_animate

test6: SD -> LPW

# test6
print(" ")
print("test6: pipe_sd -> pipe_lpw (community pipeline)")

pipe_lpw = DiffusionPipeline.from_pipe(
    pipe_sd,
    custom_pipeline="lpw_stable_diffusion",
).to("cuda")

prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
generator = torch.Generator(device="cpu").manual_seed(33)
out_lpw = pipe_lpw.text2img(
    prompt,
    negative_prompt=neg_prompt,
    width=512,height=512,
    max_embeddings_multiples=3,
    num_inference_steps=num_inference_steps,
    generator=generator,
    ).images[0]
out_lpw.save("yiyi_test_4_out_6_lpw.png")

flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 15.185057640075684 GB

yiyi_test_4_out_6_lpw

test7: run SD again

# test7
print(" ")
print("test7: pipe_sd")
generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
    prompt=prompt,
    negative_prompt=negative_prompt,
    generator=generator,
    num_inference_steps=num_inference_steps,
).images[0]
out_sd.save("yiyi_test_4_out_7_sd.png")
flush()
print(
    f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
Max memory allocated: 15.185057640075684 GB

yiyi_test_4_out_7_sd

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Mar 6, 2024

cc @vladmandic here
still WIP, but let me know what you think about the API and the use cases it covers
do you have any other specific use case in mind that I did not cover here?


# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
config = unet.config
config = dict(unet.config)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6
currently, we will modify the original 2d unet's config - even though we do not use it here, we create a new unet motion model instead

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

```
"""

if hasattr(pipeline, "_all_hooks") and len(pipeline._all_hooks) > 0:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the part we make sure if the pipeline has previous called enable_model_cpu_offload, it will still work properly with from_pipe

@vladmandic
Copy link
Contributor

thanks yiyi!
from high level, it seems to cover all main use cases. if there is something borderline, we can think of that later.
my two comments are:

  • target pipeline should inherit pipeline settings, not just components. e.g. model_cpu_offload and all other enable_() methods should be applied on target if they were applied on source.
  • testing. to make sure that target pipeline actually works when there are additional components added to it (e.g. different vae or pretty much anything). in my experience, this is where accelerate often breaks as it doesn't pull model components in time so you end up with runtimeerror cuda vs cpu.

@yiyixuxu
Copy link
Collaborator Author

thanks for the feedback! @vladmandic

to make sure that target pipeline actually works when there are additional components added to it

I have not run into any issues using enable_model_cpu_offload with additional components, and my tests are pretty extensive (we added a fast test to test all diffusers official pipeline that can use from_pipe), I think it is not a concern here because I remove all the hooks and reset the offload device in the beginning

target pipeline should inherit pipeline settings

I'm not so sure about this because:

  1. We allow adding and subtracting components with the from_pipe API, so the new pipeline may have different memory requirements, and the user may want different settings. I think it would be simpler to reset instead of inheriting the settings unless they always want to have the same settings for the new pipeline.
  2. not every pipeline has implemented all of these methods; e.g. in my testing, the LPW pipeline did not have the enable_model_cpu_offload method working correctly. This would more likely be an issue with community pipelines
  3. I agree it is less convenient if you have to re-apply settings but I don't think it makes too much difference

with this being said, I think it won't be hard to implement and I'm open to it if you all think it's more intuitive and convenient to let the new pipelines inherit settings. cc @pcuenca here too, let me know what you think!

@yiyixuxu yiyixuxu requested review from DN6 and sayakpaul March 21, 2024 19:47
@yiyixuxu
Copy link
Collaborator Author

cc @DN6 @sayakpaul for a final review
let me know what you think about this #7241 (comment) too
I'm slightly in favor of resetting the pipeline settings but I don't feel strongly either way

@vladmandic
Copy link
Contributor

thanks @yiyixuxu

re: pipeline settings inheritance - IMO it would be more convenient and expected since its a pipeline switch using loaded model components (all or some), but its not a deal breaker - from_pipe has massive value either way.

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>
if name in expected_modules and name not in passed_class_obj:
# for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
if (
not isinstance(component, torch.nn.Module)
Copy link
Collaborator

@DN6 DN6 Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe change this to see if it subclasses ModelMixin here?

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍🏽

@yiyixuxu yiyixuxu merged commit 7956c36 into main Apr 1, 2024
@yiyixuxu yiyixuxu deleted the from_pipe branch April 1, 2024 23:02
noskill pushed a commit to noskill/diffusers that referenced this pull request Apr 5, 2024
* add from_pipe



---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* add from_pipe



---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants