Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,19 +373,24 @@ def forward(


class BambaRMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, group_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.group_size = group_size

def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)

if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
*prefix_dims, last_dim = hidden_states.shape
group_count = last_dim // self.group_size
hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)

return self.weight * hidden_states.to(input_dtype)

Expand Down Expand Up @@ -524,7 +529,9 @@ def __init__(self, config: BambaConfig, layer_idx: int):
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
self.norm = BambaRMSNormGated(
self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=self.layer_norm_epsilon
)
self.D = nn.Parameter(torch.ones(self.num_heads))

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ def __init__(self, config: BambaConfig, layer_idx: int):
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
self.norm = BambaRMSNormGated(
self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=self.layer_norm_epsilon
)
self.D = nn.Parameter(torch.ones(self.num_heads))

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
Expand Down
11 changes: 6 additions & 5 deletions src/transformers/models/falcon_h1/modeling_falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,11 @@ def forward(


class FalconH1RMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, n_groups=1, norm_before_gate=True):
def __init__(self, hidden_size, group_size, eps=1e-6, norm_before_gate=True):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.n_groups = n_groups
self.group_size = group_size
self.norm_before_gate = norm_before_gate

def forward(self, hidden_states, gate=None):
Expand All @@ -409,12 +409,13 @@ def forward(self, hidden_states, gate=None):
seq_len = 1
hidden_states = hidden_states.to(torch.float32)

hidden_states = hidden_states.view(batch_size, seq_len, self.n_groups, int(dim // self.n_groups))
group_count = dim // self.group_size
hidden_states = hidden_states.view(batch_size, seq_len, group_count, self.group_size)
variance = hidden_states.pow(2).mean(-1, keepdim=True)

hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

hidden_states = self.weight.view(self.n_groups, int(dim // self.n_groups)) * hidden_states
hidden_states = self.weight.view(group_count, self.group_size) * hidden_states
hidden_states = hidden_states.view(batch_size, seq_len, dim)

if seq_len == 1:
Expand Down Expand Up @@ -560,8 +561,8 @@ def __init__(self, config: FalconH1Config, layer_idx: int):
if self.mamba_rms_norm:
self.norm = FalconH1RMSNormGated(
self.intermediate_size,
group_size=self.intermediate_size // self.n_groups,
eps=self.layer_norm_epsilon,
n_groups=self.n_groups,
norm_before_gate=config.mamba_norm_before_gate,
)
self.D = nn.Parameter(torch.ones(self.num_heads))
Expand Down
11 changes: 6 additions & 5 deletions src/transformers/models/falcon_h1/modular_falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,11 @@ def forward(


class FalconH1RMSNormGated(MambaRMSNormGated):
def __init__(self, hidden_size, eps=1e-6, n_groups=1, norm_before_gate=True):
def __init__(self, hidden_size, group_size, eps=1e-6, norm_before_gate=True):
super().__init__(hidden_size=hidden_size, eps=eps)
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.n_groups = n_groups
self.group_size = group_size
self.norm_before_gate = norm_before_gate

def forward(self, hidden_states, gate=None):
Expand All @@ -271,12 +271,13 @@ def forward(self, hidden_states, gate=None):
seq_len = 1
hidden_states = hidden_states.to(torch.float32)

hidden_states = hidden_states.view(batch_size, seq_len, self.n_groups, int(dim // self.n_groups))
group_count = dim // self.group_size
hidden_states = hidden_states.view(batch_size, seq_len, group_count, self.group_size)
variance = hidden_states.pow(2).mean(-1, keepdim=True)

hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

hidden_states = self.weight.view(self.n_groups, int(dim // self.n_groups)) * hidden_states
hidden_states = self.weight.view(group_count, self.group_size) * hidden_states
hidden_states = hidden_states.view(batch_size, seq_len, dim)

if seq_len == 1:
Expand Down Expand Up @@ -365,8 +366,8 @@ def __init__(self, config: FalconH1Config, layer_idx: int):
if self.mamba_rms_norm:
self.norm = FalconH1RMSNormGated(
self.intermediate_size,
group_size=self.intermediate_size // self.n_groups,
eps=self.layer_norm_epsilon,
n_groups=self.n_groups,
norm_before_gate=config.mamba_norm_before_gate,
)
self.D = nn.Parameter(torch.ones(self.num_heads))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.norm = GraniteMoeHybridRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
self.norm = GraniteMoeHybridRMSNormGated(
self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=self.layer_norm_epsilon
)
self.D = nn.Parameter(torch.ones(self.num_heads))

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
Expand Down Expand Up @@ -866,19 +868,24 @@ def forward(


class GraniteMoeHybridRMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, group_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.group_size = group_size

def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)

if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
*prefix_dims, last_dim = hidden_states.shape
group_count = last_dim // self.group_size
hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)

return self.weight * hidden_states.to(input_dtype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):


class GraniteMoeHybridRMSNormGated(BambaRMSNormGated):
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps)
def __init__(self, hidden_size, group_size, eps=1e-6):
super().__init__(hidden_size, group_size, eps)
Comment on lines +54 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, hidden_size, group_size, eps=1e-6):
super().__init__(hidden_size, group_size, eps)
pass

Weird that this even used the init here, we shouldn't need to anything on the modular side

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for the other modular codes



class GraniteMoeHybridMLP(GraniteMoeSharedMLP):
Expand Down
15 changes: 11 additions & 4 deletions src/transformers/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,24 @@ def reset(self):


class MambaRMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, group_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.group_size = group_size

def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)

if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
*prefix_dims, last_dim = hidden_states.shape
group_count = last_dim // self.group_size
hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)

return self.weight * hidden_states.to(input_dtype)

Expand Down Expand Up @@ -279,7 +284,9 @@ def __init__(self, config: Mamba2Config, layer_idx: int):
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
self.norm = MambaRMSNormGated(
self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=self.layer_norm_epsilon
)
self.D = nn.Parameter(torch.ones(self.num_heads))

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
Expand Down