Skip to content

Stable Diffusion Flax Img2Img pipeline expects different image format input than output #1882

@krahnikblis

Description

@krahnikblis

Describe the bug

when running img2img pipeline, the outputs are all too light. i.e. nothing darker than middle gray. it's not very pronounced when using a higher/default "strength" value, but for use cases akin to style transfer where only minor noise is added to original image, it's very apparent. in other words, the less the image is meant to be changed, the more light the output appears.

this felt like ye olde torch copypasta problem, wherein pytorch tensors for images like to be in a range of -1 to 1, PIL uses 0 to 255 and numpy uses 0 to 1. or something, i'm not sure, i don't know why no one keeps to a single standard and we always swap channels and convert value ranges... anyway, i tested it and indeed things work correctly when i alter my inputs for pixel_values to be in a value range of -1 to 1 instead of what numpy arrays are using everywhere, 0 to 1. however, the outputs of the pipeline are always in the range 0 to 1, i.e., it expects -1 to 1 inputs, but it gives back 0 to 1 outputs.

expected behavior is that it would take the same format in as it gives out. in the jax/numpy world, that should be ranged 0 to 1.

Reproduction

i'm running this on Colab with TPU, so sharding/replication is done accordingly below...

swap the commented line for image argument. my input array for pixel_values is ranged 0 to 1, so the extra *2-1 makes it -1 to 1.

  1. when running this pipe with inputs ranged 0 to 1, the outputs are muted, light
  2. when running this pipe with inputs ranged -1 to 1, the outputs are correct white balance
    in both cases, the outputs are ranged 0 to 1
images = img2img(  ### jax_batch is a batched conversion from a tensorflow data generator, reading image files and captions. 
            prompt_ids=shard(jax_batch["input_ids"][:,0,0]), # [:,:,0] because of multi captions
            # image=shard(jax_batch["pixel_values"].squeeze()),         ### this is normal input, values in range 0 to 1
            image=shard(jax_batch["pixel_values"].squeeze() * 2 - 1),   ### this proves it expects inputs -1 to 1
            params=jax_utils.replicate({
                'text_encoder':     text_encoder.params, 
                'unet':             unet_params, 
                'vae':              vae_params, 
                'scheduler':        scheduler.create_state(), 
                'safety_checker':   {}
                }), 
            prng_seed=jax_utils.replicate(rng),
            strength=0.5, # default = 0.8,
            num_inference_steps=100,
            guidance_scale=jax_utils.replicate(jnp.asarray([5],dtype=jnp.float32)), # error with args.dtype i.e. jnp.blfoat16
            jit=True # needed for pmap cuz that's what "jit" means. 
            ).images
for img in images:
    print(img.min(), img.max())                                         ### proves output is range 0 to 1
    display(jnp_to_pil_image(img))                            ### jnp_to_pil_image just swaps CHW<>HWC as needed, multiplies 255, converts to PIL

Logs

No response

System Info

colab TPU, high-ram, latest versions of transformers and diffusers and flax and optax.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions