Skip to content

Conversation

@sayakpaul
Copy link
Member

What does this PR do?

Follow-up of #6396.

This PR adds support for saving a big model's state dict into multiple shards for efficient portability and loading. Adds support for loading the sharded checkpoints, too.

This is much akin to handling big models like T5XXL.

Also, added a nice test to ensure the models that have _no_split_modules specified can be sharded and loaded back to perform inference ensuring numerical assertions.

Here's a real use-case. Consider this Transformer2DModel checkpoint: https://huggingface.co/sayakpaul/actual_bigger_transformer/.

It was serialized like so:

from diffusers import Transformer2DModel
from accelerate.utils import compute_module_sizes, shard_checkpoint
from accelerate import init_empty_weights
import torch.nn as nn

def bytes_to_giga_bytes(bytes):
    return f"{(bytes / 1024 / 1024 / 1024):.3f}"

with init_empty_weights():
    pixart_transformer = Transformer2DModel.from_config("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer")
    bigger_transformer = Transformer2DModel.from_config(
        pixart_transformer.config, num_layers=72, num_attention_heads=36, cross_attention_dim=2592,
    )
    module_size = bytes_to_giga_bytes(compute_module_sizes(bigger_transformer)[""])
    print(f"{module_size=} GB")
    pytorch_total_params = sum(p.numel() for p in bigger_transformer.parameters()) / 1e9
    print(f"{pytorch_total_params=} B")

    model = nn.Sequential(*[nn.Linear(8944, 8944) for _ in range(1000)])
    module_size = bytes_to_giga_bytes(compute_module_sizes(model)[""])
    print(f"{module_size=} GB")
    pytorch_total_params = sum(p.numel() for p in model.parameters()) / 1e9
    print(f"{pytorch_total_params=} B")

actual_bigger_transformer = Transformer2DModel.from_config(
    pixart_transformer.config, num_layers=72, num_attention_heads=36, cross_attention_dim=2592
)
actual_bigger_transformer.save_pretrained("/raid/.cache/actual_bigger_transformer", max_shard_size="10GB", push_to_hub=True)

As we can see from the Hub repo that its state dict is sharded. To perform with the model, all we have to do is this:

from diffusers import Transformer2DModel
import tempfile
import torch
import os

def get_inputs():
    sample = torch.randn(1, 4, 128, 128)
    timestep = torch.randint(0, 1000, size=(1, ))
    encoder_hidden_states = torch.randn(1, 120, 4096)

    resolution = torch.tensor([1024, 1024]).repeat(1, 1)
    aspect_ratio = torch.tensor([1.]).repeat(1, 1)
    added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
    return sample, timestep, encoder_hidden_states, added_cond_kwargs

with torch.no_grad():
    # max_memory = {0: "15GB"} # reasonable estimate for a consumer-gpu.
    with tempfile.TemporaryDirectory() as tmp_dir:
        new_model = Transformer2DModel.from_pretrained(
            "sayakpaul/actual_bigger_transformer",
            device_map="auto",
        )

        sample, timestep, encoder_hidden_states, added_cond_kwargs = get_inputs()
        out = new_model(
            hidden_states=sample,
            encoder_hidden_states=encoder_hidden_states,
            timestep=timestep, 
            added_cond_kwargs=added_cond_kwargs
        ).sample
        print(f"{out.shape=}, {out.device=}")

I haven't purposefully haven't added documentation because all of this will become useful once we use this in the context of a full-fledged pipeline execution (up next) :)

@sayakpaul sayakpaul requested review from SunMarc and yiyixuxu May 1, 2024 10:46
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

@yiyixuxu @SunMarc a gentle ping here.

@yiyixuxu yiyixuxu requested a review from BenjaminBossan May 13, 2024 22:24
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always delightful to deal with the from_pretrained code ;)

I don't really have any bigger comments, as this should hopefully work well since it's based on the transformers implementation. Only some smaller comments.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work @sayakpaul ! Left a suggestion (not a blocker, we can do it afterwards if needed) ! No major comments since @BenjaminBossan did a very thorough review already !

@sayakpaul
Copy link
Member Author

sayakpaul commented May 29, 2024

I'd rather have another pair of eyes reviewing it, given it's fairly easy to miss something when iterating/reviewing several times on the same code.

Yeah. @yiyixuxu would be the final approver here :)

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR!!
I left some comments and questions :)

revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
subfolder = kwargs.pop("subfolder", None) or ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we handle it where it fails then

we would only need to change one place, no?

raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
# This should correspond to a shard index file.
Copy link
Collaborator

@yiyixuxu yiyixuxu May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to return something different when we can't find the shard index file?

can we do

try:
  model_file = _get_model_file(...)
   ...
except ...
  model_file = None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this question I still have: why do we need to return None when we can't find a shard index file? vs for any other file we get find we raise errors -
where in the code is this needed?

@sayakpaul
Copy link
Member Author

sayakpaul commented Jun 3, 2024

@yiyixuxu do the recent changes work for you?

(I have run the tests)

@sayakpaul sayakpaul requested a review from yiyixuxu June 3, 2024 12:35
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!
I have one quetions! the rest look good to me

raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
# This should correspond to a shard index file.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this question I still have: why do we need to return None when we can't find a shard index file? vs for any other file we get find we raise errors -
where in the code is this needed?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nevermind - i got confused in my last review!
good to merge!

@sayakpaul sayakpaul merged commit 7d88711 into main Jun 7, 2024
@sayakpaul sayakpaul deleted the feat-save-sharded-ckpt branch June 7, 2024 09:19
@Wauplin
Copy link
Collaborator

Wauplin commented Jun 7, 2024

Yay! Great job @sayakpaul ! 🎉

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* feat: support saving a model in sharded checkpoints.

* feat: make loading of sharded checkpoints work.

* add tests

* cleanse the loading logic a bit more.

* more resilience while loading from the Hub.

* parallelize shard downloads by using snapshot_download()/

* default to a shard size.

* more fix

* Empty-Commit

* debug

* fix

* uality

* more debugging

* fix more

* initial comments from Benjamin

* move certain methods to loading_utils

* add test to check if the correct number of shards are present.

* add a test to check if loading of sharded checkpoints from the Hub is okay

* clarify the unit when passed as an int.

* use hf_hub for sharding.

* remove unnecessary code

* remove unnecessary function

* lucain's comments.

* fixes

* address high-level comments.

* fix test

* subfolder shenanigans./

* Update src/diffusers/utils/hub_utils.py

Co-authored-by: Lucain <[email protected]>

* Apply suggestions from code review

Co-authored-by: Lucain <[email protected]>

* remove _huggingface_hub_version as not needed.

* address more feedback.

* add a test for local_files_only=True/

* need hf hub to be at least 0.23.2

* style

* final comment.

* clean up subfolder.

* deal with suffixes in code.

* _add_variant default.

* use weights_name_pattern

* remove add_suffix_keyword

* clean up downloading of sharded ckpts.

* don't return something special when using index.json

* fix more

* don't use bare except

* remove comments and catch the errors better

* fix a couple of things when using is_file()

* empty

---------

Co-authored-by: Lucain <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants