diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index f5e337e52ebd..30f4a4c781fa 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -373,10 +373,11 @@ def forward( class BambaRMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, group_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.group_size = group_size def forward(self, hidden_states, gate=None): input_dtype = hidden_states.dtype @@ -384,8 +385,12 @@ def forward(self, hidden_states, gate=None): if gate is not None: hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + *prefix_dims, last_dim = hidden_states.shape + group_count = last_dim // self.group_size + hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size) + variance = hidden_states_group.pow(2).mean(-1, keepdim=True) + hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size) return self.weight * hidden_states.to(input_dtype) @@ -524,7 +529,9 @@ def __init__(self, config: BambaConfig, layer_idx: int): # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, self.num_heads + 1) self.A_log = nn.Parameter(torch.log(A)) - self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.norm = BambaRMSNormGated( + self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=self.layer_norm_epsilon + ) self.D = nn.Parameter(torch.ones(self.num_heads)) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index aec09861de81..9ce5ba765800 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -282,7 +282,9 @@ def __init__(self, config: BambaConfig, layer_idx: int): # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, self.num_heads + 1) self.A_log = nn.Parameter(torch.log(A)) - self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.norm = BambaRMSNormGated( + self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=self.layer_norm_epsilon + ) self.D = nn.Parameter(torch.ones(self.num_heads)) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 865daf384b49..17f2f2179a5e 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -389,11 +389,11 @@ def forward( class FalconH1RMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6, n_groups=1, norm_before_gate=True): + def __init__(self, hidden_size, group_size, eps=1e-6, norm_before_gate=True): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - self.n_groups = n_groups + self.group_size = group_size self.norm_before_gate = norm_before_gate def forward(self, hidden_states, gate=None): @@ -409,12 +409,13 @@ def forward(self, hidden_states, gate=None): seq_len = 1 hidden_states = hidden_states.to(torch.float32) - hidden_states = hidden_states.view(batch_size, seq_len, self.n_groups, int(dim // self.n_groups)) + group_count = dim // self.group_size + hidden_states = hidden_states.view(batch_size, seq_len, group_count, self.group_size) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight.view(self.n_groups, int(dim // self.n_groups)) * hidden_states + hidden_states = self.weight.view(group_count, self.group_size) * hidden_states hidden_states = hidden_states.view(batch_size, seq_len, dim) if seq_len == 1: @@ -560,8 +561,8 @@ def __init__(self, config: FalconH1Config, layer_idx: int): if self.mamba_rms_norm: self.norm = FalconH1RMSNormGated( self.intermediate_size, + group_size=self.intermediate_size // self.n_groups, eps=self.layer_norm_epsilon, - n_groups=self.n_groups, norm_before_gate=config.mamba_norm_before_gate, ) self.D = nn.Parameter(torch.ones(self.num_heads)) diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 8b00de3ab97f..6323c7bd9b4d 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -251,11 +251,11 @@ def forward( class FalconH1RMSNormGated(MambaRMSNormGated): - def __init__(self, hidden_size, eps=1e-6, n_groups=1, norm_before_gate=True): + def __init__(self, hidden_size, group_size, eps=1e-6, norm_before_gate=True): super().__init__(hidden_size=hidden_size, eps=eps) self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - self.n_groups = n_groups + self.group_size = group_size self.norm_before_gate = norm_before_gate def forward(self, hidden_states, gate=None): @@ -271,12 +271,13 @@ def forward(self, hidden_states, gate=None): seq_len = 1 hidden_states = hidden_states.to(torch.float32) - hidden_states = hidden_states.view(batch_size, seq_len, self.n_groups, int(dim // self.n_groups)) + group_count = dim // self.group_size + hidden_states = hidden_states.view(batch_size, seq_len, group_count, self.group_size) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight.view(self.n_groups, int(dim // self.n_groups)) * hidden_states + hidden_states = self.weight.view(group_count, self.group_size) * hidden_states hidden_states = hidden_states.view(batch_size, seq_len, dim) if seq_len == 1: @@ -365,8 +366,8 @@ def __init__(self, config: FalconH1Config, layer_idx: int): if self.mamba_rms_norm: self.norm = FalconH1RMSNormGated( self.intermediate_size, + group_size=self.intermediate_size // self.n_groups, eps=self.layer_norm_epsilon, - n_groups=self.n_groups, norm_before_gate=config.mamba_norm_before_gate, ) self.D = nn.Parameter(torch.ones(self.num_heads)) diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index e3a1e69fc861..e46d74fd30ef 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -451,7 +451,9 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, self.num_heads + 1) self.A_log = nn.Parameter(torch.log(A)) - self.norm = GraniteMoeHybridRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.norm = GraniteMoeHybridRMSNormGated( + self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=self.layer_norm_epsilon + ) self.D = nn.Parameter(torch.ones(self.num_heads)) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) @@ -866,10 +868,11 @@ def forward( class GraniteMoeHybridRMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, group_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.group_size = group_size def forward(self, hidden_states, gate=None): input_dtype = hidden_states.dtype @@ -877,8 +880,12 @@ def forward(self, hidden_states, gate=None): if gate is not None: hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + *prefix_dims, last_dim = hidden_states.shape + group_count = last_dim // self.group_size + hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size) + variance = hidden_states_group.pow(2).mean(-1, keepdim=True) + hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size) return self.weight * hidden_states.to(input_dtype) diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 4de1ff253914..2cc4712be2c6 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -51,8 +51,8 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): class GraniteMoeHybridRMSNormGated(BambaRMSNormGated): - def __init__(self, hidden_size, eps=1e-6): - super().__init__(hidden_size, eps) + def __init__(self, hidden_size, group_size, eps=1e-6): + super().__init__(hidden_size, group_size, eps) class GraniteMoeHybridMLP(GraniteMoeSharedMLP): diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 85cf026e49d0..52090d86e796 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -203,10 +203,11 @@ def reset(self): class MambaRMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, group_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.group_size = group_size def forward(self, hidden_states, gate=None): input_dtype = hidden_states.dtype @@ -214,8 +215,12 @@ def forward(self, hidden_states, gate=None): if gate is not None: hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + *prefix_dims, last_dim = hidden_states.shape + group_count = last_dim // self.group_size + hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size) + variance = hidden_states_group.pow(2).mean(-1, keepdim=True) + hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size) return self.weight * hidden_states.to(input_dtype) @@ -279,7 +284,9 @@ def __init__(self, config: Mamba2Config, layer_idx: int): # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, self.num_heads + 1) self.A_log = nn.Parameter(torch.log(A)) - self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.norm = MambaRMSNormGated( + self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=self.layer_norm_epsilon + ) self.D = nn.Parameter(torch.ones(self.num_heads)) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)