Skip to content

Commit ae70c6f

Browse files
anton-lpatrickvonplaten
authored andcommitted
Add diffusers version and pipeline class to the Hub UA (huggingface#814)
* Add diffusers version and pipeline class to the Hub UA * Fallback to class name for pipelines * Update src/diffusers/modeling_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Remove autoclass Co-authored-by: Patrick von Platen <[email protected]>
1 parent 4b894e3 commit ae70c6f

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
2828
from requests import HTTPError
2929

30-
from . import is_torch_available
30+
from . import __version__, is_torch_available
3131
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
3232
from .utils import (
3333
CONFIG_NAME,
@@ -286,10 +286,13 @@ def from_pretrained(
286286
local_files_only = kwargs.pop("local_files_only", False)
287287
use_auth_token = kwargs.pop("use_auth_token", None)
288288
revision = kwargs.pop("revision", None)
289-
from_auto_class = kwargs.pop("_from_auto", False)
290289
subfolder = kwargs.pop("subfolder", None)
291290

292-
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
291+
user_agent = {
292+
"diffusers": __version__,
293+
"file_type": "model",
294+
"framework": "flax",
295+
}
293296

294297
# Load config if we don't provide a configuration
295298
config_path = config if config is not None else pretrained_model_name_or_path

src/diffusers/modeling_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
2727
from requests import HTTPError
2828

29+
from . import __version__
2930
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
3031

3132

@@ -292,12 +293,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
292293
local_files_only = kwargs.pop("local_files_only", False)
293294
use_auth_token = kwargs.pop("use_auth_token", None)
294295
revision = kwargs.pop("revision", None)
295-
from_auto_class = kwargs.pop("_from_auto", False)
296296
torch_dtype = kwargs.pop("torch_dtype", None)
297297
subfolder = kwargs.pop("subfolder", None)
298298
device_map = kwargs.pop("device_map", None)
299299

300-
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
300+
user_agent = {
301+
"diffusers": __version__,
302+
"file_type": "model",
303+
"framework": "pytorch",
304+
}
301305

302306
# Load config if we don't provide a configuration
303307
config_path = pretrained_model_name_or_path

src/diffusers/pipeline_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from PIL import Image
3030
from tqdm.auto import tqdm
3131

32+
from . import __version__
3233
from .configuration_utils import ConfigMixin
3334
from .dynamic_modules_utils import get_class_from_dynamic_module
3435
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
@@ -376,6 +377,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
376377
if custom_pipeline is not None:
377378
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
378379

380+
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
381+
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
382+
if custom_pipeline is not None:
383+
user_agent["custom_pipeline"] = custom_pipeline
384+
379385
# download all allow_patterns
380386
cached_folder = snapshot_download(
381387
pretrained_model_name_or_path,
@@ -386,6 +392,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
386392
use_auth_token=use_auth_token,
387393
revision=revision,
388394
allow_patterns=allow_patterns,
395+
user_agent=user_agent,
389396
)
390397
else:
391398
cached_folder = pretrained_model_name_or_path

0 commit comments

Comments
 (0)