-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Description
Describe the bug
when running img2img pipeline, the outputs are all too light. i.e. nothing darker than middle gray. it's not very pronounced when using a higher/default "strength" value, but for use cases akin to style transfer where only minor noise is added to original image, it's very apparent. in other words, the less the image is meant to be changed, the more light the output appears.
this felt like ye olde torch copypasta problem, wherein pytorch tensors for images like to be in a range of -1 to 1, PIL uses 0 to 255 and numpy uses 0 to 1. or something, i'm not sure, i don't know why no one keeps to a single standard and we always swap channels and convert value ranges... anyway, i tested it and indeed things work correctly when i alter my inputs for pixel_values to be in a value range of -1 to 1 instead of what numpy arrays are using everywhere, 0 to 1. however, the outputs of the pipeline are always in the range 0 to 1, i.e., it expects -1 to 1 inputs, but it gives back 0 to 1 outputs.
expected behavior is that it would take the same format in as it gives out. in the jax/numpy world, that should be ranged 0 to 1.
Reproduction
i'm running this on Colab with TPU, so sharding/replication is done accordingly below...
swap the commented line for image argument. my input array for pixel_values is ranged 0 to 1, so the extra *2-1 makes it -1 to 1.
- when running this pipe with inputs ranged 0 to 1, the outputs are muted, light
- when running this pipe with inputs ranged -1 to 1, the outputs are correct white balance
in both cases, the outputs are ranged 0 to 1
images = img2img( ### jax_batch is a batched conversion from a tensorflow data generator, reading image files and captions.
prompt_ids=shard(jax_batch["input_ids"][:,0,0]), # [:,:,0] because of multi captions
# image=shard(jax_batch["pixel_values"].squeeze()), ### this is normal input, values in range 0 to 1
image=shard(jax_batch["pixel_values"].squeeze() * 2 - 1), ### this proves it expects inputs -1 to 1
params=jax_utils.replicate({
'text_encoder': text_encoder.params,
'unet': unet_params,
'vae': vae_params,
'scheduler': scheduler.create_state(),
'safety_checker': {}
}),
prng_seed=jax_utils.replicate(rng),
strength=0.5, # default = 0.8,
num_inference_steps=100,
guidance_scale=jax_utils.replicate(jnp.asarray([5],dtype=jnp.float32)), # error with args.dtype i.e. jnp.blfoat16
jit=True # needed for pmap cuz that's what "jit" means.
).images
for img in images:
print(img.min(), img.max()) ### proves output is range 0 to 1
display(jnp_to_pil_image(img)) ### jnp_to_pil_image just swaps CHW<>HWC as needed, multiplies 255, converts to PIL
Logs
No response
System Info
colab TPU, high-ram, latest versions of transformers and diffusers and flax and optax.