Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Feb 5, 2024

What does this PR do?

As per what's discussed in #6396, this PR adds support for "balanced" device map (and other variants) in the pipelines. It's NOT complete yet.

The major USP of this PR

Let's try to mimic a real use case. I want to test for the following situation.

Let's say a user has two consumer GPUs each having 8GBs of VRAM. Using enable_model_cpu_offload() might not work best here because -

  • it works only on a single GPU
  • a single model might not fit on a single GPU

enable_sequential_cpu_offload() might work but will be extremely slow. Here also we're limited to using a single GPU only.

So, this is probably one of the best situations where having the "balanced" device_map would be useful for the user because it will essentially give them a nice way to utilize both GPUs without many bells and whistles.

TODOs

  • Catch errors for unexpected configurations
  • Docs
  • Tests

Testing

unfold for the code
from diffusers import DiffusionPipeline
import argparse
import torch

def run_pipeline(args):
    if args.do_device_map:
        pipeline = DiffusionPipeline.from_pretrained(
            args.ckpt_id,
            variant="fp16",
            torch_dtype=torch.float16,
            device_map="balanced",
        )
        if hasattr(pipeline, "safety_checker"):
            pipeline.safety_checker = None

    else:
        pipeline = DiffusionPipeline.from_pretrained(
            args.ckpt_id,
            variant="fp16",
            torch_dtype=torch.float16,
        )
        if hasattr(pipeline, "safety_checker"):
            pipeline.safety_checker = None
        
        pipeline = pipeline.to("cuda")

    image = pipeline(
        "picture of a dog", num_inference_steps=args.num_inference_steps, generator=torch.manual_seed(0)
    ).images[0]
    image.save(f"resultant_image_{args.do_device_map}.png")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt_id", default="runwayml/stable-diffusion-v1-5", choices=["runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-xl-base-1.0"])
    parser.add_argument("--num_inference_steps", type=int, default=5)
    parser.add_argument("--do_device_map", action="store_true")
    args = parser.parse_args()
    run_pipeline(args)

Tested on the DGX (with accelerate installed from the source).

CUDA_VISIBLE_DEVICES=1,2 python test_device_map_pipelines.py  --num_inference_steps=50

VAE: tensor([0.2964, 0.2983, 0.3008, 0.2917, 0.3213, 0.3174, 0.3298, 0.3298, 0.2352,
        0.2367, 0.2539, 0.2510], device='cuda:0', dtype=torch.float16)
CUDA_VISIBLE_DEVICES=1,2 python test_device_map_pipelines.py  --num_inference_steps=50 --do_device_map

VAE: tensor([0.2964, 0.2983, 0.3008, 0.2917, 0.3213, 0.3174, 0.3298, 0.3298, 0.2352,
        0.2367, 0.2539, 0.2510], device='cuda:0', dtype=torch.float16)

We can see that the outputs are matching when using device mapping. There's a proper test suite in the PR as well.

Cc: @yiyixuxu for visibility. Immense thanks to @SunMarc for helping me with this every step of the way. Without Marc's support, this wouldn't have been possible.

@sayakpaul sayakpaul requested a review from SunMarc February 5, 2024 11:48
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

@SunMarc some more progress here.

I was able to determine a dictionary mapping the model components to the available GPU devices. Example:

{'unet': 0, 'text_encoder': 1, 'vae': 1}

Under these conditions:

{'unet': 27653351728, 'text_encoder': 3481712032, 'vae': 2594036992}
{0: 24978194432, 1: 24978194432}

(I hope the dictionaries are self-explanatory)

This device map is created here BEFORE the actual models are loaded.

This is because load_sub_model() passes a device_map. Inside that method, we determine the loading method for a given model, which will always resolve to from_pretrained().

From #6396 (comment):

Since each model are loaded on only one device, we don't add hooks by default. You need to set force_hook=True in load_checkpoint_and_dispatch. By doing that, we will add hooks that will move the data to the correct device when performing inference.

I think I can still pass a boolean indicator within load_sub_model() indicating that the force_hook argument should be set to True. But how do we handle the text encoders' case here? How can we let transformers know that it should set force_hook to True from the diffusers codebase?

Anyway the following seems to be working:

from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    variant="fp16",
    torch_dtype=torch.float16,
    device_map="auto",
    safety_checker=None,
)

for name, component in pipeline.components.items():
    if isinstance(component, torch.nn.Module):
        print(name, component.device)


_ = pipeline("picture of a dog", num_inference_steps=50)

Prints:

vae cuda:1
text_encoder cuda:1
unet cuda:0

But also throws:

image
You shouldn't move a model when it is dispatched on multiple devices.

However, it also produces:

/home/sayak/diffusers/src/diffusers/image_processor.py:90: RuntimeWarning: invalid value encountered in cast
  images = (images * 255).round().astype("uint8")

Could this be related to the device placement?

It is not the case when we do:

from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    variant="fp16",
    torch_dtype=torch.float16,
    safety_checker=None,
).to("cuda")

_ = pipeline("picture of a dog", num_inference_steps=50)

@sayakpaul
Copy link
Member Author

Been trying to debug the cause of the black image results stemming from device_map="auto". Here are some findings.

Without any device_map, we have (five denoising steps):

Prompt embeds: tensor([[-3.8843e-01,  2.2949e-02, -5.2338e-02, -1.8420e-01],
        [-3.7183e-01, -1.4492e+00, -3.3936e-01, -1.0754e-01],
        [-5.1074e-01, -1.4629e+00, -2.9272e-01,  1.2255e-03],
        [-5.5518e-01, -1.4248e+00, -2.8711e-01,  6.1646e-02]], device='cuda:0',
       dtype=torch.float16)

Initial latents: tensor([-0.9014,  0.1541,  0.2152, -0.6416,  1.0215, -0.3105, -1.4922,  0.0122,
        -0.9941,  0.6323,  0.5259,  0.1608,  0.9238, -1.2178,  0.4255, -1.7715],
       device='cuda:0', dtype=torch.float16)

UNet predictions 0: tensor([-0.9072,  0.1572,  0.2156, -0.6226,  1.0273, -0.2859, -1.4678,  0.0511,
        -1.0010,  0.6211,  0.4883,  0.1566,  0.9639, -1.2197,  0.4321, -1.7646],
       device='cuda:0', dtype=torch.float16)
UNet predictions 1: tensor([-0.9502,  0.1558,  0.1925, -0.6182,  0.9028, -0.2756, -1.4424,  0.0494,
        -1.0000,  0.5996,  0.4292,  0.1602,  1.0068, -1.1729,  0.4441, -1.6826],
       device='cuda:0', dtype=torch.float16)
UNet predictions 2: tensor([-0.9478,  0.1094,  0.1261, -0.6631,  0.9561, -0.3069, -1.5049,  0.0400,
        -0.9795,  0.6064,  0.4529,  0.1589,  1.0205, -1.1543,  0.4565, -1.6758],
       device='cuda:0', dtype=torch.float16)
UNet predictions 3: tensor([-0.8262,  0.1978,  0.0524, -0.5957,  0.9131, -0.1609, -1.4922,  0.1169,
        -1.0059,  0.6895,  0.3044,  0.2878,  1.0361, -0.9067,  0.3342, -1.4824],
       device='cuda:0', dtype=torch.float16)
UNet predictions 4: tensor([-1.0234,  0.4153, -0.0172, -0.6655,  0.7754, -0.0658, -1.7363,  0.0820,
        -1.3252,  1.0186,  0.2703,  0.2452,  0.5400, -0.4561, -0.0053, -1.4590],
       device='cuda:0', dtype=torch.float16)
UNet predictions 5: tensor([-0.3188,  0.1713, -0.0659, -0.3005,  0.2207, -0.0583, -0.3931,  0.0102,
        -0.2925,  0.0295, -0.4709,  0.0939,  0.2411, -0.3926,  0.0777, -0.2218],
       device='cuda:0', dtype=torch.float16)


VAE: tensor([ 0.0027,  0.0687,  0.1221,  0.1288, -0.0571,  0.0011,  0.0557,  0.0393,
        -0.1168, -0.0524,  0.0131,  0.0068], device='cuda:0',
       dtype=torch.float16)

With device_map="auto":

Prompt embeds: tensor([[-3.8843e-01,  2.2949e-02, -5.2338e-02, -1.8420e-01],
        [-3.7183e-01, -1.4492e+00, -3.3936e-01, -1.0754e-01],
        [-5.1074e-01, -1.4629e+00, -2.9272e-01,  1.2255e-03],
        [-5.5518e-01, -1.4248e+00, -2.8711e-01,  6.1646e-02]], device='cuda:1',
       dtype=torch.float16)

Initial latents: tensor([-0.9014,  0.1541,  0.2152, -0.6416,  1.0215, -0.3105, -1.4922,  0.0122,
        -0.9941,  0.6323,  0.5259,  0.1608,  0.9238, -1.2178,  0.4255, -1.7715],
       device='cuda:1', dtype=torch.float16)

UNet predictions 0: tensor([-0.3884, -0.2473, -0.2957, -0.0774, -0.6709,  1.2314, -0.7563,  0.9395,
        -0.7852, -0.3325, -1.0791, -0.0468, -0.4641, -1.5430, -0.6104, -1.1377],
       device='cuda:1', dtype=torch.float16)
UNet predictions 1: tensor([-0.9014,  0.1541,  0.2152, -0.6416,  1.0215, -0.3105, -1.4922,  0.0122,
        -0.9941,  0.6323,  0.5259,  0.1608,  0.9238, -1.2178,  0.4255, -1.7715],
       device='cuda:1', dtype=torch.float16)
UNet predictions 2: tensor([  0.8125,  -7.2891,   5.1875, -12.3125,   3.8184,  10.3203,   1.4805,
          5.5859,  -6.2812,  11.1172,  -2.3516,  12.0703,   3.5977, -18.9219,
          3.1055, -23.1250], device='cuda:1', dtype=torch.float16)
UNet predictions 3: tensor([ -0.0234,  -3.5703,   2.6973,  -6.4609,   2.3984,   5.0117,   0.0273,
          2.7969,  -3.6152,   5.8594,  -0.9258,   6.1094,   2.2402, -10.0469,
          1.7559, -12.4062], device='cuda:1', dtype=torch.float16)
UNet predictions 4: tensor([-1.7969,  4.6445, -2.8281,  6.5234, -0.8281, -6.7031, -3.0605, -3.3594,
         2.3633, -5.8125,  2.1855, -7.0938, -0.8398,  9.7031, -1.2617, 11.4375],
       device='cuda:1', dtype=torch.float16)
UNet predictions 5: tensor([-1.3057,  3.2754, -1.9756,  4.5508, -0.5312, -4.7227, -2.2227, -2.3555,
         1.6113, -4.0469,  1.5605, -4.9844, -0.5464,  6.7656, -0.8643,  7.9375],
       device='cuda:1', dtype=torch.float16)

VAE: tensor([ 0.0008,  0.0041, -0.0304, -0.0286, -0.0887, -0.1207, -0.1835, -0.2190,
        -0.2219, -0.2612, -0.3252, -0.3804], device='cuda:1',
       dtype=torch.float16)

We can clearly see that the predictions in the UNet start differing and in general, producing outputs having way higher norm than the case without device_map. Note that the prompt_embeds and the initial latents don't change (as is evident from the outputs).

SunMarc
SunMarc previously approved these changes Feb 13, 2024
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this ! I've left a few comments.

@SunMarc SunMarc self-requested a review February 13, 2024 21:16
@SunMarc
Copy link
Member

SunMarc commented Feb 13, 2024

We can clearly see that the predictions in the UNet start differing and in general, producing outputs having way higher norm than the case without device_map. Note that the prompt_embeds and the initial latents don't change (as is evident from the outputs).

