Skip to content
Open

qlora #100

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/config/cpm-bee-10b.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
"position_bias_num_segment_buckets": 256,
"position_bias_max_distance" : 2048,
"eps" : 1e-6,
"half" : true
"half" : true,
"int4" : false
}
3 changes: 2 additions & 1 deletion src/config/cpm-bee-3b.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
"position_bias_num_segment_buckets": 256,
"position_bias_max_distance" : 2048,
"eps" : 1e-6,
"half" : true
"half" : true,
"int4" : false
}
2 changes: 1 addition & 1 deletion src/cpm_live/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .embedding import Embedding, EmbeddingExt
from .position_embedding import SegmentPositionEmbedding, BucketPositionBias, RotaryEmbedding
from .linear import Linear
from .linear import Linear, Linear4bit, Params4bit
from .layernorm import LayerNorm
from .attention import Attention
from .feedforward import FeedForward
Expand Down
21 changes: 14 additions & 7 deletions src/cpm_live/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import bmtrain as bmt
import math
from .linear import Linear
from .linear import Linear, Linear4bit


class Attention(bmt.DistributedModule):
Expand All @@ -28,6 +28,8 @@ def __init__(
dim_head: int,
dtype: torch.dtype = torch.half,
dropout_p: Optional[float] = None,
int4: Optional[bool] = None,

) -> None:

super().__init__()
Expand All @@ -36,12 +38,17 @@ def __init__(
self.num_heads = num_heads
self.dim_head = dim_head

self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype)
self.project_k = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype)
self.project_v = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype)

self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype)

if int4 is None or int4 is False:
self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype)
self.project_k = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype)
self.project_v = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype)
self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype)
else:
self.project_q = Linear4bit(self.dim_model, self.num_heads * self.dim_head)
self.project_k = Linear4bit(self.dim_model, self.num_heads * self.dim_head)
self.project_v = Linear4bit(self.dim_model, self.num_heads * self.dim_head)
self.attention_out = Linear4bit(self.num_heads * self.dim_head, self.dim_model)

self.softmax = torch.nn.Softmax(dim=-1)

if dropout_p is not None:
Expand Down
11 changes: 11 additions & 0 deletions src/cpm_live/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SelfAttentionBlock(bmt.DistributedModule):
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
int4 (int, optional): whether to use int4 to load model. Defaults to False.
""" # noqa: E501

def __init__(
Expand All @@ -41,6 +42,8 @@ def __init__(
dtype=torch.half,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
int4: Optional[bool] = None,

):

super().__init__()
Expand All @@ -57,6 +60,7 @@ def __init__(
dim_head=dim_head,
dtype=dtype,
dropout_p=dropout_p,
int4=int4,
)

if dropout_p:
Expand Down Expand Up @@ -108,6 +112,7 @@ class FFNBlock(torch.nn.Module):
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
int4 (int, optional): whether to use int4 to load model. Defaults to False.
""" # noqa: E501

def __init__(
Expand All @@ -117,6 +122,7 @@ def __init__(
dtype=torch.half,
eps: float = 1e-6,
dropout_p: Optional[float] = 0,
int4: Optional[bool] = None,
):
super().__init__()

Expand All @@ -131,6 +137,7 @@ def __init__(
dim_ff,
dtype=dtype,
dropout_p=dropout_p,
int4=int4,
)

if dropout_p:
Expand Down Expand Up @@ -169,6 +176,7 @@ class TransformerBlock(torch.nn.Module):
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
int4 (int, optional): whether to use int4 to load model. Defaults to False.
""" # noqa: E501

def __init__(
Expand All @@ -182,6 +190,7 @@ def __init__(
dropout_p: Optional[float] = None,
mask_att: bool = False,
mask_ffn: bool = False,
int4: Optional[bool] = None,
):
super().__init__()
self.mask_att = mask_att
Expand All @@ -195,6 +204,7 @@ def __init__(
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
int4=int4,
)

if not self.mask_ffn:
Expand All @@ -204,6 +214,7 @@ def __init__(
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
int4=int4,
)

def forward(
Expand Down
60 changes: 40 additions & 20 deletions src/cpm_live/layers/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Optional
import torch
import bmtrain as bmt
from .linear import Linear
from .linear import Linear, Linear4bit


class DenseGatedACT(bmt.DistributedModule):
Expand All @@ -25,22 +25,33 @@ def __init__(
dim_in: int,
dim_ff: int,
dtype=torch.half,
int4: Optional[bool] = None,
):
super().__init__()

self.w_0 = Linear(
dim_in=dim_in,
dim_out=dim_ff,
dtype=dtype,
scale_before=False,
)

self.w_1 = Linear(
dim_in=dim_in,
dim_out=dim_ff,
dtype=dtype,
scale_before=False,
)
if int4 is None or int4 is False:
self.w_0 = Linear(
dim_in=dim_in,
dim_out=dim_ff,
dtype=dtype,
scale_before=False,
)

self.w_1 = Linear(
dim_in=dim_in,
dim_out=dim_ff,
dtype=dtype,
scale_before=False,
)
else:
self.w_0 = Linear4bit(
dim_in=dim_in,
dim_out=dim_ff,
)

self.w_1 = Linear4bit(
dim_in=dim_in,
dim_out=dim_ff,
)
self.act = torch.nn.GELU()

def forward(self, x: torch.Tensor):
Expand Down Expand Up @@ -74,6 +85,7 @@ class FeedForward(bmt.DistributedModule):
bias (bool, optional): whether to use bias term in fully-connected layers used in feed-forward module. Defaults to False.
activate_fn (str, optional): Defaults to `gated_gelu`.
dropout_p (int, optional): Defaults to 0.
int4 (int, optional): whether to use int4 to load model. Defaults to False.
""" # noqa: E501

def __init__(
Expand All @@ -82,6 +94,7 @@ def __init__(
dim_ff: int,
dtype=torch.half,
dropout_p: Optional[float] = None,
int4: Optional[bool] = None,
):

super().__init__()
Expand All @@ -90,18 +103,25 @@ def __init__(
dim_in=dim_model,
dim_ff=dim_ff,
dtype=dtype,
int4=int4,
)

if dropout_p is not None:
self.dropout = torch.nn.Dropout(dropout_p)
else:
self.dropout = None

self.w_out = Linear(
dim_in=dim_ff,
dim_out=dim_model,
dtype=dtype,
scale_before=False,
if int4 is None or int4 is False:
self.w_out = Linear(
dim_in=dim_ff,
dim_out=dim_model,
dtype=dtype,
scale_before=False,
)
else:
self.w_out = Linear4bit(
dim_in=dim_ff,
dim_out=dim_model,
)

def forward(self, x: torch.Tensor):
Expand Down
Loading