-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Closed
Labels
bugSomething isn't workingSomething isn't workingsingle_filestaleIssues that haven't received updatesIssues that haven't received updates
Description
Describe the bug
I try to load a .safetensors file and save it as diffusers type model and I got Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection:
['text_model.embeddings.position_ids'] warning
import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler
model_params = {
"pretrained_model_link_or_path": "/workspace/playground-v2.5-1024px-aesthetic.fp16.safetensors",
"torch_dtype": torch.float16,
}
pipe = StableDiffusionXLPipeline.from_single_file(**model_params)
pipe.save_pretrained(
save_directory="/workspace/playground-v2.5.fp16",
safe_serialization=True,
variant="fp16",
push_to_hub=False,
)
When I try to load it, I got NotImplementedError: Cannot copy out of meta tensor; no data! error
pipe = StableDiffusionXLPipeline.from_pretrained("/workspace/playground-v2.5.fp16", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
Reproduction
import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler
model_params = {
"pretrained_model_link_or_path": "/workspace/playground-v2.5-1024px-aesthetic.fp16.safetensors",
"torch_dtype": torch.float16,
}
pipe = StableDiffusionXLPipeline.from_single_file(**model_params)
pipe.save_pretrained(
save_directory="/workspace/playground-v2.5.fp16",
safe_serialization=True,
variant="fp16",
push_to_hub=False,
)
pipe = StableDiffusionXLPipeline.from_pretrained("/workspace/playground-v2.5.fp16", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
Logs
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[11], line 2
1 pipe = StableDiffusionXLPipeline.from_pretrained("/workspace/playground-v2.5.fp16", torch_dtype=torch.float16, variant="fp16")
----> 2 pipe = pipe.to("cuda")
3 pipe.safety_checker = None
4 # pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
5
6 #3
File /usr/local/lib/python3.10/dist-packages/diffusers/pipelines/pipeline_utils.py:418, in DiffusionPipeline.to(self, *args, **kwargs)
414 logger.warning(
415 f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
416 )
417 else:
--> 418 module.to(device, dtype)
420 if (
421 module.dtype == torch.float16
422 and str(device) in ["cpu"]
423 and not silence_dtype_warnings
424 and not is_offloaded
425 ):
426 logger.warning(
427 "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
428 " is not recommended to move them to `cpu` as running them will fail. Please make"
(...)
431 " `torch_dtype=torch.float16` argument, or use another device for inference."
432 )
File /usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:2576, in PreTrainedModel.to(self, *args, **kwargs)
2571 if dtype_present_in_args:
2572 raise ValueError(
2573 "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
2574 " `dtype` by passing the correct `torch_dtype` argument."
2575 )
-> 2576 return super().to(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1145, in Module.to(self, *args, **kwargs)
1141 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
1142 non_blocking, memory_format=convert_to_format)
1143 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
-> 1145 return self._apply(convert)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797, in Module._apply(self, fn)
795 def _apply(self, fn):
796 for module in self.children():
--> 797 module._apply(fn)
799 def compute_should_use_set_data(tensor, tensor_applied):
800 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
801 # If the new tensor has compatible tensor type as the existing tensor,
802 # the current behavior is to change the tensor in-place using `.data =`,
(...)
807 # global flag to let the user control whether they want the future
808 # behavior of overwriting the existing tensor or not.
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797, in Module._apply(self, fn)
795 def _apply(self, fn):
796 for module in self.children():
--> 797 module._apply(fn)
799 def compute_should_use_set_data(tensor, tensor_applied):
800 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
801 # If the new tensor has compatible tensor type as the existing tensor,
802 # the current behavior is to change the tensor in-place using `.data =`,
(...)
807 # global flag to let the user control whether they want the future
808 # behavior of overwriting the existing tensor or not.
[... skipping similar frames: Module._apply at line 797 (3 times)]
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797, in Module._apply(self, fn)
795 def _apply(self, fn):
796 for module in self.children():
--> 797 module._apply(fn)
799 def compute_should_use_set_data(tensor, tensor_applied):
800 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
801 # If the new tensor has compatible tensor type as the existing tensor,
802 # the current behavior is to change the tensor in-place using `.data =`,
(...)
807 # global flag to let the user control whether they want the future
808 # behavior of overwriting the existing tensor or not.
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:820, in Module._apply(self, fn)
816 # Tensors stored in modules are graph leaves, and we don't want to
817 # track autograd history of `param_applied`, so we have to use
818 # `with torch.no_grad():`
819 with torch.no_grad():
--> 820 param_applied = fn(param)
821 should_use_set_data = compute_should_use_set_data(param, param_applied)
822 if should_use_set_data:
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1143, in Module.to.<locals>.convert(t)
1140 if convert_to_format is not None and t.dim() in (4, 5):
1141 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
1142 non_blocking, memory_format=convert_to_format)
-> 1143 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
NotImplementedError: Cannot copy out of meta tensor; no data!
System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
diffusers
version: 0.27.0- Platform: Linux-5.15.0-79-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- PyTorch version (GPU?): 2.0.1+cu118 (True)
- Huggingface_hub version: 0.22.1
- Transformers version: 4.39.1
- Accelerate version: 0.28.0
- xFormers version: not installed
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingsingle_filestaleIssues that haven't received updatesIssues that haven't received updates