Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 30, 2023

What does this PR do?

Adds utilities to support _no_split_modules to the ModelMixin. Closely follows what's done in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py.

Part of #6240.

I think it's better to tackle the introduction of device_map="auto" to pipelines in multiple PRs. @SunMarc laid out a very nice plan here (internal Slack link).

TODO

  • Get initial reviews from an accelerate core maintainer
  • Propagate to other important models inheriting ModelMixin
  • Add tests
  • Docs (if needed)

@sayakpaul sayakpaul marked this pull request as draft December 30, 2023 05:05
Comment on lines -724 to -588
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why we used to pass this but these are NOT used in configuration_utils.py anywhere. Given that, I think they are best removed:

  • No unwanted cognitive burden in thinking about what these are doing for configuration parsing.
  • Reduces LoC (albeit small)

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this @sayakpaul ! This is exactly what I was thinking ! Let's first make it work on diffusers models and extending it to pipeline should be straightforward !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

@SunMarc so, I incorporated the changes and tested with:

from diffusers import UNet2DConditionModel

unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto")
print(unet.hf_device_map)

It prints:

{'': 0}

I tested this on a single GPU. Does this look correct?

@patrickvonplaten I have gone through the structures but would appreciate a confirmation if BasicTransformerBlock and ResnetBlock2D are indeed the only blocks that contain a residual path in their forward() method (consider the base model is an SDXL UNet).

@SunMarc
Copy link
Member

SunMarc commented Jan 10, 2024

I tested this on a single GPU. Does this look correct?

Yes, it looks correct. Try to play with multiple gpu and if you are able to run the model correctly since users uses device_map to split the model on multiple gpus.

@sayakpaul
Copy link
Member Author

Try to play with multiple gpu and if you are able to run the model correctly since users uses device_map to split the model on multiple gpus.

Do you mean using the same code example but on multiple GPUs? How should the inputs be constructed, then? How should we handle device placement for them?

@sayakpaul
Copy link
Member Author

sayakpaul commented Jan 11, 2024

@SunMarc I tried on two GPUs. Here are some findings.

Test code
from diffusers import UNet2DConditionModel
import torch 

unet = UNet2DConditionModel.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="sequential"
)
print(unet.hf_device_map)

# Inputs
sample = torch.randn(1, 4, 128, 128).to("cuda")
t = torch.randint(1, 1000, size=(1, )).to("cuda")
encoder_hidden_states = torch.randn(1, 77, 2048).to("cuda")
add_text_embeds = torch.randn(1, 1280).to("cuda")
add_time_ids = torch.randn(1, 6).to("cuda")
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

# Forward
with torch.no_grad():
    outputs = unet(
        sample=sample,
        timestep=t,
        encoder_hidden_states=encoder_hidden_states,
        added_cond_kwargs=added_cond_kwargs
    ).sample
    print(outputs.shape)

With ["BasicTransformerBlock", "ResnetBlock2D"] specified in _no_split_modules of UNet2DConditionModel, it leads to the following device map:

{'conv_in': 0, 'time_proj': 0, 'time_embedding': 0, 'add_time_proj': 0, 'add_embedding': 0, 'down_blocks': 0, 'up_blocks.0.attentions.0': 0, 'up_blocks.0.attentions.1.norm': 0, 'up_blocks.0.attentions.1.proj_in': 0, 'up_blocks.0.attentions.1.transformer_blocks.0': 0, 'up_blocks.0.attentions.1.transformer_blocks.1': 1, 'up_blocks.0.attentions.1.transformer_blocks.2': 1, 'up_blocks.0.attentions.1.transformer_blocks.3': 1, 'up_blocks.0.attentions.1.transformer_blocks.4': 1, 'up_blocks.0.attentions.1.transformer_blocks.5': 1, 'up_blocks.0.attentions.1.transformer_blocks.6': 1, 'up_blocks.0.attentions.1.transformer_blocks.7': 1, 'up_blocks.0.attentions.1.transformer_blocks.8': 1, 'up_blocks.0.attentions.1.transformer_blocks.9': 1, 'up_blocks.0.attentions.1.proj_out': 1, 'up_blocks.0.attentions.2': 1, 'up_blocks.0.resnets': 1, 'up_blocks.0.upsamplers': 1, 'up_blocks.1': 1, 'up_blocks.2': 1, 'mid_block': 1, 'conv_norm_out': 1, 'conv_act': 1, 'conv_out': 1}

However, it leads to the following error:

Traceback (most recent call last):
  File "/home/sayak/diffusers/test_single_file.py", line 19, in <module>
    outputs = unet(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/unet_2d_condition.py", line 1197, in forward
    sample = upsample_block(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/unet_2d_blocks.py", line 2324, in forward
    hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument tensors in method wrapper_CUDA_cat)

"CrossAttnUpBlock2D" is the block that causes this and when added to _no_split_modules alongside ["BasicTransformerBlock", "ResnetBlock2D"], the error went away and I was able to obtain the output. The device map prints as follows:

{'': 0}

Seems like nothing is being split, which I think is the expected result here?

@SunMarc
Copy link
Member

SunMarc commented Jan 11, 2024

Do you mean using the same code example but on multiple GPUs? How should the inputs be constructed, then? How should we handle device placement for them?

The inputs will be automatically dispatched to the right device because accelerate adds hooks for that to the modules.

"CrossAttnUpBlock2D" is the block that causes this and when added to _no_split_modules alongside ["BasicTransformerBlock", "ResnetBlock2D"], the error went away and I was able to obtain the output. The device map prints as follows:

{'': 0}
Seems like nothing is being split, which I think is the expected result here?

No that should be the case since we want the model to be split. The results we should get is something like:

{'conv_in': 0, 'time_proj': 0, 'time_embedding': 0, 'add_time_proj': 0, 'add_embedding': 0, 'down_blocks': 0, 'up_blocks.0.attentions.0': 0, 'up_blocks.0.attentions.1': 1, 'up_blocks.0.attentions.2': 1, 'up_blocks.0.resnets': 1, 'up_blocks.0.upsamplers': 1, 'up_blocks.1': 1, 'up_blocks.2': 1, 'mid_block': 1, 'conv_norm_out': 1, 'conv_act': 1, 'conv_out': 1}

In the previous example, the inference failed since the CrossAttnUpBlock2D is concatenating hidden_states that are coming from different devices. I suspect the problem comes from this mapping which splits the attention block. So indeed, we should add CrossAttnUpBlock2D inside _no_split_modules. Another way would be to make that that hidden_states, res_hidden_states are on the same device but I prefer not to add anything in the modeling code :

{'up_blocks.0.attentions.1.norm': 0, 'up_blocks.0.attentions.1.proj_in': 0, 'up_blocks.0.attentions.1.transformer_blocks.0': 0, 'up_blocks.0.attentions.1.transformer_blocks.1': 1, 'up_blocks.0.attentions.1.transformer_blocks.2': 1, 'up_blocks.0.attentions.1.transformer_blocks.3': 1, 'up_blocks.0.attentions.1.transformer_blocks.4': 1, 'up_blocks.0.attentions.1.transformer_blocks.5': 1, 'up_blocks.0.attentions.1.transformer_blocks.6': 1, 'up_blocks.0.attentions.1.transformer_blocks.7': 1, 'up_blocks.0.attentions.1.transformer_blocks.8': 1, 'up_blocks.0.attentions.1.transformer_blocks.9': 1, 'up_blocks.0.attentions.1.proj_out': 1}

@sayakpaul
Copy link
Member Author

Thanks for providing your inputs.

Another way would be to make that that hidden_states, res_hidden_states are on the same device but I prefer not to add anything in the modeling code :

Indeed this should be preferred. We don't want to touch the forward call until and unless absolutely necessary.

I suspect the problem comes from this mapping which splits the attention block. So indeed, we should add CrossAttnUpBlock2D inside _no_split_modules.

But when I did that the model doesn't seem to split though. What are we missing here? Would you be able to take deeper look or provide me pointers to see this through further?

@SunMarc
Copy link
Member

SunMarc commented Jan 12, 2024

I've traced back to the issue. It is an issue on accelerate where the memory allocation + module placement is not very good when we have models where the largest non splittable layer is very big compared to the whole model. In our case, by specifying CrossAttnUpBlock2D , the module up_blocks.0 become non splittable and the fact that it represent half of the memory (5GB out of 10GB) and we get a bad module placement. This is why I was recommending to have smaller non splittable blocks. Nevertheless, this is what needs to be added into _no_split_modules if we don't want to modify the modeling file.
I can try to fix it in accelerate but I might require quite some time since it can impacting all models on transformers depending on the fix. This model is pretty small, so it will fit in one gpu. To continue with the PR, can you try other model by adding the _no_split_modules ? This way, we can try to see if this is a recurrent issue or not.

I forgot to mention but you can also put your own device_map to check if the inference works for a specific placement since the generated device_map is not optimal. For example, this device map works with the UNet2DConditionModel .
It shows that you indeed need to have the up_blocks non split.

device_map = {
    "conv_in": 0,
    "time_proj": 0,
    "time_embedding": 0,
    "add_time_proj": 0,
    "add_embedding": 0,
    "down_blocks": 0,
    "up_blocks.0": 0, 
    "up_blocks.1": 1,
    "up_blocks.2": 1,
    "mid_block": 1,
    "conv_norm_out": 1,
    "conv_act": 1,
    "conv_out": 1,
}

@sayakpaul
Copy link
Member Author

Nevertheless, this is what needs to be added into _no_split_modules if we don't want to modify the modeling file.

I think we definitely don't want to change the modeling code following what transformers does.

I will try on other models and maybe even on a smaller GPU. The smallest I have access to is 16GB.

@sayakpaul
Copy link
Member Author

sayakpaul commented Jan 15, 2024

@SunMarc seems like a good progress now.

Since I am trying on a machine having two 4090s, tried the following to restrict the memory so that device_map takes effect:

unet = UNet2DConditionModel.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    subfolder="unet", 
    device_map="auto",
    max_memory={0: "6GiB", 1: "10GiB"},
)
print(unet.hf_device_map)