This is strange. On my tests, the prompt_embeds doesn't change but the initial latents changes every time whether I'm using device_map or not.
Can you try with device_map but only on one device (need to remove the error that you are raising) ? If we indeed have the same results, the issue is probably due to the the data being moved across the gpus.

You shouldn't move a model when it is dispatched on multiple devices.

This warning should be fixed in this PR

/home/sayak/diffusers/src/diffusers/image_processor.py:90: RuntimeWarning: invalid value encountered in cast
images = (images * 255).round().astype("uint8")

Oh that's strange, I don't have this issue

I think I can still pass a boolean indicator within load_sub_model() indicating that the force_hook argument should be set to True. But how do we handle the text encoders' case here? How can we let transformers know that it should set force_hook to True from the diffusers codebase?

The simplest solution would be to add this arg in the from_pretrained method of transformers in a PR.

I'll try to dig deeper on why I'm not able to reproduce the results. Thanks again for your work !

@sayakpaul
Copy link
Member Author

@ArthurZucker need some help regarding #6857 (comment):

I think I can still pass a boolean indicator within load_sub_model() indicating that the force_hook argument should be set to True. But how do we handle the text encoders' case here? How can we let transformers know that it should set force_hook to True from the diffusers codebase?

@sayakpaul
Copy link
Member Author

This is strange. On my tests, the prompt_embeds doesn't change but the initial latents changes every time whether I'm using device_map or not.
Can you try with device_map but only on one device (need to remove the error that you are raising) ? If we indeed have the same results, the issue is probably due to the the data being moved across the gpus.

Your hunch was correct. I did device map but with a single device. The intermediate values matched. But not when using multiple GPUs. Different results because of data movement maybe acceptable to an extent where they're not leading black images in this case.

Oh that's strange, I don't have this issue

It's kind of a bit random for a smaller number of steps. When you max it to 50 steps (num_inference_steps), you should see it.

Let me know.

@SunMarc
Copy link
Member

SunMarc commented Feb 14, 2024

Your hunch was correct. I did device map but with a single device. The intermediate values matched. But not when using multiple GPUs. Different results because of data movement maybe acceptable to an extent where they're not leading black images in this case.

On my side, even without device_map, I get different latent space each time. Can you confirm that this is not the case on your side ?

@sayakpaul
Copy link
Member Author

On my side, even without device_map, I get different latent space each time. Can you confirm that this is not the case on your side ?

I can confirm that it is NOT the case on my end. Let me send over my testing script:

from diffusers import DiffusionPipeline
import argparse
import torch

def run_pipeline(args):
    if args.do_device_map:
        pipeline = DiffusionPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            variant="fp16",
            torch_dtype=torch.float16,
            device_map="auto",
            safety_checker=None,
        )

        for name, component in pipeline.components.items():
            if isinstance(component, torch.nn.Module):
                print(name, component.hf_device_map)
    else:
        pipeline = DiffusionPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            variant="fp16",
            torch_dtype=torch.float16,
            safety_checker=None,
        ).to("cuda")

    _ = pipeline("picture of a dog", num_inference_steps=args.num_inference_steps, generator=torch.manual_seed(0))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_inference_steps", type=int, default=5)
    parser.add_argument("--do_device_map", action="store_true")
    args = parser.parse_args()
    run_pipeline(args)

This should probably work because in the previous examples, there was no generator which controls the randomness aspect. LMK.

@sayakpaul
Copy link
Member Author

@SunMarc on my local clone of transformers I did pass force_hooks=True here:

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/modeling_utils.py#L3558

But it still didn't help prevent the NaN issues described above. Will appreciate some guidance.

@sayakpaul
Copy link
Member Author

@SunMarc a gentle ping here.

For example, if you have two 8GB GPUs, then using [`~DiffusionPipeline.enable_model_cpu_offload`] may not work so well because:

