From 9b9e6d4f93d0f0b5df1543157003977e85836f40 Mon Sep 17 00:00:00 2001 From: Philip Pham Date: Wed, 17 Apr 2024 12:20:13 -0700 Subject: [PATCH] Check shape and remove deprecated APIs in scheduling_ddpm_flax.py `model_output.shape` may only have rank 1. There are warnings related to use of random keys. ``` tests/schedulers/test_scheduler_flax.py: 13 warnings /Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error. noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas /Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error. u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type] ``` --- src/diffusers/schedulers/scheduling_ddpm_flax.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 6bdfa5eb5e0d..d06a171159ee 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -222,9 +222,13 @@ def step( t = timestep if key is None: - key = jax.random.PRNGKey(0) + key = jax.random.key(0) - if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: + if ( + len(model_output.shape) > 1 + and model_output.shape[1] == sample.shape[1] * 2 + and self.config.variance_type in ["learned", "learned_range"] + ): model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) else: predicted_variance = None @@ -264,7 +268,7 @@ def step( # 6. Add noise def random_variance(): - split_key = jax.random.split(key, num=1) + split_key = jax.random.split(key, num=1)[0] noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise