-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Description
Describe the bug
using Colab TPU environment, and bfloat16, blows up pipes if using anything but a packaged-built pipeline and its built-in default scheduler, even then it can unpredictably produce NaNs instead of images (can't reliably replicate the NaN thing, but it happens more often than not - see logs/output attached). the errors are often about data type mismatch, but in different places within the pipe or schedulers. one of the schedulers is missing its "key" argument in the pipe implementation, so that errors for that reason.
prior to recent release, i had a functional workbook which would build components separately, basically to make it easier to implement custom VAE and apply my LoRA-combinator model to the Unet and Text Encoder params. this now breaks 100% of the time (the usage of component-built pipe). however, a separate issue i logged re: img2img whitewashed outputs demonstrated that an all-in-one pipeline functions fine except for that issue, so i tried that as a fix in this case... in trying that way, outputs can be ok, but not reliably (NaNs issue) - in attached outputs, i was lucky enough to get image outputs in 3 out of 10 tests, but cannot reliably use the combo of scheduler + dtype that happened to work (because more NaNs)
tests: since the errors are consistently in the scheduler itself, or in the pipe's application/usage of scheduler, and since they regard the data type, the test reproduction script below is a loop through all Flax schedulers, and through either default or bfloat16 type when building them. the first time i tried this, i was using from_config for the schedulers (because there's no weights file, thus schedulers aren't "pretrained"); that was a whole mess of other problems, which were reduced when i switched to from_pretrained (again, kinda nonsense to use that method in this way). i used the JAX config to debug NaNs, and had to run the pipe with jit=False; it makes the run very slow, but that isolated the problem to this line in the pipe: latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() which again is to do with the scheduler...
i also found that some of the combos with data type errors can be resolved by using custom latents as input to pipe, but in those cases, the outputs were all NaNs (swap one problem for another), so i didn't include that test method here. i'm also not sure what's the utility of using custom latents at all, and imo it's not something a user should need to [even be expected to know how to] do to defeat data type errors...
Reproduction
cell 1:
#@title bfloat + scheduler issue reproduction p1 (installs + maybe colab runtime restart)
!pip install -qq -U jax jaxlib
!pip install -qq flax optax transformers ftfy diffusers
### sometimes requires restart at this point, because google.cell 2:
#@title bfloat + scheduler issue reproduction p2 (imports + setup)
# setup TPU
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
# jax.tools.colab_tpu.setup_tpu(tpu_driver_version='tpu_driver_nightly')
import os
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
# imports
import PIL, jax, requests
import jax.numpy as jnp
import numpy as np
from diffusers import FlaxStableDiffusionPipeline, FlaxLMSDiscreteScheduler, FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxPNDMScheduler, FlaxDDPMScheduler
from flax import jax_utils
# make a handy display fn:
def display_jnp_image(image):
image = jnp.clip(np.array(image * 255),0,255).astype(np.uint8)
if image.shape[0] == 3:
image = np.transpose(image, axes=(1, 2, 0))
image = PIL.Image.fromarray(np.array(image),mode="RGB")
display(image)
# get pretrained
pretrained_model_name_or_path = "stable-diffusion-v1-5/"
if not os.path.exists(pretrained_model_name_or_path):
!git lfs install
!git clone -b bf16 https://huggingface.co/runwayml/stable-diffusion-v1-5
# variables and pipe
dtype = jnp.bfloat16
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, dtype=dtype)
orig_scheduler = pipe.scheduler._class_name
rng = jax.random.PRNGKey(42)
rng_repl = jax.random.split(rng,jax.device_count())
guidance_scale = jax_utils.replicate(jnp.array([7.5]))
prompt = "a pretty goose for you, and a pretty racing car"
prompt_ids = jax_utils.replicate(pipe.tokenizer(prompt,padding="max_length",max_length=pipe.tokenizer.model_max_length,truncation=True,return_tensors="np").input_ids)cell 3:
#@title bfloat + scheduler issue reproduction p3 (run + scheduler tests)
# loop scheduler tests
for dtype_test in [0,1]:
print("Testing scheduler data type" + (" default" if dtype_test == 0 else " bfloat16"))
settings = {} if dtype_test == 0 else {"dtype": jnp.bfloat16} # edit: oops, had reversed.
for lib_scheduler in [FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxDDPMScheduler]:
# build and implant
pipe.scheduler, params["scheduler"] = lib_scheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", **settings) # nonsense to need to use from_pretrained instead of from_config here...
print("Scheduler = " + lib_scheduler.__name__ + (": DEFAULT" if orig_scheduler == lib_scheduler.__name__ else ""))
try:
# run
images_jnp = pipe(
prompt_ids=prompt_ids,
params=jax_utils.replicate(params),
prng_seed=rng_repl,
num_inference_steps=50,
guidance_scale=guidance_scale,
jit=True
).images
# show me what you're made of
if jnp.any(jnp.isnan(images_jnp)):
raise ValueError("output image is NaNs")
display_jnp_image(images_jnp[0].squeeze())
except Exception as e:
print("ERROR: " + str(e))attaching outputs in place of logs...
Logs
don't know how to get logs, but here's output:
Testing scheduler data type bfloat16
Scheduler = FlaxPNDMScheduler: DEFAULT
ERROR: lax.select requires arguments to have the same dtypes, got bfloat16, float32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
Scheduler = FlaxLMSDiscreteScheduler
ERROR: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/1)>
The problem arose with the `bool` function.
The error occurred while tracing the function scanned_fun at /usr/local/lib/python3.8/dist-packages/jax/_src/lax/control_flow/loops.py:1619 for scan. This concrete value was not available in Python because it depends on the values of the argument 'loop_carry'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
Scheduler = FlaxDDIMScheduler
ERROR: output image is NaNs
Scheduler = FlaxDPMSolverMultistepScheduler
ERROR: scan carry output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ('ShapedArray(float32[1,4,64,64])', DPMSolverMultistepSchedulerState(common=CommonSchedulerState(alphas='ShapedArray(bfloat16[1000])', betas='ShapedArray(bfloat16[1000])', alphas_cumprod='ShapedArray(bfloat16[1000])'), alpha_t='ShapedArray(bfloat16[1000])', sigma_t='ShapedArray(bfloat16[1000])', lambda_t='ShapedArray(bfloat16[1000])', init_noise_sigma='ShapedArray(bfloat16[])', timesteps='ShapedArray(int32[50])', num_inference_steps='ShapedArray(int32[], weak_type=True)', model_outputs='ShapedArray(bfloat16[2,1,4,64,64])', lower_order_nums='ShapedArray(int32[])', prev_timestep='ShapedArray(int32[])', cur_sample='DIFFERENT ShapedArray(float32[1,4,64,64]) vs. ShapedArray(bfloat16[1,4,64,64])'))).
Scheduler = FlaxDDPMScheduler
ERROR: step() missing 1 required positional argument: 'key'
Testing scheduler data type default
Scheduler = FlaxPNDMScheduler: DEFAULT
IMAGE WAS OK
Scheduler = FlaxLMSDiscreteScheduler
ERROR: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/1)>
The problem arose with the `bool` function.
The error occurred while tracing the function scanned_fun at /usr/local/lib/python3.8/dist-packages/jax/_src/lax/control_flow/loops.py:1619 for scan. This concrete value was not available in Python because it depends on the values of the argument 'loop_carry'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
Scheduler = FlaxDDIMScheduler
IMAGE WAS OK
Scheduler = FlaxDPMSolverMultistepScheduler
IMAGE WAS OK
Scheduler = FlaxDDPMScheduler
ERROR: step() missing 1 required positional argument: 'key'System Info
Google Colab Pro, TPU environment with High RAM, default setup except what is included in reproduction script