* it only works on a single GPU
* a single model might not fit on a single GPU ([`~DiffusionPipeline.enable_sequential_cpu_offload`] might work but it will be extremely slow and it is also limited to a single GPU)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I made a comment yesterday about this but can't find it now. How does this PR help when a model does not fit in a single GPU?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a GitHub bug, apparently. I responded to it but it got disappeared as well.

We offload to the CPU. In case there's not enough CPU memory, we error out.

The next alternative would be to split the model across devices - GPU, CPU, and disk. This is a separate PR: #6396.

@sayakpaul
Copy link
Member Author

@pcuenca not sure if you had the chance to test-drive yesterday.

@sayakpaul
Copy link
Member Author

Going to merge it once the CI is green.

@sayakpaul sayakpaul merged commit 3e4a6bd into main Apr 10, 2024
@sayakpaul sayakpaul deleted the pipeline-device-map-auto branch April 10, 2024 03:29
@MikeHanKK
Copy link

@sayakpaul so right now on the main branch, we cannot use device_map=auto because device_map has to be "balanced"? But at the same time, device_map=balanced is conflicting with enable_model_cpu_offload(), so that we cannot use enable_model_cpu_offload anymore?

@sayakpaul
Copy link
Member Author

But at the same time, device_map=balanced is conflicting with enable_model_cpu_offload(), so that we cannot use enable_model_cpu_offload anymore?

I don't think this is true. On a single GPU, you will still be able to use enable_model_cpu_offload(), the effectivity of device_map=balanced becomes evident in multi-GPU use cases.

@MikeHanKK
Copy link

MikeHanKK commented May 11, 2024

img_v3_02ao_804e0ea0-0ae4-46dd-b195-8d98488a607g
when device_map=balanced, it will report error using enable_model_cpu_offload().
@sayakpaul

@sayakpaul
Copy link
Member Author

Yeah that is expected. Don’t use balanced device map while using model offloading.

@MikeHanKK
Copy link

I was using device_map=auto when I used model offloading. But it seems "device_map=auto" is not supported any more. If I still want to model offloading, just don't set device_map at all? @sayakpaul

@sayakpaul
Copy link
Member Author

On a single device, yes.

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* get device <-> component mapping when using multiple gpus.

* condition the device_map bits.

* relax condition

* device_map progress.

* device_map enhancement

* some cleaning up and debugging

* Apply suggestions from code review

Co-authored-by: Marc Sun <[email protected]>

* incorporate suggestions from PR.

* remove multi-gpu condition for now.

* guard check the component -> device mapping

* fix: device_memory variable

* dispatching transformers model to have force_hooks=True

* better guarding for transformers device_map

* introduce support balanced_low_memory and balanced_ultra_low_memory.

* remove device_map patch.

* fix: intermediate variable scoping.

* fix: condition in cpu offload.

* fix: flax class restrictions.

* remove modifications from cpu_offload and model_offload

* incorporate changes.

* add a simple forward pass test

* add: torch_device in get_inputs()

* add: tests

* remove print

* safe-guard to(), model offloading and cpu offloading when balanced is used as a device_map.

* style

* remove .

* safeguard device_map with more checks and remove invalid device_mapping strategues.

* make  a class attribute and adjust tests accordingly.

* fix device_map check

* fix test

* adjust comment

* fix: device_map attribute

* fix: dispatching.

* max_memory test for pipeline

* version guard the tests

* fix guard.

* address review feedback.

* reset_device_map method.

* add: test for reset_hf_device_map

* fix a couple things.

* add reset_device_map() in the error message.

* add tests for checking reset_device_map doesn't have unintended consequences.

* fix reset_device_map and offloading tests.

* create _get_final_device_map utility.

* hf_device_map -> _hf_device_map

* add documentation

* add notes suggested by Marc.

* styling.

* Apply suggestions from code review

Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>

* move updates within gpu condition.

* other docs related things

* note on ignore a device not specified in .

* provide a suggestion if device mapping errors out.

* fix: typo.

* _hf_device_map -> hf_device_map

* Empty-Commit

* add: example hf_device_map.

---------

Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants