Skip to content

Commit eb04363

Browse files
authored
Raise error instead of warning when using meta device in from_pretrained (#40942)
* raise instead of warning * add timm * remove
1 parent ecc1d77 commit eb04363

File tree

5 files changed

+10
-26
lines changed

5 files changed

+10
-26
lines changed

src/transformers/modeling_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4910,11 +4910,10 @@ def from_pretrained(
49104910
if device_map is None and not is_deepspeed_zero3_enabled():
49114911
device_in_context = get_torch_context_manager_or_global_device()
49124912
if device_in_context == torch.device("meta"):
4913-
# TODO Cyril: raise an error instead of the warning in v4.53 (and change the test to check for raise instead of success)
4914-
logger.warning(
4915-
"We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`\n"
4916-
"This is an anti-pattern and will raise an Error in version v4.53\nIf you want to initialize a model on the meta device, use "
4917-
"the context manager or global device with `from_config`, or `ModelClass(config)`"
4913+
raise RuntimeError(
4914+
"You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n"
4915+
"This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an "
4916+
"empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`"
49184917
)
49194918
device_map = device_in_context
49204919

tests/models/perception_lm/test_modeling_perception_lm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,6 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
313313
def test_can_be_initialized_on_meta(self):
314314
pass
315315

316-
@unittest.skip("ViT PE / TimmWrapperModel cannot be tested with meta device")
317-
def test_can_load_with_meta_device_context_manager(self):
318-
pass
319-
320316
@unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM")
321317
def test_generate_from_inputs_embeds_0_greedy(self):
322318
pass

tests/models/timm_backbone/test_modeling_timm_backbone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def test_can_load_with_global_device_set(self):
169169
pass
170170

171171
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
172-
def test_can_load_with_meta_device_context_manager(self):
172+
def test_cannot_load_with_meta_device_context_manager(self):
173173
pass
174174

175175
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")

tests/models/xcodec/test_modeling_xcodec.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,6 @@ def test_gradient_checkpointing_backward_compatibility(self):
151151
model = model_class(config)
152152
self.assertTrue(model.is_gradient_checkpointing)
153153

154-
@unittest.skip("XcodecModel cannot be tested with meta device")
155-
def test_can_load_with_meta_device_context_manager(self):
156-
pass
157-
158154
@unittest.skip(reason="We cannot configure to output a smaller model.")
159155
def test_model_is_small(self):
160156
pass

tests/test_modeling_common.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4488,7 +4488,7 @@ def test_can_load_with_global_device_set(self):
44884488
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
44894489
)
44904490

4491-
def test_can_load_with_meta_device_context_manager(self):
4491+
def test_cannot_load_with_meta_device_context_manager(self):
44924492
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
44934493
for model_class in self.all_model_classes:
44944494
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
@@ -4497,18 +4497,11 @@ def test_can_load_with_meta_device_context_manager(self):
44974497

44984498
with tempfile.TemporaryDirectory() as tmpdirname:
44994499
model.save_pretrained(tmpdirname)
4500-
45014500
with torch.device("meta"):
4502-
new_model = model_class.from_pretrained(tmpdirname)
4503-
unique_devices = {param.device for param in new_model.parameters()} | {
4504-
buffer.device for buffer in new_model.buffers()
4505-
}
4506-
4507-
self.assertEqual(
4508-
unique_devices,
4509-
{torch.device("meta")},
4510-
f"All parameters should be on meta device, but found {unique_devices}.",
4511-
)
4501+
with self.assertRaisesRegex(
4502+
RuntimeError, "You are using `from_pretrained` with a meta device context manager"
4503+
):
4504+
_ = model_class.from_pretrained(tmpdirname)
45124505

45134506
def test_config_attn_implementation_setter(self):
45144507
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

0 commit comments

Comments
 (0)