-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Closed
Description
For example here, it is using from dataclasses import dataclass
diffusers/src/diffusers/models/unet_2d_condition_flax.py
Lines 22 to 23 in d8b0e4f
| @dataclass | |
| class FlaxUNet2DConditionOutput(BaseOutput): |
But transformers equivalents use @flax.struct.dataclass. For example here
@flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput):The benefit of using @flax.struct.dataclass over naive python dataclass is that: jax.jit can consume @flax.struct.dataclass
So the question is: should we use @flax.struct.dataclass on diffusers as well ?
Metadata
Metadata
Labels
No labels