From 336ae35e4a03caadde551b228ca9e5de41efb688 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 14 Apr 2023 20:01:51 +0000 Subject: [PATCH] extract pipeline from log_validation --- examples/controlnet/train_controlnet_flax.py | 35 +++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 0b413ace09d2..24b32e7f4301 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -76,20 +76,11 @@ def image_grid(imgs, rows, cols): return grid -def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype): - logger.info("Running validation... ") +def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype): + logger.info("Running validation...") - pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( - args.pretrained_model_name_or_path, - tokenizer=tokenizer, - controlnet=controlnet, - safety_checker=None, - dtype=weight_dtype, - revision=args.revision, - from_pt=args.from_pt, - ) - params = jax_utils.replicate(params) - params["controlnet"] = controlnet_params + pipeline_params = pipeline_params.copy() + pipeline_params["controlnet"] = controlnet_params num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) @@ -121,7 +112,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d images = pipeline( prompt_ids=prompt_ids, image=processed_image, - params=params, + params=pipeline_params, prng_seed=prng_seed, num_inference_steps=50, jit=True, @@ -176,6 +167,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N - text-to-image - diffusers - controlnet +- jax-diffusers-event inference: true --- """ @@ -800,6 +792,17 @@ def main(): ]: controlnet_params[key] = unet_params[key] + pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + tokenizer=tokenizer, + controlnet=controlnet, + safety_checker=None, + dtype=weight_dtype, + revision=args.revision, + from_pt=args.from_pt, + ) + pipeline_params = jax_utils.replicate(pipeline_params) + # Optimization if args.scale_lr: args.learning_rate = args.learning_rate * total_train_batch_size @@ -1073,7 +1076,7 @@ def l2(xs): and global_step % args.validation_steps == 0 and jax.process_index() == 0 ): - _ = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) + _ = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype) if global_step % args.logging_steps == 0 and jax.process_index() == 0: if args.report_to == "wandb": @@ -1105,7 +1108,7 @@ def l2(xs): if args.validation_prompt is not None: if args.profile_validation: jax.profiler.start_trace(args.output_dir) - image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) + image_logs = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype) if args.profile_validation: jax.profiler.stop_trace() else: