Skip to content

Conversation

@SunMarc
Copy link
Member

@SunMarc SunMarc commented Jun 13, 2024

What does this PR do?

This PR fixes the loading for sharded checkpoint when no device_map is passed. Currently, the following doesn't work:

from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline 
import torch

unet = UNet2DConditionModel.from_pretrained(
    "sayakpaul/sdxl-unet-sharded", torch_dtype = torch.float16
)

You can have more details here.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc
Copy link
Member Author

SunMarc commented Jun 13, 2024

There is still a path where sharding is not handled. It happens when low_cpu_mem_usage=False. I see that by default, low_cpu_mem_usage is set to True, it is the case for most models ? cc @sayakpaul

@SunMarc SunMarc requested review from sayakpaul and yiyixuxu and removed request for yiyixuxu June 13, 2024 12:51
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!
very nice tests too:)

is it possible to explain device_map=None in the doc string for device_map too?

@SunMarc
Copy link
Member Author

SunMarc commented Jun 14, 2024

is it possible to explain device_map=None in the doc string for device_map too?

Done !

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_torch_gpu
def test_sharded_checkpoints(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is already here:

def test_sharded_checkpoints(self):

Is it different?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

he renamed this test to test_sharded_checkpoints_device_map because in that test it loads with device_map='auto' flag; this is a new test testing default value for device_map

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I renamed the tests since it makes more sense this way

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much, Marc. I think there's some confusion in the tests as they are existing in the main already. Am I missing out on something?

@sayakpaul
Copy link
Member

Alright then! Let’s merge this.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@yiyixuxu yiyixuxu merged commit 96399c3 into huggingface:main Jun 18, 2024
yiyixuxu pushed a commit that referenced this pull request Jun 20, 2024
* Fix sharding when no device_map is passed

* style

* add tests

* align

* add docstring

* format

---------

Co-authored-by: Sayak Paul <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* Fix sharding when no device_map is passed

* style

* add tests

* align

* add docstring

* format

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants