Skip to content

Commit 8e70c0a

Browse files
Merge branch 'dev-tensorrt-txt2img-pipeline' of github.com:asfiyab-nvidia/diffusers into dev-tensorrt-txt2img-pipeline
Signed-off-by: Asfiya Baig <[email protected]>
2 parents ea03376 + d995745 commit 8e70c0a

File tree

10 files changed

+595
-479
lines changed

10 files changed

+595
-479
lines changed

docs/source/en/using-diffusers/loading.mdx

Lines changed: 196 additions & 412 deletions
Large diffs are not rendered by default.

examples/controlnet/train_controlnet_flax.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
import torch
2828
import torch.utils.checkpoint
2929
import transformers
30-
from datasets import load_dataset
30+
from datasets import load_dataset, load_from_disk
3131
from flax import jax_utils
3232
from flax.core.frozen_dict import unfreeze
3333
from flax.training import train_state
3434
from flax.training.common_utils import shard
3535
from huggingface_hub import create_repo, upload_folder
36-
from PIL import Image
36+
from PIL import Image, PngImagePlugin
3737
from torch.utils.data import IterableDataset
3838
from torchvision import transforms
3939
from tqdm.auto import tqdm
@@ -49,6 +49,11 @@
4949
from diffusers.utils import check_min_version, is_wandb_available
5050

5151

52+
# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
53+
# see more https://github.com/python-pillow/Pillow/issues/5610
54+
LARGE_ENOUGH_NUMBER = 100
55+
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
56+
5257
if is_wandb_available():
5358
import wandb
5459

@@ -246,6 +251,12 @@ def parse_args():
246251
default=None,
247252
help="Total number of training steps to perform.",
248253
)
254+
parser.add_argument(
255+
"--checkpointing_steps",
256+
type=int,
257+
default=5000,
258+
help=("Save a checkpoint of the training state every X updates."),
259+
)
249260
parser.add_argument(
250261
"--learning_rate",
251262
type=float,
@@ -344,9 +355,17 @@ def parse_args():
344355
type=str,
345356
default=None,
346357
help=(
347-
"A folder containing the training data. Folder contents must follow the structure described in"
348-
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
349-
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
358+
"A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
359+
"Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
360+
"If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
361+
),
362+
)
363+
parser.add_argument(
364+
"--load_from_disk",
365+
action="store_true",
366+
help=(
367+
"If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
368+
"See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
350369
),
351370
)
352371
parser.add_argument(
@@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None):
478497
)
479498
else:
480499
if args.train_data_dir is not None:
481-
dataset = load_dataset(
482-
args.train_data_dir,
483-
cache_dir=args.cache_dir,
484-
)
500+
if args.load_from_disk:
501+
dataset = load_from_disk(
502+
args.train_data_dir,
503+
)
504+
else:
505+
dataset = load_dataset(
506+
args.train_data_dir,
507+
cache_dir=args.cache_dir,
508+
)
485509
# See more about loading custom images at
486510
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
487511

@@ -545,6 +569,7 @@ def tokenize_captions(examples, is_train=True):
545569
image_transforms = transforms.Compose(
546570
[
547571
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
572+
transforms.CenterCrop(args.resolution),
548573
transforms.ToTensor(),
549574
transforms.Normalize([0.5], [0.5]),
550575
]
@@ -553,6 +578,7 @@ def tokenize_captions(examples, is_train=True):
553578
conditioning_image_transforms = transforms.Compose(
554579
[
555580
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
581+
transforms.CenterCrop(args.resolution),
556582
transforms.ToTensor(),
557583
]
558584
)
@@ -981,6 +1007,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
9811007
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
9821008
}
9831009
)
1010+
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
1011+
controlnet.save_pretrained(
1012+
f"{args.output_dir}/{global_step}",
1013+
params=get_params_to_save(state.params),
1014+
)
9841015

9851016
train_metric = jax_utils.unreplicate(train_metric)
9861017
train_step_progress_bar.close()

examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,14 +405,14 @@ def main():
405405
args = parse_args()
406406
logging_dir = Path(args.output_dir, args.logging_dir)
407407

408-
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
408+
project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
409409

410410
accelerator = Accelerator(
411411
gradient_accumulation_steps=args.gradient_accumulation_steps,
412412
mixed_precision=args.mixed_precision,
413413
log_with="tensorboard",
414414
logging_dir=logging_dir,
415-
accelerator_project_config=accelerator_project_config,
415+
project_config=project_config,
416416
)
417417

418418
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -734,14 +734,15 @@ def __call__(
734734
image = latents
735735
has_nsfw_concept = None
736736

737-
image = self.decode_latents(latents)
738-
739-
if self.safety_checker is not None:
740-
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
741737
else:
742-
has_nsfw_concept = False
738+
image = self.decode_latents(latents)
739+
740+
if self.safety_checker is not None:
741+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
742+
else:
743+
has_nsfw_concept = False
743744

744-
image = self.image_processor.postprocess(image, output_type=output_type)
745+
image = self.image_processor.postprocess(image, output_type=output_type)
745746

746747
# Offload last model to CPU
747748
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 96 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class AudioPipelineOutput(BaseOutput):
134134
audios: np.ndarray
135135

136136

137-
def is_safetensors_compatible(filenames, variant=None) -> bool:
137+
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
138138
"""
139139
Checking for safetensors compatibility:
140140
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
@@ -150,9 +150,14 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
150150

151151
sf_filenames = set()
152152

153+
passed_components = passed_components or []
154+
153155
for filename in filenames:
154156
_, extension = os.path.splitext(filename)
155157

158+
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
159+
continue
160+
156161
if extension == ".bin":
157162
pt_filenames.append(filename)
158163
elif extension == ".safetensors":
@@ -163,10 +168,8 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
163168
path, filename = os.path.split(filename)
164169
filename, extension = os.path.splitext(filename)
165170

166-
if filename == "pytorch_model":
167-
filename = "model"
168-
elif filename == f"pytorch_model.{variant}":
169-
filename = f"model.{variant}"
171+
if filename.startswith("pytorch_model"):
172+
filename = filename.replace("pytorch_model", "model")
170173
else:
171174
filename = filename
172175

@@ -196,24 +199,51 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
196199
weight_prefixes = [w.split(".")[0] for w in weight_names]
197200
# .bin, .safetensors, ...
198201
weight_suffixs = [w.split(".")[-1] for w in weight_names]
202+
# -00001-of-00002
203+
transformers_index_format = "\d{5}-of-\d{5}"
204+
205+
if variant is not None:
206+
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
207+
variant_file_re = re.compile(
208+
f"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
209+
)
210+
# `text_encoder/pytorch_model.bin.index.fp16.json`
211+
variant_index_re = re.compile(
212+
f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213+
)
199214

200-
variant_file_regex = (
201-
re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})")
202-
if variant is not None
203-
else None
215+
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
216+
non_variant_file_re = re.compile(
217+
f"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
204218
)
205-
non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}")
219+
# `text_encoder/pytorch_model.bin.index.json`
220+
non_variant_index_re = re.compile(f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
206221

207222
if variant is not None:
208-
variant_filenames = {f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None}
223+
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
224+
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
225+
variant_filenames = variant_weights | variant_indexes
209226
else:
210227
variant_filenames = set()
211228

212-
non_variant_filenames = {f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None}
229+
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
230+
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
231+
non_variant_filenames = non_variant_weights | non_variant_indexes
213232

233+
# all variant filenames will be used by default
214234
usable_filenames = set(variant_filenames)
235+
236+
def convert_to_variant(filename):
237+
if "index" in filename:
238+
variant_filename = filename.replace("index", f"index.{variant}")
239+
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
240+
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
241+
else:
242+
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
243+
return variant_filename
244+
215245
for f in non_variant_filenames:
216-
variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}"
246+
variant_filename = convert_to_variant(f)
217247
if variant_filename not in usable_filenames:
218248
usable_filenames.add(f)
219249

@@ -292,6 +322,27 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p
292322
return class_obj, class_candidates
293323

294324

325+
def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, revision=None):
326+
if custom_pipeline is not None:
327+
if custom_pipeline.endswith(".py"):
328+
path = Path(custom_pipeline)
329+
# decompose into folder & file
330+
file_name = path.name
331+
custom_pipeline = path.parent.absolute()
332+
else:
333+
file_name = CUSTOM_PIPELINE_FILE_NAME
334+
335+
return get_class_from_dynamic_module(
336+
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
337+
)
338+
339+
if class_obj != DiffusionPipeline:
340+
return class_obj
341+
342+
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
343+
return getattr(diffusers_module, config["_class_name"])
344+
345+
295346
def load_sub_model(
296347
library_name: str,
297348
class_name: str,
@@ -779,7 +830,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
779830
device_map = kwargs.pop("device_map", None)
780831
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
781832
variant = kwargs.pop("variant", None)
782-
kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
833+
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
783834

784835
# 1. Download the checkpoints and configs
785836
# use snapshot download here to get it working from from_pretrained
@@ -794,8 +845,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
794845
use_auth_token=use_auth_token,
795846
revision=revision,
796847
from_flax=from_flax,
848+
use_safetensors=use_safetensors,
797849
custom_pipeline=custom_pipeline,
850+
custom_revision=custom_revision,
798851
variant=variant,
852+
**kwargs,
799853
)
800854
else:
801855
cached_folder = pretrained_model_name_or_path
@@ -810,29 +864,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
810864
for folder in os.listdir(cached_folder):
811865
folder_path = os.path.join(cached_folder, folder)
812866
is_folder = os.path.isdir(folder_path) and folder in config_dict
813-
variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path))
867+
variant_exists = is_folder and any(
868+
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
869+
)
814870
if variant_exists:
815871
model_variants[folder] = variant
816872

817873
# 3. Load the pipeline class, if using custom module then load it from the hub
818874
# if we load from explicit class, let's use it
819-
if custom_pipeline is not None:
820-
if custom_pipeline.endswith(".py"):
821-
path = Path(custom_pipeline)
822-
# decompose into folder & file
823-
file_name = path.name
824-
custom_pipeline = path.parent.absolute()
825-
else:
826-
file_name = CUSTOM_PIPELINE_FILE_NAME
827-
828-
pipeline_class = get_class_from_dynamic_module(
829-
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision
830-
)
831-
elif cls != DiffusionPipeline:
832-
pipeline_class = cls
833-
else:
834-
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
835-
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
875+
pipeline_class = _get_pipeline_class(
876+
cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision
877+
)
836878

837879
# DEPRECATED: To be removed in 1.0.0
838880
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
@@ -1095,6 +1137,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
10951137
revision = kwargs.pop("revision", None)
10961138
from_flax = kwargs.pop("from_flax", False)
10971139
custom_pipeline = kwargs.pop("custom_pipeline", None)
1140+
custom_revision = kwargs.pop("custom_revision", None)
10981141
variant = kwargs.pop("variant", None)
10991142
use_safetensors = kwargs.pop("use_safetensors", None)
11001143

@@ -1153,7 +1196,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
11531196
# this enables downloading schedulers, tokenizers, ...
11541197
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
11551198
# also allow downloading config.json files with the model
1156-
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
1199+
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
11571200

11581201
allow_patterns += [
11591202
SCHEDULER_CONFIG_NAME,
@@ -1162,17 +1205,28 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
11621205
CUSTOM_PIPELINE_FILE_NAME,
11631206
]
11641207

1208+
# retrieve passed components that should not be downloaded
1209+
pipeline_class = _get_pipeline_class(
1210+
cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision
1211+
)
1212+
expected_components, _ = cls._get_signature_keys(pipeline_class)
1213+
passed_components = [k for k in expected_components if k in kwargs]
1214+
11651215
if (
11661216
use_safetensors
11671217
and not allow_pickle
1168-
and not is_safetensors_compatible(model_filenames, variant=variant)
1218+
and not is_safetensors_compatible(
1219+
model_filenames, variant=variant, passed_components=passed_components
1220+
)
11691221
):
11701222
raise EnvironmentError(
11711223
f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
11721224
)
11731225
if from_flax:
11741226
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
1175-
elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant):
1227+
elif use_safetensors and is_safetensors_compatible(
1228+
model_filenames, variant=variant, passed_components=passed_components
1229+
):
11761230
ignore_patterns = ["*.bin", "*.msgpack"]
11771231

11781232
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
@@ -1194,6 +1248,13 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
11941248
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
11951249
)
11961250

1251+
# Don't download any objects that are passed
1252+
allow_patterns = [
1253+
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
1254+
]
1255+
# Don't download index files of forbidden patterns either
1256+
ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]
1257+
11971258
re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
11981259
re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
11991260

0 commit comments

Comments
 (0)