Skip to content

Conversation

@LouisRouss
Copy link
Owner

@LouisRouss LouisRouss commented Sep 10, 2025

Adds PrefGRPO training to Diffulab.
The implementation is done from scratch but I take inspiration from the official repository.
Tons of work left to do, draft PR just to open the subject.

This PR also:

  • correct GaussianDiffusion class and introduces the Sampler abstractions
  • serves as a test and fix for the text to image logic. I hadn't had the time to fix everything on the master for this text code that was coded long time ago and not tested.

@LouisRouss LouisRouss marked this pull request as draft September 10, 2025 20:33
…g and reward computation - Allow batch processing with multiple prompts
…uler and EulerMaruyama methods for flow based models
…put_size; add PreComputedEmbedder class for handling precomputed embeddings; update SD3TextEmbedder to streamline initialization and embedding retrieval.
…mprove type casting for image and text feature encoding.
… timestep handling in Euler and EulerMaruyama samplers.
…ting methods

- Fix Gaussian Diffusion in general
…asses to remove ModelInputGRPO and standardize on ModelInput, enhancing type consistency and clarity.

- Refactor GRPOTrainer to take into account sampler use and grpo function merged into basic ones
…fusion classes to make data_shape optional, enhancing flexibility in model input handling.
@LouisRouss LouisRouss changed the title [WIP] feature/PrefGRPO [WIP] feature/PrefGRPO-txt2img-cleaning Oct 23, 2025
@LouisRouss LouisRouss marked this pull request as ready for review November 2, 2025 14:41
@LouisRouss LouisRouss changed the title [WIP] feature/PrefGRPO-txt2img-cleaning feature/PrefGRPO-txt2img-cleaning Nov 2, 2025
LouisRouss and others added 2 commits November 18, 2025 22:38
- remove pre compute on dataset for context embedder
- Introduced RotaryPositionalEmbeddingNDim to support N-dimensional rotary embeddings.
- Updated DiTAttention and MMDiTAttention classes to utilize the new rotary embedding structure.
- Modified forward methods to accept precomputed cosine and sine values for rotary embeddings.
- Enhanced PerceiverRotaryPositionalEmbedding to work with N-dimensional inputs.
- Adjusted PerceiverAttention and PerceiverResampler to accommodate new rotary embedding logic.
- Replaced RMSNorm with LayerNorm in DiTBlock and MMDiTBlock for consistency.
- Updated MMDiT class to compute positional encodings for both text and image inputs.
- Added PackedSwiGLU for improved MLP performance.
- Fixed minor issues in the base trainer's model input handling.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants