Skip to content

Commit bacca27

Browse files
bleppingadlerfaulkner
authored andcommitted
Add support for Chroma Radiance (comfyanonymous#9682)
* Initial Chroma Radiance support * Minor Chroma Radiance cleanups * Update Radiance nodes to ensure latents/images are on the intermediate device * Fix Chroma Radiance memory estimation. * Increase Chroma Radiance memory usage factor * Increase Chroma Radiance memory usage factor once again * Ensure images are multiples of 16 for Chroma Radiance Add batch dimension and fix channels when necessary in ChromaRadianceImageToLatent node * Tile Chroma Radiance NeRF to reduce memory consumption, update memory usage factor * Update Radiance to support conv nerf final head type. * Allow setting NeRF embedder dtype for Radiance Bump Radiance nerf tile size to 32 Support EasyCache/LazyCache on Radiance (maybe) * Add ChromaRadianceStubVAE node * Crop Radiance image inputs to multiples of 16 instead of erroring to be in line with existing VAE behavior * Convert Chroma Radiance nodes to V3 schema. * Add ChromaRadianceOptions node and backend support. Cleanups/refactoring to reduce code duplication with Chroma. * Fix overriding the NeRF embedder dtype for Chroma Radiance * Minor Chroma Radiance cleanups * Move Chroma Radiance to its own directory in ldm Minor code cleanups and tooltip improvements * Fix Chroma Radiance embedder dtype overriding * Remove Radiance dynamic nerf_embedder dtype override feature * Unbork Radiance NeRF embedder init * Remove Chroma Radiance image conversion and stub VAE nodes Add a chroma_radiance option to the VAELoader builtin node which uses comfy.sd.PixelspaceConversionVAE Add a PixelspaceConversionVAE to comfy.sd for converting BHWC 0..1 <-> BCHW -1..1
1 parent 95a5630 commit bacca27

File tree

10 files changed

+770
-9
lines changed

10 files changed

+770
-9
lines changed

comfy/latent_formats.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,3 +629,20 @@ class Hunyuan3Dv2mini(LatentFormat):
629629
class ACEAudio(LatentFormat):
630630
latent_channels = 8
631631
latent_dimensions = 2
632+
633+
class ChromaRadiance(LatentFormat):
634+
latent_channels = 3
635+
636+
def __init__(self):
637+
self.latent_rgb_factors = [
638+
# R G B
639+
[ 1.0, 0.0, 0.0 ],
640+
[ 0.0, 1.0, 0.0 ],
641+
[ 0.0, 0.0, 1.0 ]
642+
]
643+
644+
def process_in(self, latent):
645+
return latent
646+
647+
def process_out(self, latent):
648+
return latent

comfy/ldm/chroma/model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ def forward_orig(
151151
attn_mask: Tensor = None,
152152
) -> Tensor:
153153
patches_replace = transformer_options.get("patches_replace", {})
154-
if img.ndim != 3 or txt.ndim != 3:
155-
raise ValueError("Input img and txt tensors must have 3 dimensions.")
156154

157155
# running on sequences img
158156
img = self.img_in(img)
@@ -254,8 +252,9 @@ def block_wrap(args):
254252
img[:, txt.shape[1] :, ...] += add
255253

256254
img = img[:, txt.shape[1] :, ...]
257-
final_mod = self.get_modulations(mod_vectors, "final")
258-
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
255+
if hasattr(self, "final_layer"):
256+
final_mod = self.get_modulations(mod_vectors, "final")
257+
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
259258
return img
260259

261260
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
@@ -271,6 +270,9 @@ def _forward(self, x, timestep, context, guidance, control=None, transformer_opt
271270

272271
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
273272

273+
if img.ndim != 3 or context.ndim != 3:
274+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
275+
274276
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
275277
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
276278
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Adapted from https://github.com/lodestone-rock/flow
2+
from functools import lru_cache
3+
4+
import torch
5+
from torch import nn
6+
7+
from comfy.ldm.flux.layers import RMSNorm
8+
9+
10+
class NerfEmbedder(nn.Module):
11+
"""
12+
An embedder module that combines input features with a 2D positional
13+
encoding that mimics the Discrete Cosine Transform (DCT).
14+
15+
This module takes an input tensor of shape (B, P^2, C), where P is the
16+
patch size, and enriches it with positional information before projecting
17+
it to a new hidden size.
18+
"""
19+
def __init__(
20+
self,
21+
in_channels: int,
22+
hidden_size_input: int,
23+
max_freqs: int,
24+
dtype=None,
25+
device=None,
26+
operations=None,
27+
):
28+
"""
29+
Initializes the NerfEmbedder.
30+
31+
Args:
32+
in_channels (int): The number of channels in the input tensor.
33+
hidden_size_input (int): The desired dimension of the output embedding.
34+
max_freqs (int): The number of frequency components to use for both
35+
the x and y dimensions of the positional encoding.
36+
The total number of positional features will be max_freqs^2.
37+
"""
38+
super().__init__()
39+
self.dtype = dtype
40+
self.max_freqs = max_freqs
41+
self.hidden_size_input = hidden_size_input
42+
43+
# A linear layer to project the concatenated input features and
44+
# positional encodings to the final output dimension.
45+
self.embedder = nn.Sequential(
46+
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
47+
)
48+
49+
@lru_cache(maxsize=4)
50+
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
51+
"""
52+
Generates and caches 2D DCT-like positional embeddings for a given patch size.
53+
54+
The LRU cache is a performance optimization that avoids recomputing the
55+
same positional grid on every forward pass.
56+
57+
Args:
58+
patch_size (int): The side length of the square input patch.
59+
device: The torch device to create the tensors on.
60+
dtype: The torch dtype for the tensors.
61+
62+
Returns:
63+
A tensor of shape (1, patch_size^2, max_freqs^2) containing the
64+
positional embeddings.
65+
"""
66+
# Create normalized 1D coordinate grids from 0 to 1.
67+
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
68+
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
69+
70+
# Create a 2D meshgrid of coordinates.
71+
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
72+
73+
# Reshape positions to be broadcastable with frequencies.
74+
# Shape becomes (patch_size^2, 1, 1).
75+
pos_x = pos_x.reshape(-1, 1, 1)
76+
pos_y = pos_y.reshape(-1, 1, 1)
77+
78+
# Create a 1D tensor of frequency values from 0 to max_freqs-1.
79+
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
80+
81+
# Reshape frequencies to be broadcastable for creating 2D basis functions.
82+
# freqs_x shape: (1, max_freqs, 1)
83+
# freqs_y shape: (1, 1, max_freqs)
84+
freqs_x = freqs[None, :, None]
85+
freqs_y = freqs[None, None, :]
86+
87+
# A custom weighting coefficient, not part of standard DCT.
88+
# This seems to down-weight the contribution of higher-frequency interactions.
89+
coeffs = (1 + freqs_x * freqs_y) ** -1
90+
91+
# Calculate the 1D cosine basis functions for x and y coordinates.
92+
# This is the core of the DCT formulation.
93+
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
94+
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
95+
96+
# Combine the 1D basis functions to create 2D basis functions by element-wise
97+
# multiplication, and apply the custom coefficients. Broadcasting handles the
98+
# combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
99+
# The result is flattened into a feature vector for each position.
100+
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
101+
102+
return dct
103+
104+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
105+
"""
106+
Forward pass for the embedder.
107+
108+
Args:
109+
inputs (Tensor): The input tensor of shape (B, P^2, C).
110+
111+
Returns:
112+
Tensor: The output tensor of shape (B, P^2, hidden_size_input).
113+
"""
114+
# Get the batch size, number of pixels, and number of channels.
115+
B, P2, C = inputs.shape
116+
117+
# Infer the patch side length from the number of pixels (P^2).
118+
patch_size = int(P2 ** 0.5)
119+
120+
input_dtype = inputs.dtype
121+
inputs = inputs.to(dtype=self.dtype)
122+
123+
# Fetch the pre-computed or cached positional embeddings.
124+
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
125+
126+
# Repeat the positional embeddings for each item in the batch.
127+
dct = dct.repeat(B, 1, 1)
128+
129+
# Concatenate the original input features with the positional embeddings
130+
# along the feature dimension.
131+
inputs = torch.cat((inputs, dct), dim=-1)
132+
133+
# Project the combined tensor to the target hidden size.
134+
return self.embedder(inputs).to(dtype=input_dtype)
135+
136+
137+
class NerfGLUBlock(nn.Module):
138+
"""
139+
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
140+
"""
141+
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
142+
super().__init__()
143+
# The total number of parameters for the MLP is increased to accommodate
144+
# the gate, value, and output projection matrices.
145+
# We now need to generate parameters for 3 matrices.
146+
total_params = 3 * hidden_size_x**2 * mlp_ratio
147+
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
148+
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
149+
self.mlp_ratio = mlp_ratio
150+
151+
152+
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
153+
batch_size, num_x, hidden_size_x = x.shape
154+
mlp_params = self.param_generator(s)
155+
156+
# Split the generated parameters into three parts for the gate, value, and output projection.
157+
fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
158+
159+
# Reshape the parameters into matrices for batch matrix multiplication.
160+
fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
161+
fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
162+
fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
163+
164+
# Normalize the generated weight matrices as in the original implementation.
165+
fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
166+
fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
167+
fc2 = torch.nn.functional.normalize(fc2, dim=-2)
168+
169+
res_x = x
170+
x = self.norm(x)
171+
172+
# Apply the final output projection.
173+
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
174+
175+
return x + res_x
176+
177+
178+
class NerfFinalLayer(nn.Module):
179+
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
180+
super().__init__()
181+
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
182+
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
183+
184+
def forward(self, x: torch.Tensor) -> torch.Tensor:
185+
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
186+
# So we temporarily move the channel dimension to the end for the norm operation.
187+
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
188+
189+
190+
class NerfFinalLayerConv(nn.Module):
191+
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
192+
super().__init__()
193+
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
194+
self.conv = operations.Conv2d(
195+
in_channels=hidden_size,
196+
out_channels=out_channels,
197+
kernel_size=3,
198+
padding=1,
199+
dtype=dtype,
200+
device=device,
201+
)
202+
203+
def forward(self, x: torch.Tensor) -> torch.Tensor:
204+
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
205+
# So we temporarily move the channel dimension to the end for the norm operation.
206+
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))

0 commit comments

Comments
 (0)