diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 6fb5b08e93c0..0550b2f9bc44 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -1,4 +1,5 @@ from .wav2vec2 import Wav2Vec2Model +from .whisper import WhisperLargeV3 import comfy.model_management import comfy.ops import comfy.utils @@ -11,13 +12,18 @@ def __init__(self, config): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + model_type = config.pop("model_type") model_config = dict(config) model_config.update({ "dtype": self.dtype, "device": offload_device, "operations": comfy.ops.manual_cast }) - self.model = Wav2Vec2Model(**model_config) + + if model_type == "wav2vec2": + self.model = Wav2Vec2Model(**model_config) + elif model_type == "whisper3": + self.model = WhisperLargeV3(**model_config) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 @@ -40,33 +46,45 @@ def encode_audio(self, audio, sample_rate): def load_audio_encoder_from_sd(sd, prefix=""): sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) - embed_dim = sd["encoder.layer_norm.bias"].shape[0] - if embed_dim == 1024:# large - config = { - "embed_dim": 1024, - "num_heads": 16, - "num_layers": 24, - "conv_norm": True, - "conv_bias": True, - "do_normalize": True, - "do_stable_layer_norm": True + if "encoder.layer_norm.bias" in sd: #wav2vec2 + embed_dim = sd["encoder.layer_norm.bias"].shape[0] + if embed_dim == 1024:# large + config = { + "model_type": "wav2vec2", + "embed_dim": 1024, + "num_heads": 16, + "num_layers": 24, + "conv_norm": True, + "conv_bias": True, + "do_normalize": True, + "do_stable_layer_norm": True + } + elif embed_dim == 768: # base + config = { + "model_type": "wav2vec2", + "embed_dim": 768, + "num_heads": 12, + "num_layers": 12, + "conv_norm": False, + "conv_bias": False, + "do_normalize": False, # chinese-wav2vec2-base has this False + "do_stable_layer_norm": False } - elif embed_dim == 768: # base + else: + raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) + elif "model.encoder.embed_positions.weight" in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""}) config = { - "embed_dim": 768, - "num_heads": 12, - "num_layers": 12, - "conv_norm": False, - "conv_bias": False, - "do_normalize": False, # chinese-wav2vec2-base has this False - "do_stable_layer_norm": False + "model_type": "whisper3", } else: - raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) + raise RuntimeError("ERROR: audio encoder not supported.") audio_encoder = AudioEncoderModel(config) m, u = audio_encoder.load_sd(sd) if len(m) > 0: logging.warning("missing audio encoder: {}".format(m)) + if len(u) > 0: + logging.warning("unexpected audio encoder: {}".format(u)) return audio_encoder diff --git a/comfy/audio_encoders/whisper.py b/comfy/audio_encoders/whisper.py new file mode 100755 index 000000000000..93d3782f14b5 --- /dev/null +++ b/comfy/audio_encoders/whisper.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from typing import Optional +from comfy.ldm.modules.attention import optimized_attention_masked +import comfy.ops + +class WhisperFeatureExtractor(nn.Module): + def __init__(self, n_mels=128, device=None): + super().__init__() + self.sample_rate = 16000 + self.n_fft = 400 + self.hop_length = 160 + self.n_mels = n_mels + self.chunk_length = 30 + self.n_samples = 480000 + + self.mel_spectrogram = torchaudio.transforms.MelSpectrogram( + sample_rate=self.sample_rate, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=0, + f_max=8000, + norm="slaney", + mel_scale="slaney", + ).to(device) + + def __call__(self, audio): + audio = torch.mean(audio, dim=1) + batch_size = audio.shape[0] + processed_audio = [] + + for i in range(batch_size): + aud = audio[i] + if aud.shape[0] > self.n_samples: + aud = aud[:self.n_samples] + elif aud.shape[0] < self.n_samples: + aud = F.pad(aud, (0, self.n_samples - aud.shape[0])) + processed_audio.append(aud) + + audio = torch.stack(processed_audio) + + mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device) + + log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0) + log_mel_spec = (log_mel_spec + 4.0) / 4.0 + + return log_mel_spec + + +class MultiHeadAttention(nn.Module): + def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None): + super().__init__() + assert d_model % n_heads == 0 + + self.d_model = d_model + self.n_heads = n_heads + self.d_k = d_model // n_heads + + self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device) + self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, seq_len, _ = query.shape + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class EncoderLayer(nn.Module): + def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None): + super().__init__() + + self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations) + self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device) + + self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device) + self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device) + self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device) + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x, x, x, attention_mask) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + x = residual + x + + return x + + +class AudioEncoder(nn.Module): + def __init__( + self, + n_mels: int = 128, + n_ctx: int = 1500, + n_state: int = 1280, + n_head: int = 20, + n_layer: int = 32, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device) + self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device) + + self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device) + + self.layers = nn.ModuleList([ + EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations) + for _ in range(n_layer) + ]) + + self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + + x = x.transpose(1, 2) + + x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x) + + all_x = () + for layer in self.layers: + all_x += (x,) + x = layer(x) + + x = self.layer_norm(x) + all_x += (x,) + return x, all_x + + +class WhisperLargeV3(nn.Module): + def __init__( + self, + n_mels: int = 128, + n_audio_ctx: int = 1500, + n_audio_state: int = 1280, + n_audio_head: int = 20, + n_audio_layer: int = 32, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device) + + self.encoder = AudioEncoder( + n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, + dtype=dtype, device=device, operations=operations + ) + + def forward(self, audio): + mel = self.feature_extractor(audio) + x, all_x = self.encoder(mel) + return x, all_x