diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index ba6f7ccace3e..b36664cb81ff 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -1,6 +1,7 @@ import os import tempfile from typing import Callable, List, Optional, Union +from urllib.parse import unquote, urlparse import PIL.Image import PIL.ImageOps @@ -80,12 +81,22 @@ def load_video( ) if is_url: - video_data = requests.get(video, stream=True).raw - suffix = os.path.splitext(video)[1] or ".mp4" + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError(f"Failed to download video. Status code: {response.status_code}") + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name + was_tempfile_created = True + + video_data = response.iter_content(chunk_size=8192) with open(video_path, "wb") as f: - f.write(video_data.read()) + for chunk in video_data: + f.write(chunk) video = video_path