Worked like a charm!

The device map:

{'conv_in': 0, 'time_proj': 0, 'time_embedding': 0, 'add_time_proj': 0, 'add_embedding': 0, 'down_blocks.0': 0, 'down_blocks.1': 0, 'down_blocks.2.attentions.0.norm': 0, 'down_blocks.2.attentions.0.proj_in': 0, 'down_blocks.2.attentions.0.transformer_blocks.0': 0, 'down_blocks.2.attentions.0.transformer_blocks.1': 0, 'down_blocks.2.attentions.0.transformer_blocks.2': 0, 'down_blocks.2.attentions.0.transformer_blocks.3': 0, 'down_blocks.2.attentions.0.transformer_blocks.4': 0, 'down_blocks.2.attentions.0.transformer_blocks.5': 0, 'down_blocks.2.attentions.0.transformer_blocks.6': 0, 'down_blocks.2.attentions.0.transformer_blocks.7': 0, 'down_blocks.2.attentions.0.transformer_blocks.8': 0, 'down_blocks.2.attentions.0.transformer_blocks.9': 1, 'down_blocks.2.attentions.0.proj_out': 1, 'down_blocks.2.attentions.1': 1, 'down_blocks.2.resnets': 1, 'up_blocks': 1, 'mid_block': 1, 'conv_norm_out': 1, 'conv_act': 1, 'conv_out': 1}

I have also added two tests closely following this and this. Have tested it too with the following:

RUN_SLOW=1 pytest tests/models/test_models_unet_2d_condition.py -k "offload"

I think we can add docs after we ship this feature to pipelines because that provides a fuller context.

Meanwhile, could you go through the PR once in detail and let me know your thoughts? Once that's done, will add _no_split_modules to other models and mark it ready for review.

Cc: @DN6 for awareness.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the cpu/disk offload tests ! Can you also add the multi-gpu test ?

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 25, 2024
@sayakpaul
Copy link
Member Author

Not stale. This PR serves as a valuable reference for models that would need splitting.

@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Feb 26, 2024
@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 22, 2024
@sayakpaul sayakpaul removed the stale Issues that haven't received updates label Mar 23, 2024
@pcuenca pcuenca added the wip label Apr 2, 2024
@sayakpaul
Copy link
Member Author

@yiyixuxu I think this PR is ready for a review now.

To test its somewhat extremes, I did the following:

Created a ~8B Transformer variant:

import torch 
from accelerate import init_empty_weights


with init_empty_weights():
    pixart_transformer = Transformer2DModel.from_config("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer")

actual_bigger_transformer = Transformer2DModel.from_config(
    pixart_transformer.config, num_layers=72, num_attention_heads=36, cross_attention_dim=2592
)
actual_bigger_transformer.save_pretrained("/raid/.cache/actual_bigger_transformer")

Has about 7.8B parameters and takes ~29.135GB to store.

I then ran the model like so:

from diffusers import Transformer2DModel
import tempfile
import torch
import os

def get_inputs():
    sample = torch.randn(1, 4, 128, 128)
    timestep = torch.randint(0, 1000, size=(1, ))
    encoder_hidden_states = torch.randn(1, 120, 4096)

    resolution = torch.tensor([1024, 1024]).repeat(1, 1)
    aspect_ratio = torch.tensor([1.]).repeat(1, 1)
    added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
    return sample, timestep, encoder_hidden_states, added_cond_kwargs

with torch.no_grad():
    max_memory = {0: "15GB"} # reasonable estimate for a consumer-gpu.
    with tempfile.TemporaryDirectory() as tmp_dir:
        new_model = Transformer2DModel.from_pretrained(
            "/raid/.cache/actual_bigger_transformer", 
            device_map="auto",
            max_memory=max_memory, 
            offload_folder=os.path.join("/raid/.cache/huggingface", tmp_dir)
        )

        sample, timestep, encoder_hidden_states, added_cond_kwargs = get_inputs()
        out = new_model(
            hidden_states=sample,
            encoder_hidden_states=encoder_hidden_states,
            timestep=timestep, 
            added_cond_kwargs=added_cond_kwargs
        ).sample
        print(out.shape)

It successfully runs.

Happy to add it to the test suite if you want.

Would be quite nice to also support:

  • Sharding of big checkpoints
  • Loading of shared checkpoints

Since we will have bigger models, I think it makes sense to at least support these because downloading a single big checkpoint is quite messy. Both of the above can be easily done with accelerate.

Ccing @SunMarc for visibility and awareness. If you have any comments, feel free to do so.

@sayakpaul
Copy link
Member Author

@yiyixuxu a gentle ping here.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Thanks for testing this PR on a real example. As for a sharding, it makes sense to support it but maybe in a follow up PR. For the sharding, you can take inspiration from transformers save_pretrained function or save_model from accelerate where I tried to mimic save_pretrained but removing specific code to transformers. I think it will be better to add your own logic for flexibility. As for the loading, sharded checkpoint is supported in load_checkpoint_in_model.

@yiyixuxu
Copy link
Collaborator

so we only support this on the model level, not the pipeline level, right?

@sayakpaul
Copy link
Member Author

Yes. This is because pipeline-level device mapping strategy and model-level device mapping strategy are conceptually different from one another.

Once this feature is in, I will work on the pipeline-level device-mapping strategy to facilitate the change.

@sayakpaul
Copy link
Member Author

I am going to merge this without docs because in a follow-up PR, I am going to add support for serializing shared checkpoints and tests to ensure we can load them.

@sayakpaul sayakpaul merged commit 3fd31ee into main Apr 30, 2024
@sayakpaul sayakpaul deleted the feat/device-map-auto branch April 30, 2024 03:16
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* introduce _no_split_modules.

* unnecessary spaces.

* remove unnecessary kwargs and style

* fix: accelerate imports.

* change to _determine_device_map

* add the blocks that have residual connections.

* add: CrossAttnUpBlock2D

* add: testin

* style

* line-spaces

* quality

* add disk offload test without safetensors.

* checking disk offloading percentages.

* change model split

* add: utility for checking multi-gpu requirement.

* model parallelism test

* splits.

* splits.

* splits

* splits.

* splits.

* splits.

* offload folder to test_disk_offload_with_safetensors

* add _no_split_modules

* fix-copies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants