We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 73e0bc6 commit f9675e9Copy full SHA for f9675e9
src/diffusers/modeling_flax_utils.py
@@ -393,8 +393,7 @@ def from_pretrained(
393
# flatten dicts
394
state = flatten_dict(state)
395
396
- prng_key = jax.random.PRNGKey(0)
397
- params_shape_tree = jax.eval_shape(model.init_weights, prng_key)
+ params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
398
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
399
400
shape_state = flatten_dict(unfreeze(params_shape_tree))
0 commit comments