Skip to content

Commit f9675e9

Browse files
Mishig Davaadorjpatil-suraj
andauthored
Update src/diffusers/modeling_flax_utils.py
Co-authored-by: Suraj Patil <[email protected]>
1 parent 73e0bc6 commit f9675e9

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,7 @@ def from_pretrained(
393393
# flatten dicts
394394
state = flatten_dict(state)
395395

396-
prng_key = jax.random.PRNGKey(0)
397-
params_shape_tree = jax.eval_shape(model.init_weights, prng_key)
396+
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
398397
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
399398

400399
shape_state = flatten_dict(unfreeze(params_shape_tree))

0 commit comments

Comments
 (0)