From 1a335644cdff0ce704bfbf9e1513c4dd5a63373e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 23 Mar 2023 23:53:06 +0000 Subject: [PATCH 01/12] add train_controlnet_flax --- examples/controlnet/train_controlnet_flax.py | 740 +++++++++++++++++++ 1 file changed, 740 insertions(+) create mode 100644 examples/controlnet/train_controlnet_flax.py diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py new file mode 100644 index 000000000000..3b0e5f96b8d6 --- /dev/null +++ b/examples/controlnet/train_controlnet_flax.py @@ -0,0 +1,740 @@ +import argparse +import logging +import math +import os +import random +from pathlib import Path +from typing import Optional + +import jax +import jax.numpy as jnp +import numpy as np +import optax +import torch +import torch.utils.checkpoint +import transformers +from datasets import load_dataset +from flax import jax_utils +from flax.training import train_state +from flax.training.common_utils import shard +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed + + +from diffusers import ( + FlaxAutoencoderKL, + FlaxDDPMScheduler, + FlaxPNDMScheduler, + FlaxStableDiffusionPipeline, + FlaxUNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker +from diffusers.utils import check_min_version, is_wandb_available + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.15.0.dev0") + +logger = logging.getLogger(__name__) + +def log_validation(unet, unet_params, args, rng, weight_dtype): + logger.info("Running validation... ") + + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=unet, + safety_checker=None, + torch_dtype=weight_dtype, + ) + params = jax_utils.replicate(params) + params['unet'] = unet_params + + num_samples = jax.device_count() + prng_seed = jax.random.split(rng, jax.device_count()) + + prompts = num_samples * [args.validation_prompt] + prompt_ids = pipeline.prepare_inputs(prompts) + prompt_ids = shard(prompt_ids) + + images = pipeline(prompt_ids, params, prng_seed, 50, jit=True).images + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + images = pipeline.numpy_to_pil(images) + + if args.report_to == 'wandb': + wandb.log({"Validation Images": [wandb.Image(img, caption=f"{i}:{args.validation_prompt}") for i, img in enumerate(images)]}) + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="train_text_to_image", + help=( + "The `project` argument passed to wandb"), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help=( + "A prompt evaluated every `--validation_steps` and logged to `--report_to`."), + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` and logging the images." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + return args + +def make_train_dataset(args, tokenizer): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {caption_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if random.random() < args.proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + examples["input_ids"] = tokenize_captions(examples) + + return examples + + if jax.process_index() == 0: + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + return train_dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.stack([example["input_ids"] for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "input_ids": input_ids, + } + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def get_params_to_save(params): + return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params)) + + +def main(): + args = parse_args() + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + # handle wandb init + if jax.process_index() == 0 and args.report_to == 'wandb': + wandb.init( + project=args.tracker_project_name, + job_type="train", + config=args, + ) + + if args.seed is not None: + set_seed(args.seed) + rng = jax.random.PRNGKey(args.seed) + else: + rng = jax.random.PRNGKey(0) + + # Handle the repository creation + if jax.process_index() == 0: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + else: + raise NotImplementedError("No tokenizer specified!") + + # Get the datasets: you can either provide your own training and evaluation files (see below) + train_dataset = make_train_dataset(args, tokenizer) + total_train_batch_size = args.train_batch_size * jax.local_device_count() + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=total_train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + weight_dtype = jnp.float32 + if args.mixed_precision == "fp16": + weight_dtype = jnp.float16 + elif args.mixed_precision == "bf16": + weight_dtype = jnp.bfloat16 + + # Load models and create wrapper for stable diffusion + text_encoder = FlaxCLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision + ) + vae, vae_params = FlaxAutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype + ) + unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision + ) + + # to-do: + # 1. add an argument for from_pt + # 2. check trainable models (controlnet) are in full precision + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(args.controlnet_model_name_or_path, from_pt=True, dtype=jnp.float32) + else: + logger.info("Initializing controlnet weights from unet") + rng, rng_params = jax.random.split(rng) + + controlnet_config = FlaxControlNetModel.load_config(args.controlnet_model_name_or_path) + controlnet = FlaxControlNetModel.from_config(controlnet_config) + controlnet_params = controlnet.init_weights(rng=rng_params) + for key in ['conv_in', 'time_proj', 'time_embedding', 'down_blocks','mid_block']: + controlnet_params[key] = unet_params[key] + + # Optimization + if args.scale_lr: + args.learning_rate = args.learning_rate * total_train_batch_size + + constant_scheduler = optax.constant_schedule(args.learning_rate) + + adamw = optax.adamw( + learning_rate=constant_scheduler, + b1=args.adam_beta1, + b2=args.adam_beta2, + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + optimizer = optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + adamw, + ) + + state = train_state.TrainState.create(apply_fn=controlnet.__call__, params=controlnet_params, tx=optimizer) + + noise_scheduler = FlaxDDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + ) + noise_scheduler_state = noise_scheduler.create_state() + + # Initialize our training + train_rngs = jax.random.split(rng, jax.local_device_count()) + + def train_step(state, unet_param, text_encoder_params, vae_params, batch, train_rng): + dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3) + + def compute_loss(params): + # Convert images to latent space + vae_outputs = vae.apply( + {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode + ) + latents = vae_outputs.latent_dist.sample(sample_rng) + # (NHWC) -> (NCHW) + latents = jnp.transpose(latents, (0, 3, 1, 2)) + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise_rng, timestep_rng = jax.random.split(sample_rng) + noise = jax.random.normal(noise_rng, latents.shape) + # Sample a random timestep for each image + bsz = latents.shape[0] + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + noise_scheduler.config.num_train_timesteps, + ) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder( + batch["input_ids"], + params=text_encoder_params, + train=False, + )[0] + + controlnet_image = batch["conditioning_pixel_values"] + + # Predict the noise residual and compute loss + down_block_res_samples, mid_block_res_sample = controlnet.apply( + {'params': params}, + noisy_latents, + timesteps, + encoder_hidden_states, + controlnet_cond, + train=True, + return_dict=False,) + + model_pred = unet.apply( + {"params": unet_params}, + noisy_latents, + timesteps, + encoder_hidden_states, + down_block_additional_residuals = down_block_res_samples, + mid_block_additional_residual = mid_block_res_sample, + ).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = (target - model_pred) ** 2 + loss = loss.mean() + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + + new_state = state.apply_gradients(grads=grad) + + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + + return new_state, metrics, new_train_rng + + # Create parallel version of the train step + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + # Replicate the train state on each device + state = jax_utils.replicate(state) + unet_params = jax_utils.replicate(unet_params) + text_encoder_params = jax_utils.replicate(text_encoder.params) + vae_params = jax_utils.replicate(vae_params) + + # Train! + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) + + # Scheduler and math around the number of training steps. + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + global_step = 0 + + epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0, disable=jax.process_index() > 0,) + for epoch in epochs: + # ======================== Training ================================ + + train_metrics = [] + + steps_per_epoch = len(train_dataset) // total_train_batch_size + train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False, disable=jax.process_index() > 0,) + # train + for batch in train_dataloader: + batch = shard(batch) + state, train_metric, train_rngs = p_train_step(state, text_encoder_params, vae_params, batch, train_rngs) + train_metrics.append(train_metric) + + train_step_progress_bar.update(1) + + global_step += 1 + + if args.validation_prompt is not None and global_step % args.validation_steps == 0 and jax.process_index() == 0: + log_validation(unet, state.params, args, rng, weight_dtype) + + if global_step >= args.max_train_steps: + break + + train_metric = jax_utils.unreplicate(train_metric) + + train_step_progress_bar.close() + epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") + if args.report_to == 'wandb' and jax.process_index() == 0: + wandb.log({ + "train/epoch": epoch, + "train/loss": train_metric['loss']}) + + # Create the pipeline using using the trained modules and save it. + if jax.process_index() == 0: + scheduler = FlaxPNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) + safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", from_pt=True + ) + pipeline = FlaxStableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"), + ) + + pipeline.save_pretrained( + args.output_dir, + params={ + "text_encoder": get_params_to_save(text_encoder_params), + "vae": get_params_to_save(vae_params), + "unet": get_params_to_save(state.params), + "safety_checker": safety_checker.params, + }, + ) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + +if __name__ == "__main__": + main() From b5f71ddf3d3d05112845594cd99121dbd65b1e3e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 24 Mar 2023 09:13:18 +0000 Subject: [PATCH 02/12] fix --- examples/controlnet/train_controlnet_flax.py | 308 +++++++++++-------- 1 file changed, 177 insertions(+), 131 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 3b0e5f96b8d6..775744e865f3 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + import argparse import logging import math @@ -17,18 +32,21 @@ from flax import jax_utils from flax.training import train_state from flax.training.common_utils import shard +from flax.core.frozen_dict import unfreeze from huggingface_hub import HfFolder, Repository, create_repo, whoami from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed +from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed +from PIL import Image from diffusers import ( FlaxAutoencoderKL, FlaxDDPMScheduler, FlaxPNDMScheduler, - FlaxStableDiffusionPipeline, + FlaxStableDiffusionControlNetPipeline, FlaxUNet2DConditionModel, + FlaxControlNetModel, ) from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker from diffusers.utils import check_min_version, is_wandb_available @@ -41,31 +59,51 @@ logger = logging.getLogger(__name__) -def log_validation(unet, unet_params, args, rng, weight_dtype): +def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype): logger.info("Running validation... ") - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=unet, + tokenizer=tokenizer, + controlnet=controlnet, safety_checker=None, - torch_dtype=weight_dtype, + dtype=weight_dtype, + revision=args.revision ) params = jax_utils.replicate(params) - params['unet'] = unet_params - + params['controlnet'] = controlnet_params + num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) prompts = num_samples * [args.validation_prompt] - prompt_ids = pipeline.prepare_inputs(prompts) + prompt_ids = pipeline.prepare_text_inputs(prompts) prompt_ids = shard(prompt_ids) + + validation_image = Image.open(args.validation_image) + processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) + processed_image = shard(processed_image) + + images = pipeline( + prompt_ids=prompt_ids, + image=processed_image, + params=params, + prng_seed=prng_seed, + num_inference_steps=50, + jit=True).images - images = pipeline(prompt_ids, params, prng_seed, 50, jit=True).images images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) images = pipeline.numpy_to_pil(images) - + if args.report_to == 'wandb': - wandb.log({"Validation Images": [wandb.Image(img, caption=f"{i}:{args.validation_prompt}") for i, img in enumerate(images)]}) + + images_log = [] + images_log.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + for i, image in enumerate(images): + image = wandb.Image(image, caption=f"{i}:{args.validation_prompt}") + images_log.append(image) + + wandb.log({"Validation": images_log}) def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") @@ -77,66 +115,29 @@ def parse_args(): help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--dataset_name", + "--controlnet_model_name_or_path", type=str, default=None, - help=( - "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," - " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," - " or to a folder containing files that 🤗 Datasets can understand." - ), + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", ) parser.add_argument( - "--dataset_config_name", + "--revision", type=str, default=None, - help="The config of the Dataset, leave as None if there's only one config.", + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( - "--train_data_dir", + "--tokenizer_name", type=str, default=None, - help=( - "A folder containing the training data. Folder contents must follow the structure described in" - " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" - " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." - ), - ) - parser.add_argument( - "--image_column", type=str, default="image", help="The column of the dataset containing an image." - ) - parser.add_argument( - "--conditioning_image_column", - type=str, - default="conditioning_image", - help="The column of the dataset containing the controlnet conditioning image.", - ) - parser.add_argument( - "--caption_column", - type=str, - default="text", - help="The column of the dataset containing a caption or a list of captions.", - ) - parser.add_argument( - "--max_train_samples", - type=int, - default=None, - help=( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ), + help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--output_dir", type=str, - default="sd-model-finetuned", + default="controlnet-model", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( @@ -156,21 +157,7 @@ def parse_args(): ), ) parser.add_argument( - "--center_crop", - default=False, - action="store_true", - help=( - "Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping." - ), - ) - parser.add_argument( - "--random_flip", - action="store_true", - help="whether to randomly flip images horizontally", - ) - parser.add_argument( - "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." ) parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument( @@ -208,29 +195,6 @@ def parse_args(): "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) - parser.add_argument( - "--tracker_project_name", - type=str, - default="train_text_to_image", - help=( - "The `project` argument passed to wandb"), - ) - parser.add_argument( - "--validation_prompt", - type=str, - default=None, - help=( - "A prompt evaluated every `--validation_steps` and logged to `--report_to`."), - ) - parser.add_argument( - "--validation_steps", - type=int, - default=100, - help=( - "Run validation every X steps. Validation consists of running the prompt" - " `args.validation_prompt` and logging the images." - ), - ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -273,6 +237,95 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help=( + "A prompt evaluated every `--validation_steps` and logged to `--report_to`." + " Used with `--validation_image` "), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + help=( + "path to the controlnet conditioning image evaluated every `--validation_steps`" + " and logged to `--report_to`. Used with `--validation_prompt` " + ), + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="train_controlnet_flax", + help=( + "The `project` argument passed to wandb"), + ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") args = parser.parse_args() @@ -286,6 +339,7 @@ def parse_args(): return args + def make_train_dataset(args, tokenizer): # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). @@ -411,11 +465,13 @@ def collate_fn(examples): input_ids = torch.stack([example["input_ids"] for example in examples]) - return { + batch = { "pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "input_ids": input_ids, } + batch = {k: v.numpy() for k, v in batch.items()} + return batch def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): @@ -457,9 +513,8 @@ def main(): if args.seed is not None: set_seed(args.seed) - rng = jax.random.PRNGKey(args.seed) - else: - rng = jax.random.PRNGKey(0) + + rng = jax.random.PRNGKey(0) # Handle the repository creation if jax.process_index() == 0: @@ -528,10 +583,21 @@ def main(): logger.info("Initializing controlnet weights from unet") rng, rng_params = jax.random.split(rng) - controlnet_config = FlaxControlNetModel.load_config(args.controlnet_model_name_or_path) - controlnet = FlaxControlNetModel.from_config(controlnet_config) + controlnet = FlaxControlNetModel( + in_channels=unet.config.in_channels, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + attention_head_dim=unet.config.attention_head_dim, + cross_attention_dim=unet.config.cross_attention_dim, + use_linear_projection=unet.config.use_linear_projection, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + ) controlnet_params = controlnet.init_weights(rng=rng_params) - for key in ['conv_in', 'time_proj', 'time_embedding', 'down_blocks','mid_block']: + controlnet_params = unfreeze(controlnet_params) + for key in ['conv_in', 'time_embedding', 'down_blocks_0', 'down_blocks_1', 'down_blocks_2', 'down_blocks_3', 'mid_block']: controlnet_params[key] = unet_params[key] # Optimization @@ -561,9 +627,10 @@ def main(): noise_scheduler_state = noise_scheduler.create_state() # Initialize our training - train_rngs = jax.random.split(rng, jax.local_device_count()) + validation_rng, train_rngs = jax.random.split(rng) + train_rngs = jax.random.split(train_rngs, jax.local_device_count()) - def train_step(state, unet_param, text_encoder_params, vae_params, batch, train_rng): + def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng): dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3) def compute_loss(params): @@ -599,7 +666,7 @@ def compute_loss(params): train=False, )[0] - controlnet_image = batch["conditioning_pixel_values"] + controlnet_cond = batch["conditioning_pixel_values"] # Predict the noise residual and compute loss down_block_res_samples, mid_block_res_sample = controlnet.apply( @@ -682,7 +749,7 @@ def compute_loss(params): # train for batch in train_dataloader: batch = shard(batch) - state, train_metric, train_rngs = p_train_step(state, text_encoder_params, vae_params, batch, train_rngs) + state, train_metric, train_rngs = p_train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rngs) train_metrics.append(train_metric) train_step_progress_bar.update(1) @@ -690,7 +757,7 @@ def compute_loss(params): global_step += 1 if args.validation_prompt is not None and global_step % args.validation_steps == 0 and jax.process_index() == 0: - log_validation(unet, state.params, args, rng, weight_dtype) + log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) if global_step >= args.max_train_steps: break @@ -706,30 +773,9 @@ def compute_loss(params): # Create the pipeline using using the trained modules and save it. if jax.process_index() == 0: - scheduler = FlaxPNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ) - safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker", from_pt=True - ) - pipeline = FlaxStableDiffusionPipeline( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"), - ) - - pipeline.save_pretrained( + controlnet.save_pretrained( args.output_dir, - params={ - "text_encoder": get_params_to_save(text_encoder_params), - "vae": get_params_to_save(vae_params), - "unet": get_params_to_save(state.params), - "safety_checker": safety_checker.params, - }, + params=get_params_to_save(state.params), ) if args.push_to_hub: From 075f4c4a27440fcc5a8eee0bd57f026c8b09a533 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 24 Mar 2023 16:29:29 +0000 Subject: [PATCH 03/12] fix --- examples/controlnet/train_controlnet_flax.py | 134 +++++++++++-------- 1 file changed, 80 insertions(+), 54 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 775744e865f3..21f4041f7531 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -30,27 +30,25 @@ import transformers from datasets import load_dataset from flax import jax_utils +from flax.core.frozen_dict import unfreeze from flax.training import train_state from flax.training.common_utils import shard -from flax.core.frozen_dict import unfreeze from huggingface_hub import HfFolder, Repository, create_repo, whoami +from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed -from PIL import Image - from diffusers import ( FlaxAutoencoderKL, + FlaxControlNetModel, FlaxDDPMScheduler, - FlaxPNDMScheduler, FlaxStableDiffusionControlNetPipeline, FlaxUNet2DConditionModel, - FlaxControlNetModel, ) -from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker from diffusers.utils import check_min_version, is_wandb_available + if is_wandb_available(): import wandb @@ -59,19 +57,20 @@ logger = logging.getLogger(__name__) + def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype): logger.info("Running validation... ") - + pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, tokenizer=tokenizer, controlnet=controlnet, safety_checker=None, dtype=weight_dtype, - revision=args.revision - ) + revision=args.revision, + ) params = jax_utils.replicate(params) - params['controlnet'] = controlnet_params + params["controlnet"] = controlnet_params num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) @@ -79,24 +78,24 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d prompts = num_samples * [args.validation_prompt] prompt_ids = pipeline.prepare_text_inputs(prompts) prompt_ids = shard(prompt_ids) - + validation_image = Image.open(args.validation_image) processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) images = pipeline( - prompt_ids=prompt_ids, + prompt_ids=prompt_ids, image=processed_image, - params=params, - prng_seed=prng_seed, - num_inference_steps=50, - jit=True).images + params=params, + prng_seed=prng_seed, + num_inference_steps=50, + jit=True, + ).images images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) images = pipeline.numpy_to_pil(images) - - if args.report_to == 'wandb': + if args.report_to == "wandb": images_log = [] images_log.append(wandb.Image(validation_image, caption="Controlnet conditioning")) for i, image in enumerate(images): @@ -105,6 +104,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d wandb.log({"Validation": images_log}) + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -299,7 +299,8 @@ def parse_args(): default=None, help=( "A prompt evaluated every `--validation_steps` and logged to `--report_to`." - " Used with `--validation_image` "), + " Used with `--validation_image` " + ), ) parser.add_argument( "--validation_image", @@ -323,8 +324,7 @@ def parse_args(): "--tracker_project_name", type=str, default="train_controlnet_flax", - help=( - "The `project` argument passed to wandb"), + help=("The `project` argument passed to wandb"), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -502,14 +502,14 @@ def main(): transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() - + # handle wandb init - if jax.process_index() == 0 and args.report_to == 'wandb': + if jax.process_index() == 0 and args.report_to == "wandb": wandb.init( project=args.tracker_project_name, job_type="train", config=args, - ) + ) if args.seed is not None: set_seed(args.seed) @@ -547,15 +547,15 @@ def main(): # Get the datasets: you can either provide your own training and evaluation files (see below) train_dataset = make_train_dataset(args, tokenizer) total_train_batch_size = args.train_batch_size * jax.local_device_count() - + train_dataloader = torch.utils.data.DataLoader( - train_dataset, - shuffle=True, - collate_fn=collate_fn, - batch_size=total_train_batch_size, + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=total_train_batch_size, num_workers=args.dataloader_num_workers, ) - + weight_dtype = jnp.float32 if args.mixed_precision == "fp16": weight_dtype = jnp.float16 @@ -573,12 +573,14 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision ) - # to-do: - # 1. add an argument for from_pt - # 2. check trainable models (controlnet) are in full precision + # to-do: + # 1. add an argument for from_pt + # 2. check trainable models (controlnet) are in full precision if args.controlnet_model_name_or_path: logger.info("Loading existing controlnet weights") - controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(args.controlnet_model_name_or_path, from_pt=True, dtype=jnp.float32) + controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + args.controlnet_model_name_or_path, from_pt=True, dtype=jnp.float32 + ) else: logger.info("Initializing controlnet weights from unet") rng, rng_params = jax.random.split(rng) @@ -597,7 +599,15 @@ def main(): ) controlnet_params = controlnet.init_weights(rng=rng_params) controlnet_params = unfreeze(controlnet_params) - for key in ['conv_in', 'time_embedding', 'down_blocks_0', 'down_blocks_1', 'down_blocks_2', 'down_blocks_3', 'mid_block']: + for key in [ + "conv_in", + "time_embedding", + "down_blocks_0", + "down_blocks_1", + "down_blocks_2", + "down_blocks_3", + "mid_block", + ]: controlnet_params[key] = unet_params[key] # Optimization @@ -670,21 +680,22 @@ def compute_loss(params): # Predict the noise residual and compute loss down_block_res_samples, mid_block_res_sample = controlnet.apply( - {'params': params}, - noisy_latents, - timesteps, - encoder_hidden_states, - controlnet_cond, + {"params": params}, + noisy_latents, + timesteps, + encoder_hidden_states, + controlnet_cond, train=True, - return_dict=False,) + return_dict=False, + ) model_pred = unet.apply( - {"params": unet_params}, - noisy_latents, + {"params": unet_params}, + noisy_latents, timesteps, encoder_hidden_states, - down_block_additional_residuals = down_block_res_samples, - mid_block_additional_residual = mid_block_res_sample, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, ).sample # Get the target for loss depending on the prediction type @@ -738,25 +749,42 @@ def compute_loss(params): global_step = 0 - epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0, disable=jax.process_index() > 0,) + epochs = tqdm( + range(args.num_train_epochs), + desc="Epoch ... ", + position=0, + disable=jax.process_index() > 0, + ) for epoch in epochs: # ======================== Training ================================ train_metrics = [] steps_per_epoch = len(train_dataset) // total_train_batch_size - train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False, disable=jax.process_index() > 0,) + train_step_progress_bar = tqdm( + total=steps_per_epoch, + desc="Training...", + position=1, + leave=False, + disable=jax.process_index() > 0, + ) # train for batch in train_dataloader: batch = shard(batch) - state, train_metric, train_rngs = p_train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rngs) + state, train_metric, train_rngs = p_train_step( + state, unet_params, text_encoder_params, vae_params, batch, train_rngs + ) train_metrics.append(train_metric) train_step_progress_bar.update(1) global_step += 1 - if args.validation_prompt is not None and global_step % args.validation_steps == 0 and jax.process_index() == 0: + if ( + args.validation_prompt is not None + and global_step % args.validation_steps == 0 + and jax.process_index() == 0 + ): log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) if global_step >= args.max_train_steps: @@ -765,11 +793,9 @@ def compute_loss(params): train_metric = jax_utils.unreplicate(train_metric) train_step_progress_bar.close() - epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") - if args.report_to == 'wandb' and jax.process_index() == 0: - wandb.log({ - "train/epoch": epoch, - "train/loss": train_metric['loss']}) + epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") + if args.report_to == "wandb" and jax.process_index() == 0: + wandb.log({"train/epoch": epoch, "train/loss": train_metric["loss"]}) # Create the pipeline using using the trained modules and save it. if jax.process_index() == 0: From 6c6718c5dc7be1ac7881ec89240e71f43a0fc888 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 25 Mar 2023 06:27:42 +0000 Subject: [PATCH 04/12] improve logging --- examples/controlnet/train_controlnet_flax.py | 178 ++++++++++++++----- 1 file changed, 131 insertions(+), 47 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 21f4041f7531..2bc2dbad65f1 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -18,6 +18,7 @@ import math import os import random +import time from pathlib import Path from typing import Optional @@ -68,6 +69,8 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d safety_checker=None, dtype=weight_dtype, revision=args.revision, + from_pt=args.from_pt, + ) params = jax_utils.replicate(params) params["controlnet"] = controlnet_params @@ -75,34 +78,61 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) - prompts = num_samples * [args.validation_prompt] - prompt_ids = pipeline.prepare_text_inputs(prompts) - prompt_ids = shard(prompt_ids) - - validation_image = Image.open(args.validation_image) - processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) - processed_image = shard(processed_image) - - images = pipeline( - prompt_ids=prompt_ids, - image=processed_image, - params=params, - prng_seed=prng_seed, - num_inference_steps=50, - jit=True, - ).images + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) - images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) - images = pipeline.numpy_to_pil(images) + image_logs = [] + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + prompts = num_samples * [validation_prompt] + prompt_ids = pipeline.prepare_text_inputs(prompts) + prompt_ids = shard(prompt_ids) + + validation_image = Image.open(validation_image) + processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) + processed_image = shard(processed_image) + images = pipeline( + prompt_ids=prompt_ids, + image=processed_image, + params=params, + prng_seed=prng_seed, + num_inference_steps=50, + jit=True, + ).images + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + images = pipeline.numpy_to_pil(images) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) if args.report_to == "wandb": - images_log = [] - images_log.append(wandb.Image(validation_image, caption="Controlnet conditioning")) - for i, image in enumerate(images): - image = wandb.Image(image, caption=f"{i}:{args.validation_prompt}") - images_log.append(image) - - wandb.log({"Validation": images_log}) + formatted_images = [] + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + wandb.log({"validation": formatted_images}) + else: + logger.warn(f"image logging not implemented for {args.report_to}") def parse_args(): @@ -110,7 +140,6 @@ def parse_args(): parser.add_argument( "--pretrained_model_name_or_path", type=str, - default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) @@ -125,9 +154,13 @@ def parse_args(): "--revision", type=str, default=None, - required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--from_pt", + action="store_true", + help="Load the pretrained model from a pytorch checkpoint.", + ) parser.add_argument( "--tokenizer_name", type=str, @@ -164,7 +197,7 @@ def parse_args(): "--max_train_steps", type=int, default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + help="Total number of training steps to perform.", ) parser.add_argument( "--learning_rate", @@ -175,7 +208,6 @@ def parse_args(): parser.add_argument( "--scale_lr", action="store_true", - default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( @@ -217,6 +249,12 @@ def parse_args(): " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) + parser.add_argument( + "--logging_steps", + type=int, + default=100, + help=("log training metric every X steps to `--report_t`"), + ) parser.add_argument( "--report_to", type=str, @@ -297,18 +335,23 @@ def parse_args(): "--validation_prompt", type=str, default=None, + nargs="+", help=( - "A prompt evaluated every `--validation_steps` and logged to `--report_to`." - " Used with `--validation_image` " + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." ), ) parser.add_argument( "--validation_image", type=str, default=None, + nargs="+", help=( - "path to the controlnet conditioning image evaluated every `--validation_steps`" - " and logged to `--report_to`. Used with `--validation_prompt` " + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." ), ) parser.add_argument( @@ -336,6 +379,29 @@ def parse_args(): # Sanity checks if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Need either a dataset name or a training folder.") + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) return args @@ -564,18 +630,27 @@ def main(): # Load models and create wrapper for stable diffusion text_encoder = FlaxCLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision + args.pretrained_model_name_or_path, + subfolder="text_encoder", + dtype=weight_dtype, + revision=args.revision, + from_pt=args.from_pt, ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype + args.pretrained_model_name_or_path, + revision=args.revision, + subfolder="vae", + dtype=weight_dtype, + from_pt=args.from_pt, ) unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision + args.pretrained_model_name_or_path, + subfolder="unet", + dtype=weight_dtype, + revision=args.revision, + from_pt=args.from_pt, ) - # to-do: - # 1. add an argument for from_pt - # 2. check trainable models (controlnet) are in full precision if args.controlnet_model_name_or_path: logger.info("Loading existing controlnet weights") controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( @@ -748,7 +823,7 @@ def compute_loss(params): logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 - + train_metrics_last_t = time.time() epochs = tqdm( range(args.num_train_epochs), desc="Epoch ... ", @@ -779,6 +854,8 @@ def compute_loss(params): train_step_progress_bar.update(1) global_step += 1 + if global_step >= args.max_train_steps: + break if ( args.validation_prompt is not None @@ -787,15 +864,22 @@ def compute_loss(params): ): log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) - if global_step >= args.max_train_steps: - break - - train_metric = jax_utils.unreplicate(train_metric) + if global_step % args.logging_steps == 0 and jax.process_index() == 0: + train_metric = jax_utils.unreplicate(train_metric) + train_step_progress_bar.write(f"Loss: {train_metric['loss']}") + if args.report_to == "wandb": + wandb.log( + { + "train/step": global_step, + "train/epoch": epoch, + "time/seconds_per_step": (time.time() - train_metrics_last_t) / args.logging_steps, + "train/loss": train_metric["loss"], + } + ) + train_metrics_last_t = time.time() train_step_progress_bar.close() - epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") - if args.report_to == "wandb" and jax.process_index() == 0: - wandb.log({"train/epoch": epoch, "train/loss": train_metric["loss"]}) + epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs})") # Create the pipeline using using the trained modules and save it. if jax.process_index() == 0: From 63bef8d7d7ccdc3658a721781a2df6fe5ac4a3cf Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 25 Mar 2023 09:51:10 +0000 Subject: [PATCH 05/12] gradient_accumulation_steps --- examples/controlnet/train_controlnet_flax.py | 80 ++++++++++++++++---- 1 file changed, 67 insertions(+), 13 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 2bc2dbad65f1..92d6da1626ae 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -369,6 +369,12 @@ def parse_args(): default="train_controlnet_flax", help=("The `project` argument passed to wandb"), ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of steps to accumulate gradients over" + ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") args = parser.parse_args() @@ -612,7 +618,7 @@ def main(): # Get the datasets: you can either provide your own training and evaluation files (see below) train_dataset = make_train_dataset(args, tokenizer) - total_train_batch_size = args.train_batch_size * jax.local_device_count() + total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps train_dataloader = torch.utils.data.DataLoader( train_dataset, @@ -716,12 +722,21 @@ def main(): train_rngs = jax.random.split(train_rngs, jax.local_device_count()) def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng): - dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3) - - def compute_loss(params): + + print("yiyi testing batch shape, before + after") + print(jax.tree_map(lambda x:x.shape, batch)) + print(' ') + # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1 + if args.gradient_accumulation_steps > 1: + G = args.gradient_accumulation_steps + batch = jax.tree_map( + lambda x: x.reshape((G, x.shape[0]//G) + x.shape[1:]), batch) + print(jax.tree_map(lambda x:x.shape, batch)) + + def compute_loss(params, minibatch, sample_rng): # Convert images to latent space vae_outputs = vae.apply( - {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode + {"params": vae_params}, minibatch["pixel_values"], deterministic=True, method=vae.encode ) latents = vae_outputs.latent_dist.sample(sample_rng) # (NHWC) -> (NCHW) @@ -746,12 +761,12 @@ def compute_loss(params): # Get the text embedding for conditioning encoder_hidden_states = text_encoder( - batch["input_ids"], + minibatch["input_ids"], params=text_encoder_params, train=False, )[0] - controlnet_cond = batch["conditioning_pixel_values"] + controlnet_cond = minibatch["conditioning_pixel_values"] # Predict the noise residual and compute loss down_block_res_samples, mid_block_res_sample = controlnet.apply( @@ -787,7 +802,47 @@ def compute_loss(params): return loss grad_fn = jax.value_and_grad(compute_loss) - loss, grad = grad_fn(state.params) + + # get a minibatch (one gradient accumulation slice) + def get_minibatch(batch, grad_idx): + return jax.tree_util.tree_map( + lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False), + batch, + ) + + def loss_and_grad(grad_idx, train_rng): + #create minibatch for the grad step + minibatch = (get_minibatch(batch, grad_idx) if grad_idx is not None else batch) + sample_rng, train_rng = jax.random.split(train_rng, 2) + loss, grads = grad_fn(state.params, minibatch, sample_rng) + return loss, grads, train_rng + + if args.gradient_accumulation_steps ==1: + loss, grads, new_train_rng = loss_and_grad(None, train_rng) + else: + init_loss_grad_rng = ( + 0.0, # initial value for cumul_loss + jax.tree_map(jnp.zeros_like, state.params), # initial value for cumul_grads + train_rng # initial value for train_rng + ) + + def cumul_grad_step(grad_idx, loss_grads_rng): + cumul_loss, cumul_grads, train_rng = loss_grads_rng + loss,grads,new_train_rng = loss_and_grad(grad_idx, train_rng) + cumul_loss, cumul_grads = jax.tree_map( + jnp.add, (cumul_loss, cumul_grads),(loss,grads)) + return cumul_loss, cumul_grads, new_train_rng + + loss, grads, new_train_rng = jax.lax.fori_loop( + 0, + args.gradient_accumulation_steps, + cumul_grad_step, + init_loss_grad_rng, + ) + loss, grads = jax.tree_map( + lambda x: x / args.gradient_accumulation_steps, (loss, grads) + ) + grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) @@ -865,21 +920,20 @@ def compute_loss(params): log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) if global_step % args.logging_steps == 0 and jax.process_index() == 0: - train_metric = jax_utils.unreplicate(train_metric) - train_step_progress_bar.write(f"Loss: {train_metric['loss']}") if args.report_to == "wandb": wandb.log( { "train/step": global_step, "train/epoch": epoch, "time/seconds_per_step": (time.time() - train_metrics_last_t) / args.logging_steps, - "train/loss": train_metric["loss"], + "train/loss": jax_utils.unreplicate(train_metric)["loss"], } ) train_metrics_last_t = time.time() - + + train_metric = jax_utils.unreplicate(train_metric) train_step_progress_bar.close() - epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs})") + epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") # Create the pipeline using using the trained modules and save it. if jax.process_index() == 0: From d9cf213e9b8cc59be953e1da7df916346612aef3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 25 Mar 2023 23:27:36 +0000 Subject: [PATCH 06/12] fix --- examples/controlnet/train_controlnet_flax.py | 91 +++++++++----------- 1 file changed, 43 insertions(+), 48 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 92d6da1626ae..386a356d5423 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -18,7 +18,6 @@ import math import os import random -import time from pathlib import Path from typing import Optional @@ -70,7 +69,6 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d dtype=weight_dtype, revision=args.revision, from_pt=args.from_pt, - ) params = jax_utils.replicate(params) params["controlnet"] = controlnet_params @@ -370,10 +368,7 @@ def parse_args(): help=("The `project` argument passed to wandb"), ) parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of steps to accumulate gradients over" + "--gradient_accumulation_steps", type=int, default=1, help="Number of steps to accumulate gradients over" ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -575,7 +570,7 @@ def main(): else: transformers.utils.logging.set_verbosity_error() - # handle wandb init + # wandb init if jax.process_index() == 0 and args.report_to == "wandb": wandb.init( project=args.tracker_project_name, @@ -626,6 +621,7 @@ def main(): collate_fn=collate_fn, batch_size=total_train_batch_size, num_workers=args.dataloader_num_workers, + drop_last=True, ) weight_dtype = jnp.float32 @@ -722,16 +718,10 @@ def main(): train_rngs = jax.random.split(train_rngs, jax.local_device_count()) def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng): - - print("yiyi testing batch shape, before + after") - print(jax.tree_map(lambda x:x.shape, batch)) - print(' ') - # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1 + # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1 if args.gradient_accumulation_steps > 1: - G = args.gradient_accumulation_steps - batch = jax.tree_map( - lambda x: x.reshape((G, x.shape[0]//G) + x.shape[1:]), batch) - print(jax.tree_map(lambda x:x.shape, batch)) + G = args.gradient_accumulation_steps + batch = jax.tree_map(lambda x: x.reshape((G, x.shape[0] // G) + x.shape[1:]), batch) def compute_loss(params, minibatch, sample_rng): # Convert images to latent space @@ -809,39 +799,36 @@ def get_minibatch(batch, grad_idx): lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False), batch, ) - + def loss_and_grad(grad_idx, train_rng): - #create minibatch for the grad step - minibatch = (get_minibatch(batch, grad_idx) if grad_idx is not None else batch) + # create minibatch for the grad step + minibatch = get_minibatch(batch, grad_idx) if grad_idx is not None else batch sample_rng, train_rng = jax.random.split(train_rng, 2) - loss, grads = grad_fn(state.params, minibatch, sample_rng) - return loss, grads, train_rng - - if args.gradient_accumulation_steps ==1: - loss, grads, new_train_rng = loss_and_grad(None, train_rng) + loss, grad = grad_fn(state.params, minibatch, sample_rng) + return loss, grad, train_rng + + if args.gradient_accumulation_steps == 1: + loss, grad, new_train_rng = loss_and_grad(None, train_rng) else: init_loss_grad_rng = ( - 0.0, # initial value for cumul_loss - jax.tree_map(jnp.zeros_like, state.params), # initial value for cumul_grads - train_rng # initial value for train_rng - ) - - def cumul_grad_step(grad_idx, loss_grads_rng): - cumul_loss, cumul_grads, train_rng = loss_grads_rng - loss,grads,new_train_rng = loss_and_grad(grad_idx, train_rng) - cumul_loss, cumul_grads = jax.tree_map( - jnp.add, (cumul_loss, cumul_grads),(loss,grads)) - return cumul_loss, cumul_grads, new_train_rng - - loss, grads, new_train_rng = jax.lax.fori_loop( + 0.0, # initial value for cumul_loss + jax.tree_map(jnp.zeros_like, state.params), # initial value for cumul_grad + train_rng, # initial value for train_rng + ) + + def cumul_grad_step(grad_idx, loss_grad_rng): + cumul_loss, cumul_grad, train_rng = loss_grad_rng + loss, grad, new_train_rng = loss_and_grad(grad_idx, train_rng) + cumul_loss, cumul_grad = jax.tree_map(jnp.add, (cumul_loss, cumul_grad), (loss, grad)) + return cumul_loss, cumul_grad, new_train_rng + + loss, grad, new_train_rng = jax.lax.fori_loop( 0, - args.gradient_accumulation_steps, + args.gradient_accumulation_steps, cumul_grad_step, init_loss_grad_rng, ) - loss, grads = jax.tree_map( - lambda x: x / args.gradient_accumulation_steps, (loss, grads) - ) + loss, grad = jax.tree_map(lambda x: x / args.gradient_accumulation_steps, (loss, grad)) grad = jax.lax.pmean(grad, "batch") @@ -862,7 +849,7 @@ def cumul_grad_step(grad_idx, loss_grads_rng): vae_params = jax_utils.replicate(vae_params) # Train! - num_update_steps_per_epoch = math.ceil(len(train_dataloader)) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) # Scheduler and math around the number of training steps. if args.max_train_steps is None: @@ -875,10 +862,20 @@ def cumul_grad_step(grad_idx, loss_grads_rng): logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") - logger.info(f" Total optimization steps = {args.max_train_steps}") + logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}") + + if jax.process_index() == 0: + wandb.define_metric("*", step_metric="train/step") + wandb.config.update( + { + "num_train_examples": len(train_dataset), + "total_train_batch_size": total_train_batch_size, + "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch, + "num_devices": jax.device_count(), + } + ) global_step = 0 - train_metrics_last_t = time.time() epochs = tqdm( range(args.num_train_epochs), desc="Epoch ... ", @@ -922,15 +919,13 @@ def cumul_grad_step(grad_idx, loss_grads_rng): if global_step % args.logging_steps == 0 and jax.process_index() == 0: if args.report_to == "wandb": wandb.log( - { + { "train/step": global_step, "train/epoch": epoch, - "time/seconds_per_step": (time.time() - train_metrics_last_t) / args.logging_steps, "train/loss": jax_utils.unreplicate(train_metric)["loss"], } ) - train_metrics_last_t = time.time() - + train_metric = jax_utils.unreplicate(train_metric) train_step_progress_bar.close() epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") From bdc71d62f5f9e86306109878fb6d7f3c42b3085c Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 27 Mar 2023 06:24:03 -1000 Subject: [PATCH 07/12] Apply suggestions from code review Co-authored-by: Patrick von Platen --- examples/controlnet/train_controlnet_flax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 386a356d5423..60d1e58ad5ea 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -708,7 +708,7 @@ def main(): state = train_state.TrainState.create(apply_fn=controlnet.__call__, params=controlnet_params, tx=optimizer) - noise_scheduler = FlaxDDPMScheduler( + noise_scheduler = FlaxDDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) noise_scheduler_state = noise_scheduler.create_state() @@ -720,8 +720,8 @@ def main(): def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng): # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1 if args.gradient_accumulation_steps > 1: - G = args.gradient_accumulation_steps - batch = jax.tree_map(lambda x: x.reshape((G, x.shape[0] // G) + x.shape[1:]), batch) + grad_steps = args.gradient_accumulation_steps + batch = jax.tree_map(lambda x: x.reshape((grad_steps, x.shape[0] // grad_steps) + x.shape[1:]), batch) def compute_loss(params, minibatch, sample_rng): # Convert images to latent space From ad4137b9189d47a26b663585649dcfbdb0b5863c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 28 Mar 2023 02:07:06 +0000 Subject: [PATCH 08/12] save_model_card --- examples/controlnet/train_controlnet_flax.py | 70 ++++++++++++++++++-- 1 file changed, 63 insertions(+), 7 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 60d1e58ad5ea..c6c95170da2d 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -58,6 +58,18 @@ logger = logging.getLogger(__name__) +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype): logger.info("Running validation... ") @@ -132,6 +144,43 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d else: logger.warn(f"image logging not implemented for {args.report_to}") + return image_logs + + +def save_model_card(repo_name, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- controlnet +inference: true +--- + """ + model_card = f""" +# controlnet- {repo_name} + +These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") @@ -590,8 +639,8 @@ def main(): repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) else: repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + repo_url = create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository(args.output_dir, clone_from=repo_url, token=args.hub_token) with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: if "step_*" not in gitignore: @@ -708,10 +757,9 @@ def main(): state = train_state.TrainState.create(apply_fn=controlnet.__call__, params=controlnet_params, tx=optimizer) - noise_scheduler = FlaxDDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + noise_scheduler, noise_scheduler_state = FlaxDDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" ) - noise_scheduler_state = noise_scheduler.create_state() # Initialize our training validation_rng, train_rngs = jax.random.split(rng) @@ -870,7 +918,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): { "num_train_examples": len(train_dataset), "total_train_batch_size": total_train_batch_size, - "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch, + "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch, "num_devices": jax.device_count(), } ) @@ -914,7 +962,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): and global_step % args.validation_steps == 0 and jax.process_index() == 0 ): - log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) + _ = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) if global_step % args.logging_steps == 0 and jax.process_index() == 0: if args.report_to == "wandb": @@ -932,12 +980,20 @@ def cumul_grad_step(grad_idx, loss_grad_rng): # Create the pipeline using using the trained modules and save it. if jax.process_index() == 0: + image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) + controlnet.save_pretrained( args.output_dir, params=get_params_to_save(state.params), ) if args.push_to_hub: + save_model_card( + repo_name, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) From 162697ce7724cf684055cf90c2a067b3dd11201c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 28 Mar 2023 02:59:17 +0000 Subject: [PATCH 09/12] readme --- examples/controlnet/README.md | 95 +++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 32de31e14bbd..276ac000692e 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -267,3 +267,98 @@ image = pipe( image.save("./output.png") ``` + +## Training with Flax/JAX + +For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. + +### Running on Google Cloud TPU + +See below for commands to set up a TPU VM(`--accelerator-type v4-8`). For more details about how to set up and use TPUs, refer to [Cloud docs for single VM setup](https://cloud.google.com/tpu/docs/run-calculation-jax). + +First create a single TPUv4-8 VM and connect to it: + +``` +ZONE=us-central2-b +TPU_TYPE=v4-8 +VM_NAME=hg_flax + +gcloud alpha compute tpus tpu-vm create $VM_NAME \ + --zone $ZONE \ + --accelerator-type $TPU_TYPE \ + --version tpu-vm-v4-base + +gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \ +``` + +When connected install JAX `0.4.5`: + +``` +pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +``` + +To verify that JAX was correctly installed, you can run the following command: + +``` +import jax +jax.device_count() +``` + +This should display the number of TPU cores, which should be 4 on a TPUv4-8 VM. + +Then install Diffusers and the library's training dependencies: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run + +```bash +pip install -U -r requirements_flax.txt +``` + +Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress + +``` +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png + +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png +``` + +We encourage you to store or share your model with the community. To use huggingface hub, please login to your Hugging Face account, or ([create one](https://huggingface.co/docs/diffusers/main/en/training/hf.co/join) if you don’t have one already): + +``` +huggingface-cli login +``` + +Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub: + +``` +export MODEL_DIR="runwayml/stable-diffusion-v1-5" +export OUTPUT_DIR="control_out" +export HUB_MODEL_ID="yiyixu/fill-circle-controlnet" +``` + +And finally start the training + +``` +python3 train_controlnet_flax.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=fusing/fill50k \ + --resolution=512 \ + --learning_rate=1e-5 \ + --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --validation_steps=1000 \ + --train_batch_size=2 \ + --revision="non-ema" \ + --from_pt \ + --report_to="wandb" \ + --max_train_steps=10000 \ + --push_to_hub \ + --hub_model_id=$HUB_MODEL_ID + ``` From 64a9c020357bd452eefc223c027c3fadecd363e5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 28 Mar 2023 03:56:51 +0000 Subject: [PATCH 10/12] fix --- examples/controlnet/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 276ac000692e..ce6f8692ff59 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -324,7 +324,6 @@ Now let's downloading two conditioning images that we will use to run validation ``` wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png - wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png ``` @@ -339,7 +338,7 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v ``` export MODEL_DIR="runwayml/stable-diffusion-v1-5" export OUTPUT_DIR="control_out" -export HUB_MODEL_ID="yiyixu/fill-circle-controlnet" +export HUB_MODEL_ID="fill-circle-controlnet" ``` And finally start the training From 5eec46ed4366c4f524a93544aa296469e7d46bff Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 28 Mar 2023 04:46:07 +0000 Subject: [PATCH 11/12] fix --- examples/controlnet/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index ce6f8692ff59..d5f528d88725 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -361,3 +361,5 @@ python3 train_controlnet_flax.py \ --push_to_hub \ --hub_model_id=$HUB_MODEL_ID ``` + +By the end of training, the final checkpoint will be automatically stored on your huggingface hub account, under `$HUB_MODEL_ID` (see an example [here](https://huggingface.co/YiYiXu/fill-circle-controlnet)) From edd8d76ec4a83b6eb13a20d156f3ffcdd12ef158 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 28 Mar 2023 04:49:43 +0000 Subject: [PATCH 12/12] fix --- examples/controlnet/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index d5f528d88725..0650c2230b71 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -362,4 +362,4 @@ python3 train_controlnet_flax.py \ --hub_model_id=$HUB_MODEL_ID ``` -By the end of training, the final checkpoint will be automatically stored on your huggingface hub account, under `$HUB_MODEL_ID` (see an example [here](https://huggingface.co/YiYiXu/fill-circle-controlnet)) +Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).