Commit 6a5597a
committed
Check shape and remove deprecated APIs in scheduling_ddpm_flax.py
`model_output.shape` may only have rank 1.
There are warnings related to use of random keys.
```
tests/schedulers/test_scheduler_flax.py: 13 warnings
/Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)
tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas
/Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type]
```1 parent db969cc commit 6a5597a
1 file changed
+7
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
222 | 222 | | |
223 | 223 | | |
224 | 224 | | |
225 | | - | |
| 225 | + | |
226 | 226 | | |
227 | | - | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
228 | 232 | | |
229 | 233 | | |
230 | 234 | | |
| |||
264 | 268 | | |
265 | 269 | | |
266 | 270 | | |
267 | | - | |
| 271 | + | |
268 | 272 | | |
269 | 273 | | |
270 | 274 | | |
| |||
0 commit comments