diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 4aee44c56c3e..fb0ce92cb61c 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -62,6 +62,7 @@ jobs: run: | python -m pip install -e .[quality,test] python -m pip install -U git+https://github.com/huggingface/transformers + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -134,6 +135,7 @@ jobs: ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate - name: Environment shell: arch -arch arm64 bash {0} @@ -157,4 +159,4 @@ jobs: uses: actions/upload-artifact@v2 with: name: torch_mps_test_reports - path: reports \ No newline at end of file + path: reports diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 93bbdae388e6..082b12404a85 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -60,6 +60,7 @@ jobs: apt-get update && apt-get install libsndfile1-dev -y python -m pip install -e .[quality,test] python -m pip install -U git+https://github.com/huggingface/transformers + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -126,6 +127,7 @@ jobs: ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate ${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index df3a3bf0fdf2..2d4875b80ced 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -62,6 +62,7 @@ jobs: run: | python -m pip install -e .[quality,test] python -m pip install -U git+https://github.com/huggingface/transformers + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -130,6 +131,7 @@ jobs: - name: Install dependencies run: | python -m pip install -e .[quality,test,training] + python -m pip install git+https://github.com/huggingface/accelerate python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment @@ -151,4 +153,4 @@ jobs: uses: actions/upload-artifact@v2 with: name: examples_test_reports - path: reports \ No newline at end of file + path: reports diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 6d934e6b3049..9afece3244e9 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os from functools import partial from typing import Callable, List, Optional, Tuple, Union @@ -489,11 +490,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P state_dict = load_state_dict(model_file) # move the parms from meta device to cpu for param_name, param in state_dict.items(): - set_module_tensor_to_device(model, param_name, param_device, value=param) + accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) + if accepts_dtype: + set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype) + else: + set_module_tensor_to_device(model, param_name, param_device, value=param) else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by deafult the device_map is None and the weights are loaded on the CPU - accelerate.load_checkpoint_and_dispatch(model, model_file, device_map) + accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype) loading_info = { "missing_keys": [], @@ -519,20 +524,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file) - dtype = set(v.dtype for v in state_dict.values()) - - if len(dtype) > 1 and torch.float32 not in dtype: - raise ValueError( - f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please" - f" make sure that {model_file} weights have only one dtype." - ) - elif len(dtype) > 1 and torch.float32 in dtype: - dtype = torch.float32 - else: - dtype = dtype.pop() - - # move model to correct dtype - model = model.to(dtype) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 68ab914b4209..e4c32c0e78c1 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -70,9 +70,9 @@ def test_from_save_pretrained_dtype(self): with tempfile.TemporaryDirectory() as tmpdirname: model.to(dtype) model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) assert new_model.dtype == dtype - new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype) assert new_model.dtype == dtype def test_determinism(self):