Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions comfy/audio_encoders/audio_encoders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .wav2vec2 import Wav2Vec2Model
from .whisper import WhisperLargeV3
import comfy.model_management
import comfy.ops
import comfy.utils
Expand All @@ -11,13 +12,18 @@ def __init__(self, config):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
model_type = config.pop("model_type")
model_config = dict(config)
model_config.update({
"dtype": self.dtype,
"device": offload_device,
"operations": comfy.ops.manual_cast
})
self.model = Wav2Vec2Model(**model_config)

if model_type == "wav2vec2":
self.model = Wav2Vec2Model(**model_config)
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
Expand All @@ -40,33 +46,45 @@ def encode_audio(self, audio, sample_rate):

def load_audio_encoder_from_sd(sd, prefix=""):
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
if embed_dim == 1024:# large
config = {
"embed_dim": 1024,
"num_heads": 16,
"num_layers": 24,
"conv_norm": True,
"conv_bias": True,
"do_normalize": True,
"do_stable_layer_norm": True
if "encoder.layer_norm.bias" in sd: #wav2vec2
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
if embed_dim == 1024:# large
config = {
"model_type": "wav2vec2",
"embed_dim": 1024,
"num_heads": 16,
"num_layers": 24,
"conv_norm": True,
"conv_bias": True,
"do_normalize": True,
"do_stable_layer_norm": True
}
elif embed_dim == 768: # base
config = {
"model_type": "wav2vec2",
"embed_dim": 768,
"num_heads": 12,
"num_layers": 12,
"conv_norm": False,
"conv_bias": False,
"do_normalize": False, # chinese-wav2vec2-base has this False
"do_stable_layer_norm": False
}
elif embed_dim == 768: # base
else:
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
elif "model.encoder.embed_positions.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
config = {
"embed_dim": 768,
"num_heads": 12,
"num_layers": 12,
"conv_norm": False,
"conv_bias": False,
"do_normalize": False, # chinese-wav2vec2-base has this False
"do_stable_layer_norm": False
"model_type": "whisper3",
}
else:
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
raise RuntimeError("ERROR: audio encoder not supported.")

audio_encoder = AudioEncoderModel(config)
m, u = audio_encoder.load_sd(sd)
if len(m) > 0:
logging.warning("missing audio encoder: {}".format(m))
if len(u) > 0:
logging.warning("unexpected audio encoder: {}".format(u))

return audio_encoder
186 changes: 186 additions & 0 deletions comfy/audio_encoders/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from typing import Optional
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.ops

class WhisperFeatureExtractor(nn.Module):
def __init__(self, n_mels=128, device=None):
super().__init__()
self.sample_rate = 16000
self.n_fft = 400
self.hop_length = 160
self.n_mels = n_mels
self.chunk_length = 30
self.n_samples = 480000

self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sample_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels,
f_min=0,
f_max=8000,
norm="slaney",
mel_scale="slaney",
).to(device)

def __call__(self, audio):
audio = torch.mean(audio, dim=1)
batch_size = audio.shape[0]
processed_audio = []

for i in range(batch_size):
aud = audio[i]
if aud.shape[0] > self.n_samples:
aud = aud[:self.n_samples]
elif aud.shape[0] < self.n_samples:
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
processed_audio.append(aud)

audio = torch.stack(processed_audio)

mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)

log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
log_mel_spec = (log_mel_spec + 4.0) / 4.0

return log_mel_spec


class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
super().__init__()
assert d_model % n_heads == 0

self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads

self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len, _ = query.shape

q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)

attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
attn_output = self.out_proj(attn_output)

return attn_output


class EncoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
super().__init__()

self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)

self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)

def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
residual = x
x = self.self_attn_layer_norm(x)
x = self.self_attn(x, x, x, attention_mask)
x = residual + x

residual = x
x = self.final_layer_norm(x)
x = self.fc1(x)
x = F.gelu(x)
x = self.fc2(x)
x = residual + x

return x


class AudioEncoder(nn.Module):
def __init__(
self,
n_mels: int = 128,
n_ctx: int = 1500,
n_state: int = 1280,
n_head: int = 20,
n_layer: int = 32,
dtype=None,
device=None,
operations=None
):
super().__init__()

self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)

self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)

self.layers = nn.ModuleList([
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
for _ in range(n_layer)
])

self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))

x = x.transpose(1, 2)

x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)

all_x = ()
for layer in self.layers:
all_x += (x,)
x = layer(x)

x = self.layer_norm(x)
all_x += (x,)
return x, all_x


class WhisperLargeV3(nn.Module):
def __init__(
self,
n_mels: int = 128,
n_audio_ctx: int = 1500,
n_audio_state: int = 1280,
n_audio_head: int = 20,
n_audio_layer: int = 32,
dtype=None,
device=None,
operations=None
):
super().__init__()

self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)

self.encoder = AudioEncoder(
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
dtype=dtype, device=device, operations=operations
)

def forward(self, audio):
mel = self.feature_extractor(audio)
x, all_x = self.encoder(mel)
return x, all_x
Loading