diff --git a/src/config/cpm-bee-10b.json b/src/config/cpm-bee-10b.json index c34b2e0..dad7a02 100644 --- a/src/config/cpm-bee-10b.json +++ b/src/config/cpm-bee-10b.json @@ -10,5 +10,6 @@ "position_bias_num_segment_buckets": 256, "position_bias_max_distance" : 2048, "eps" : 1e-6, - "half" : true + "half" : true, + "int4" : false } diff --git a/src/config/cpm-bee-3b.json b/src/config/cpm-bee-3b.json index 55fd0f2..e1c7c42 100644 --- a/src/config/cpm-bee-3b.json +++ b/src/config/cpm-bee-3b.json @@ -10,5 +10,6 @@ "position_bias_num_segment_buckets": 256, "position_bias_max_distance" : 2048, "eps" : 1e-6, - "half" : true + "half" : true, + "int4" : false } diff --git a/src/cpm_live/layers/__init__.py b/src/cpm_live/layers/__init__.py index dbd8dcd..78ea038 100644 --- a/src/cpm_live/layers/__init__.py +++ b/src/cpm_live/layers/__init__.py @@ -1,6 +1,6 @@ from .embedding import Embedding, EmbeddingExt from .position_embedding import SegmentPositionEmbedding, BucketPositionBias, RotaryEmbedding -from .linear import Linear +from .linear import Linear, Linear4bit, Params4bit from .layernorm import LayerNorm from .attention import Attention from .feedforward import FeedForward diff --git a/src/cpm_live/layers/attention.py b/src/cpm_live/layers/attention.py index 241b14e..c9c339b 100644 --- a/src/cpm_live/layers/attention.py +++ b/src/cpm_live/layers/attention.py @@ -17,7 +17,7 @@ import torch import bmtrain as bmt import math -from .linear import Linear +from .linear import Linear, Linear4bit class Attention(bmt.DistributedModule): @@ -28,6 +28,8 @@ def __init__( dim_head: int, dtype: torch.dtype = torch.half, dropout_p: Optional[float] = None, + int4: Optional[bool] = None, + ) -> None: super().__init__() @@ -36,12 +38,17 @@ def __init__( self.num_heads = num_heads self.dim_head = dim_head - self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype) - self.project_k = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype) - self.project_v = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype) - - self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype) - + if int4 is None or int4 is False: + self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype) + self.project_k = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype) + self.project_v = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype) + self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype) + else: + self.project_q = Linear4bit(self.dim_model, self.num_heads * self.dim_head) + self.project_k = Linear4bit(self.dim_model, self.num_heads * self.dim_head) + self.project_v = Linear4bit(self.dim_model, self.num_heads * self.dim_head) + self.attention_out = Linear4bit(self.num_heads * self.dim_head, self.dim_model) + self.softmax = torch.nn.Softmax(dim=-1) if dropout_p is not None: diff --git a/src/cpm_live/layers/blocks.py b/src/cpm_live/layers/blocks.py index a16abf6..e478de3 100644 --- a/src/cpm_live/layers/blocks.py +++ b/src/cpm_live/layers/blocks.py @@ -31,6 +31,7 @@ class SelfAttentionBlock(bmt.DistributedModule): dtype (optional): Defaults to torch.half. eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. dropout_p (float, optional): Defaults to 0. + int4 (int, optional): whether to use int4 to load model. Defaults to False. """ # noqa: E501 def __init__( @@ -41,6 +42,8 @@ def __init__( dtype=torch.half, eps: float = 1e-6, dropout_p: Optional[float] = None, + int4: Optional[bool] = None, + ): super().__init__() @@ -57,6 +60,7 @@ def __init__( dim_head=dim_head, dtype=dtype, dropout_p=dropout_p, + int4=int4, ) if dropout_p: @@ -108,6 +112,7 @@ class FFNBlock(torch.nn.Module): dtype (optional): Defaults to torch.half. eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. dropout_p (float, optional): Defaults to 0. + int4 (int, optional): whether to use int4 to load model. Defaults to False. """ # noqa: E501 def __init__( @@ -117,6 +122,7 @@ def __init__( dtype=torch.half, eps: float = 1e-6, dropout_p: Optional[float] = 0, + int4: Optional[bool] = None, ): super().__init__() @@ -131,6 +137,7 @@ def __init__( dim_ff, dtype=dtype, dropout_p=dropout_p, + int4=int4, ) if dropout_p: @@ -169,6 +176,7 @@ class TransformerBlock(torch.nn.Module): dtype (optional): Defaults to torch.half. eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. dropout_p (float, optional): Defaults to 0. + int4 (int, optional): whether to use int4 to load model. Defaults to False. """ # noqa: E501 def __init__( @@ -182,6 +190,7 @@ def __init__( dropout_p: Optional[float] = None, mask_att: bool = False, mask_ffn: bool = False, + int4: Optional[bool] = None, ): super().__init__() self.mask_att = mask_att @@ -195,6 +204,7 @@ def __init__( dtype=dtype, eps=eps, dropout_p=dropout_p, + int4=int4, ) if not self.mask_ffn: @@ -204,6 +214,7 @@ def __init__( dtype=dtype, eps=eps, dropout_p=dropout_p, + int4=int4, ) def forward( diff --git a/src/cpm_live/layers/feedforward.py b/src/cpm_live/layers/feedforward.py index c015cf3..efc964f 100644 --- a/src/cpm_live/layers/feedforward.py +++ b/src/cpm_live/layers/feedforward.py @@ -16,7 +16,7 @@ from typing import Optional import torch import bmtrain as bmt -from .linear import Linear +from .linear import Linear, Linear4bit class DenseGatedACT(bmt.DistributedModule): @@ -25,22 +25,33 @@ def __init__( dim_in: int, dim_ff: int, dtype=torch.half, + int4: Optional[bool] = None, ): super().__init__() - - self.w_0 = Linear( - dim_in=dim_in, - dim_out=dim_ff, - dtype=dtype, - scale_before=False, - ) - - self.w_1 = Linear( - dim_in=dim_in, - dim_out=dim_ff, - dtype=dtype, - scale_before=False, - ) + if int4 is None or int4 is False: + self.w_0 = Linear( + dim_in=dim_in, + dim_out=dim_ff, + dtype=dtype, + scale_before=False, + ) + + self.w_1 = Linear( + dim_in=dim_in, + dim_out=dim_ff, + dtype=dtype, + scale_before=False, + ) + else: + self.w_0 = Linear4bit( + dim_in=dim_in, + dim_out=dim_ff, + ) + + self.w_1 = Linear4bit( + dim_in=dim_in, + dim_out=dim_ff, + ) self.act = torch.nn.GELU() def forward(self, x: torch.Tensor): @@ -74,6 +85,7 @@ class FeedForward(bmt.DistributedModule): bias (bool, optional): whether to use bias term in fully-connected layers used in feed-forward module. Defaults to False. activate_fn (str, optional): Defaults to `gated_gelu`. dropout_p (int, optional): Defaults to 0. + int4 (int, optional): whether to use int4 to load model. Defaults to False. """ # noqa: E501 def __init__( @@ -82,6 +94,7 @@ def __init__( dim_ff: int, dtype=torch.half, dropout_p: Optional[float] = None, + int4: Optional[bool] = None, ): super().__init__() @@ -90,6 +103,7 @@ def __init__( dim_in=dim_model, dim_ff=dim_ff, dtype=dtype, + int4=int4, ) if dropout_p is not None: @@ -97,11 +111,17 @@ def __init__( else: self.dropout = None - self.w_out = Linear( - dim_in=dim_ff, - dim_out=dim_model, - dtype=dtype, - scale_before=False, + if int4 is None or int4 is False: + self.w_out = Linear( + dim_in=dim_ff, + dim_out=dim_model, + dtype=dtype, + scale_before=False, + ) + else: + self.w_out = Linear4bit( + dim_in=dim_ff, + dim_out=dim_model, ) def forward(self, x: torch.Tensor): diff --git a/src/cpm_live/layers/linear.py b/src/cpm_live/layers/linear.py index 120a1a8..90e4c75 100644 --- a/src/cpm_live/layers/linear.py +++ b/src/cpm_live/layers/linear.py @@ -17,7 +17,12 @@ import bmtrain as bmt import math import torch.nn.functional as F - +import bitsandbytes as bnb +from typing import TypeVar,overload,Optional,Union,Callable,Any +from torch import Tensor, device, dtype +from bmtrain.utils import round_up +from bmtrain.global_var import config +T = TypeVar("T", bound="torch.nn.Module") class Linear(bmt.DistributedModule): def __init__( @@ -55,3 +60,181 @@ def forward(self, x: torch.Tensor): x = F.linear(x, self.weight) x = x / math.sqrt(self.dim_in) return x + +class Linear4bit(bmt.DistributedModule): + def __init__( + self, + dim_in: int, + dim_out: int, + compute_dtype: torch.dtype = torch.float32, + compress_statistics: bool = True, + quant_type: str = 'nf4', + ): + super().__init__() + self.dim_in = self.in_features = dim_in + self.dim_out = self.out_features = dim_out + + weight = Params4bit( + data=torch.empty((dim_out * dim_in // 2, 1), dtype=torch.uint8), + requires_grad=False, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + + self.weight = DistributedParameter4Int8(weight, requires_grad=False, quant_state=weight.quant_state) + self.compute_dtype = compute_dtype + + def forward(self, x: torch.Tensor): + if getattr(self.weight, 'quant_state', None) is None: + print('quantization state not initialized. Please ensure that the model parameters you load include the quant_state attribute.') + + inp_dtype = x.dtype + dtype_dict = { + 'torch.float32': torch.float32, + 'torch.float16': torch.float16, + } + if self.compute_dtype is not None: + if isinstance(self.compute_dtype, str): + self.compute_dtype = dtype_dict[self.compute_dtype] + x = x.to(dtype=self.compute_dtype) + + out = bnb.matmul_4bit(x, self.weight.t(), bias=None, quant_state=self.weight.quant_state) + out = out.to(inp_dtype) + out = out / math.sqrt(self.dim_in) + return out + +class Params4bit(torch.nn.Parameter): + def __new__(cls, + data=None, + requires_grad=True, + quant_state=None, + blocksize=64, + compress_statistics=True, + quant_type='nf4', + ): + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.blocksize = blocksize + self.compress_statistics = compress_statistics + self.quant_type = quant_type + self.quant_state = quant_state + self.data = data + return self + + def cuda(self, device): + w = self.data.contiguous().half().cuda(device) + w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) + self.data = w_4bit + self.quant_state = quant_state + return self + + @overload + def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: + ... + + @overload + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + ... + + @overload + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + ... + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): + return self.cuda(device) + else: + s = self.quant_state + if s is not None: + # make sure the quantization state is on the right device + s[0] = s[0].to(device) + if self.compress_statistics: + # TODO: refactor this. This is a nightmare + # for 4-bit: + # state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + # state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + #s[-2][0] = s[-2][0].to(device) # offset + #s[-2][1][0] = s[-2][1][0].to(device) # nested absmax + + # for 8-bit + s[-2][0] = s[-2][0].to(device) # offset + s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics + s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook + new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, quant_state=self.quant_state, + blocksize=self.blocksize, compress_statistics=self.compress_statistics, + quant_type=self.quant_type) + + return new_param + +class DistributedParameter4Int8(bmt.DistributedParameter): + r""" + DistributedParameter4Int8 is a subclass of DistributedParameter. + + The main difference is the added support for quantization, provided by the quant_state attribute. + + Args: + data (Tensor): Parameter tensor. + requires_grad (bool, optional): If the parameter requires gradient. + init_method (Callable[['DistributedParameter'], None], optional): The method to initialize the parameter. + group (str, optional): The group name of the parameter. + quant_state (Any, optional): The state of quantization for the parameter. + + Note: DistributedParameter4Int8 must be on the CUDA device. It will transfer the data to device automatically when `__init__` called. + """ + + _original_shape : torch.Size + _start_partition : int + _end_partition : int + _init_method : Optional[Callable[['DistributedParameter'], None]] + _in_checkpoint_block : bool + _group : Optional[str] + _quant_state : Optional[Any] + + def __new__(cls, + data : torch.Tensor, + requires_grad : bool = True, + init_method : Optional[Callable[['DistributedParameter'], None]] = None, + group : Optional[str] = None, + quant_state : Optional[Any] = None + ): + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + + num_of_elements = data.numel() + + cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda") + cuda_storage_size = round_up(num_of_elements, config["world_size"]) // config["world_size"] + + original_shape = data.size() + + cuda_storage = cuda_tensor.storage_type()(cuda_storage_size) + + start_of_partition = cuda_storage_size * config["rank"] + end_of_partition = min(num_of_elements, cuda_storage_size * (config["rank"] + 1)) + + # FX: cuda_tensor_size < 0 if num_of_elements is too small + cuda_tensor_size = max(end_of_partition - start_of_partition, 0) + + cuda_tensor.set_(cuda_storage, 0, (cuda_tensor_size,)) + cuda_tensor.copy_(data.view(-1)[start_of_partition: end_of_partition]) + ret = torch.Tensor._make_subclass(cls, cuda_tensor, requires_grad) + + setattr(ret, "_original_shape", original_shape) + setattr(ret, "_start_partition", start_of_partition) + setattr(ret, "_end_partition", end_of_partition) + setattr(ret, "_init_method", init_method) + setattr(ret, "_in_checkpoint_block", False) + setattr(ret, "_group", group) + setattr(ret, "_quant_state", quant_state) + + return ret + + @property + def quant_state(self) -> Optional[Any]: + return self._quant_state + + @quant_state.setter + def quant_state(self, value: Optional[Any]): + self._quant_state = value diff --git a/src/cpm_live/layers/transformer.py b/src/cpm_live/layers/transformer.py index fef5116..9ffa310 100644 --- a/src/cpm_live/layers/transformer.py +++ b/src/cpm_live/layers/transformer.py @@ -33,6 +33,7 @@ class Encoder(bmt.DistributedModule): dtype (optional): Defaults to torch.half. eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-6. dropout_p (float, optional): Defaults to 0. + int4 (int, optional): whether to use int4 to load model. Defaults to False. """ # noqa: E501 def __init__( @@ -46,6 +47,7 @@ def __init__( eps: float = 1e-6, dropout_p: Optional[float] = None, mask_modules: Optional[List[Tuple[bool, bool]]] = None, + int4: Optional[bool] = None, ): super().__init__() @@ -76,6 +78,7 @@ def __init__( dropout_p=dropout_p, mask_att=mask_modules[ith][0], mask_ffn=mask_modules[ith][1], + int4=int4, ) ) for ith in range(num_layers) diff --git a/src/cpm_live/models/bee.py b/src/cpm_live/models/bee.py index 098ecc6..efb3dfc 100644 --- a/src/cpm_live/models/bee.py +++ b/src/cpm_live/models/bee.py @@ -48,6 +48,7 @@ def __init__( eps=1e-6, half: bool = True, mask_modules: Optional[List[Tuple[bool, bool]]] = None, + int4: Optional[bool] = None, ): super().__init__() @@ -67,7 +68,7 @@ def __init__( self.dtype = torch.float self.vocab_size = vocab_size self.mask_modules = mask_modules - + self.int4 = int4 class CPMBee(bmt.DistributedModule): def __init__(self, config: CPMBeeConfig): @@ -84,6 +85,7 @@ def __init__(self, config: CPMBeeConfig): eps=config.eps, dropout_p=config.dropout_p, mask_modules=config.mask_modules, + int4 = config.int4, ) self.input_embedding = EmbeddingExt( diff --git a/src/finetune_cpm_bee_qlora.py b/src/finetune_cpm_bee_qlora.py new file mode 100644 index 0000000..2d07c75 --- /dev/null +++ b/src/finetune_cpm_bee_qlora.py @@ -0,0 +1,454 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Dict, List, Union +import torch +import torch.nn as nn +import bmtrain as bmt +import os +from opendelta import LoraModel +from opendelta.delta_models.lora import LowRankLinear +from cpm_live.arguments import get_args + +from cpm_live.models import CPMBee, CPMBeeConfig +from cpm_live.tokenizers import CPMBeeTokenizer +from cpm_live.utils import allgather_objects +from cpm_live.training_tasks.bee import FinetuneDataset + +from cpm_live.layers import Linear4bit + +def get_tokenizer(args): + tokenizer = CPMBeeTokenizer() + return tokenizer + +def get_model(args): + config = CPMBeeConfig.from_json_file(args.model_config) + model = CPMBee(config) + model.config = config + if args.load is not None: + state_dict = load_quantize_state_dict(args.load) + model.load_state_dict(state_dict) + for name, param in model.named_parameters(): + if name in state_dict and hasattr(state_dict[name], 'quant_state'): + param.quant_state = state_dict[name].quant_state + else: + bmt.init_parameters(model) + + original_dtype_dict = {} + for name, module in model.named_modules(): + if isinstance(module, Linear4bit): + original_dtype_dict[name] = module.weight.data.dtype + module.weight.data = module.weight.data.to(torch.half) + + # insert LoRA + if args.use_delta: + delta_model = LoraModel( + backbone_model=model, modified_modules=["project_q", "project_v"], backend="bmt" + ) + delta_model.freeze_module(exclude=["deltas"], set_state_dict=True) + delta_model.log() + + for name, module in model.named_modules(): + if name in original_dtype_dict: + module.weight.data = module.weight.data.to(original_dtype_dict[name]) + if isinstance(module, LowRankLinear): + module.lora_A = nn.Parameter(module.lora_A.data.to(model.config.dtype)) + module.lora_B = nn.Parameter(module.lora_B.data.to(model.config.dtype)) + return model + +def get_optimizer(args, model): + optimizer = bmt.optim.AdamOffloadOptimizer( + model.parameters(), weight_decay=args.weight_decay + ) + return optimizer + +def get_learning_rate_scheduler(args, optimizer): + if args.lr_decay_iters is None: + args.lr_decay_iters = args.train_iters + lr_scheduler = bmt.lr_scheduler.Noam( + optimizer, + start_lr=args.lr, + warmup_iter=args.warmup_iters, + end_iter=args.lr_decay_iters, + num_iter=args.start_step, + ) + return lr_scheduler + +def setup_model_and_optimizer(args): + model = get_model(args) + tokenizer = get_tokenizer(args) + bmt.synchronize() + optimizer = get_optimizer(args, model) + lr_scheduler = get_learning_rate_scheduler(args, optimizer) + bmt.synchronize() + optim_manager = bmt.optim.OptimManager( + loss_scale=args.loss_scale, + loss_scale_factor=2, + loss_scale_steps=512, + ) + optim_manager.add_optimizer(optimizer, lr_scheduler) + return tokenizer, model, optimizer, lr_scheduler, optim_manager + +def initialize(): + args = get_args(finetune=True) + bmt.init_distributed(seed=args.seed) + if args.save is not None: + os.makedirs(args.save, exist_ok=True) + return args + +def load_quantize_state_dict(quantize_save): + checkpoint = torch.load(quantize_save) + state_dict = checkpoint["state_dict"] + quant_state_dict = checkpoint["quant_state_dict"] + for key, value in state_dict.items(): + if key in quant_state_dict: + value.quant_state = quant_state_dict[key] + + return state_dict + +def see_memory(detail=False): + if detail: + res = torch.cuda.memory_summary() + else: + res = ( + round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024), 2), + round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2), + ) + torch.cuda.reset_peak_memory_stats() + return res + +def add_mem_time(info, mem_usage, tim_usage): + torch.cuda.synchronize() + mem_usage[info] = see_memory() + tim_usage[info] = time.time() + return mem_usage, tim_usage + +def evaluation(model, args, tokenizer, loss_func): + bmt.print_rank("evaluation begins...") + eval_dataloader = FinetuneDataset( + args.eval_dataset, + 1, + args.max_length, + tokenizer, + max_depth=8, + task_name=args.task_name, + drop_last=args.drop_last, + ) + eval_losses = [] + last_data = None + + with torch.no_grad(): + for iteration, data in enumerate(eval_dataloader): + iteration = iteration + 1 + skip_this_batch = False + if data is None: + if last_data is None: + raise RuntimeError( + "Dataset is too small, please use a smaller batch size or sequence length!" + ) + data = last_data + skip_this_batch = True + else: + last_data = data + + input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32) + input_ids_sub = torch.from_numpy(data["inputs_sub"]).cuda().to(torch.int32) + input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32) + input_context = torch.from_numpy(data["context"]).cuda().bool() + input_sample_ids = torch.from_numpy(data["sample_ids"]).cuda().to(torch.int32) + input_num_segments = torch.from_numpy(data["num_segments"]).cuda().to(torch.int32) + input_segment_ids = torch.from_numpy(data["segment_ids"]).cuda().to(torch.int32) + input_segment_rel_offset = ( + torch.from_numpy(data["segment_rel_offset"]).cuda().to(torch.int32) + ) + input_segment_rel = torch.from_numpy(data["segment_rel"]).cuda().to(torch.int32) + input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32) + targets = torch.from_numpy(data["target"]).cuda().to(torch.int32) + ext_table_ids = torch.from_numpy(data["ext_ids"]).cuda().to(torch.int32) + ext_table_sub = torch.from_numpy(data["ext_sub"]).cuda().to(torch.int32) + # =========== + mem_usage = {} + tim_usage = {} + mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage) + + # =========== + logits, _ = model( + input_ids, + input_ids_sub, + input_length, + input_context, + input_sample_ids, + input_num_segments, + input_segment_ids, + input_segment_rel_offset, + input_segment_rel, + input_span, + ext_table_ids, + ext_table_sub, + ) + loss = loss_func(logits.view(-1, logits.size(-1)), targets.long().view(-1)) + if skip_this_batch: + loss = loss * 0 + eval_losses.append(bmt.sum_loss(loss)) + + overall_loss = torch.stack(eval_losses).mean().item() + return overall_loss + +def finetune( + args, + tokenizer: CPMBeeTokenizer, + model: CPMBee, + optimizer: bmt.optim.AdamOffloadOptimizer, + lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler, + optim_manager: bmt.optim.OptimManager, +): + average_time = bmt.utils.AverageRecorder() + + if model.config.dtype == torch.half: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + else: + loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + + if args.tensorboard is not None and bmt.rank() == 0: + from torch.utils.tensorboard import SummaryWriter + import distutils.version # noqa: F401 + + if not os.path.exists(args.tensorboard): + os.makedirs(args.tensorboard) + writer = SummaryWriter(log_dir=args.tensorboard) + + best_eval_loss, eval_loss_increase = 1e9, 0 + global_token_pass = 0.0 + global_steps = 0 + global_world_size = bmt.world_size() + dataloader = FinetuneDataset( + args.dataset, + args.batch_size, + args.max_length, + tokenizer, + max_depth=8, + task_name=args.task_name, + drop_last=args.drop_last, + ) + + for epoch in range(args.epoch): + epoch = epoch + 1 + last_data = None + for iteration, data in enumerate(dataloader): + iteration = iteration + 1 + if global_steps >= args.train_iters: + break + global_steps = global_steps + 1 + skip_this_batch = False + if data is None: + if last_data is None: + raise RuntimeError( + "Dataset is too small, please use a smaller batch size or sequence length!" + ) + data = last_data # use last data + skip_this_batch = True + else: + last_data = data + + input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32) + input_ids_sub = torch.from_numpy(data["inputs_sub"]).cuda().to(torch.int32) + input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32) + input_context = torch.from_numpy(data["context"]).cuda().bool() + input_sample_ids = torch.from_numpy(data["sample_ids"]).cuda().to(torch.int32) + input_num_segments = torch.from_numpy(data["num_segments"]).cuda().to(torch.int32) + input_segment_ids = torch.from_numpy(data["segment_ids"]).cuda().to(torch.int32) + input_segment_rel_offset = ( + torch.from_numpy(data["segment_rel_offset"]).cuda().to(torch.int32) + ) + input_segment_rel = torch.from_numpy(data["segment_rel"]).cuda().to(torch.int32) + input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32) + targets = torch.from_numpy(data["target"]).cuda().to(torch.int32) + ext_table_ids = torch.from_numpy(data["ext_ids"]).cuda().to(torch.int32) + ext_table_sub = torch.from_numpy(data["ext_sub"]).cuda().to(torch.int32) + task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32) + task_names = data["task_names"] + # =========== + optim_manager.zero_grad() + mem_usage = {} + tim_usage = {} + mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage) + + # =========== + logits, _ = model( + input_ids, + input_ids_sub, + input_length, + input_context, + input_sample_ids, + input_num_segments, + input_segment_ids, + input_segment_rel_offset, + input_segment_rel, + input_span, + ext_table_ids, + ext_table_sub, + ) + loss = loss_func(logits.view(-1, logits.size(-1)), targets.long().view(-1)) + if skip_this_batch: + loss = loss * 0 + + mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage) + + # =========== + optim_manager.backward(loss) + mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage) + # =========== + grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0) + optim_manager.step() + mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage) + # ========== + iteration_time = tim_usage["optim"] - tim_usage["init"] + average_time.record(iteration_time) + + with torch.no_grad(): + task_num = len(task_names) + targets_tmp = targets.expand(task_num, -1, -1) + task = torch.arange(task_num, dtype=torch.int32, device="cuda")[:, None, None] + targets_tmp = torch.where( + task_ids == task, + targets_tmp, + torch.scalar_tensor(-100, dtype=torch.int32, device="cuda"), + ) + + task_loss_map: Dict[str, float] = {} + if not skip_this_batch: + for i in range(task_num): + task_loss = loss_func( + logits.view(-1, logits.size(-1)), targets_tmp[i, :].long().view(-1) + ) + task_loss_map[task_names[i]] = task_loss.item() + gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map) + + global_task_loss_map: Dict[str, Union[List[float], float]] = {} + for local_task_loss_map in gatherd_task_loss_map: + for task_name, task_loss in local_task_loss_map.items(): + if task_name not in global_task_loss_map: + global_task_loss_map[task_name] = [] + global_task_loss_map[task_name].append(task_loss) + + task_loss_map = {} + for task_name in sorted(list(global_task_loss_map.keys())): + avg_loss = sum(global_task_loss_map[task_name]) / len( + global_task_loss_map[task_name] + ) + task_loss_map[task_name] = avg_loss + + local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda() + local_total_rate = bmt.sum_loss(local_total_rate).item() + global_token_pass += ( + global_world_size * local_total_rate * args.max_length * args.batch_size + ) + avg_time = average_time.value + + train_info = { + "time": tim_usage["init"], + "epoch": epoch, + "iteration": iteration, + "loss": task_loss_map[args.task_name], + "lr": lr_scheduler.current_lr, + "lr_scale": int(optim_manager.loss_scale), + "time_usage": tim_usage, + "mem_usage": mem_usage, + "avg_time": avg_time, + "token_max": local_total_rate, + "token_pass": global_token_pass, + "throughout": args.max_length * args.batch_size * local_total_rate / avg_time, + "grad_norm": grad_norm.item(), + "mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(), + "num_gpus": global_world_size, + "task_loss": task_loss_map, + } + + bmt.print_rank( + ( + "| Epoch: {:3d} | Iter: {:6d} | loss: {:.4f} " + + "| lr: {:.4e}, scale: {:10.4f} | time: {:.4f} |" + + " token/max: {:.4f} | mask/max: {:.4f} | grad_norm: {:.10f}" + ).format( + epoch, + iteration, + task_loss_map[args.task_name], + lr_scheduler.current_lr, + int(optim_manager.loss_scale), + avg_time, + input_length.float().mean() / args.max_length, + (targets >= 0).sum(-1).float().mean() / args.max_length, + grad_norm, + ) + ) + bmt.print_rank( + "| " + + " | ".join( + [ + "{} loss: {:.4f}".format(task_name, loss) + for task_name, loss in task_loss_map.items() + ] + ) + ) + # not available for std and var only support floating point and complex dtypes + # if iteration % args.inspect_iters == 0: + # model_inspect = bmt.inspect.inspect_model(model, "*") + # bmt.print_rank(bmt.inspect.format_summary(model_inspect)) + # train_info["model_inspect"] = model_inspect + # print(train_info["mem_usage"]) + + # write log here + if args.tensorboard is not None and bmt.rank() == 0: + writer.add_scalar("Loss/train", task_loss_map[args.task_name], global_steps) + for task_name, loss in task_loss_map.items(): + writer.add_scalar("Loss/train/{}".format(task_name), loss, global_steps) + + # evaluation + if global_steps % args.eval_interval == 0: + eval_loss = evaluation(model, args, tokenizer, loss_func) + if args.tensorboard is not None and bmt.rank() == 0: + writer.add_scalar("Loss/eval", eval_loss, global_steps) + if eval_loss < best_eval_loss: + best_eval_loss = eval_loss + eval_loss_increase = 0 + if args.save is not None: + if not args.use_delta: + bmt.save(model, os.path.join(args.save, args.save_name + "-best.pt")) + else: + state_dict = model.state_dict() + if bmt.rank() == 0: + print("saving_now") + torch.save(state_dict, os.path.join(args.save, args.save_name + "-delta-best.pt")) + else: + eval_loss_increase += 1 + bmt.print_rank( + "| Eval loss: {:.4f} | Increase: {:2d}".format(eval_loss, eval_loss_increase) + ) + if eval_loss_increase == args.early_stop_patience: + bmt.print_rank( + "Eval loss has increased {:d} times, the finetune loop early stopped." + .format(eval_loss_increase) + ) + return + # end of finetune + +def main(): + args = initialize() + tokenizer, model, optimizer, lr_scheduler, optim_manager = setup_model_and_optimizer(args) + finetune(args, tokenizer, model, optimizer, lr_scheduler, optim_manager) + +if __name__ == "__main__": + main() diff --git a/src/quantize_state_dict.py b/src/quantize_state_dict.py new file mode 100644 index 0000000..3b054a6 --- /dev/null +++ b/src/quantize_state_dict.py @@ -0,0 +1,31 @@ +from argparse import ArgumentParser +import torch +from cpm_live.layers.linear import Params4bit + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--input-path", type=str, help="The path to input state dict path", required=True) + parser.add_argument("--output-path", type=str, help="the path to output state dict path", required=True) + args = parser.parse_args() + return args + +def quantize_state_dict(args): + state_dict = torch.load(args.input_path) + replace_list = ["project_q", "project_k", "project_v", "attention_out", "w_0", "w_1", "w_out"] + + temp_dict = {} + quant_state_dict = {} + for key, value in state_dict.items(): + if any(word in key for word in replace_list): + new_value = Params4bit(value, requires_grad=False).cuda("cuda") + temp_dict[key] = new_value + quant_state_dict[key] = new_value.quant_state + state_dict.update(temp_dict) + torch.save({"state_dict": state_dict, "quant_state_dict": quant_state_dict}, args.output_path) + +def main(): + args = parse_args() + quantize_state_dict(args) + +if __name__ == "__main__": + main() diff --git a/src/scripts/finetune_cpm_bee_qlora.sh b/src/scripts/finetune_cpm_bee_qlora.sh new file mode 100644 index 0000000..f47d91a --- /dev/null +++ b/src/scripts/finetune_cpm_bee_qlora.sh @@ -0,0 +1,29 @@ +#! /bin/bash +export CUDA_VISIBLE_DEVICES=0 + +OPTS="" +OPTS+=" --use-delta" +OPTS+=" --model-config config/cpm-bee-10b.json" +OPTS+=" --dataset path/to/dataset" +OPTS+=" --eval_dataset path/to/eval/dataset" +OPTS+=" --epoch 10" +OPTS+=" --batch-size 2" +OPTS+=" --train-iters 100" +OPTS+=" --save-name cpm_bee_finetune" +OPTS+=" --max-length 2048" +OPTS+=" --save results/" +OPTS+=" --lr 0.0001" +OPTS+=" --warmup-iters 1" +OPTS+=" --eval-interval 10" +OPTS+=" --early-stop-patience 10" +OPTS+=" --lr-decay-style noam" +OPTS+=" --weight-decay 0.01" +OPTS+=" --clip-grad 1.0" +OPTS+=" --loss-scale 32768" +OPTS+=" --start-step 0" +OPTS+=" --load quantized_model.pt" + +CMD="python finetune_cpm_bee_qlora.py ${OPTS}" + +echo ${CMD} +$CMD \ No newline at end of file diff --git a/tutorials/basic_task_finetune/README_qlora.md b/tutorials/basic_task_finetune/README_qlora.md new file mode 100644 index 0000000..cbd7cf2 --- /dev/null +++ b/tutorials/basic_task_finetune/README_qlora.md @@ -0,0 +1,38 @@ +## CPM-Bee单卡QLoRA微调 + +### 使用CPM-Bee进行基础任务量化微调 + +本教程在**使用CPM-Bee进行基础任务微调**的基础上,引入量化操作进行Delta Tuning,在保证模型训练效果的前提下降低显存消耗。经测试,此方法支持RTX3090 24GB单卡上对CPM-Bee-10B的全精度int4量化微调。 + +步骤如下: + +首先,您需要对模型参数文件进行量化调整。 + +进入工作路径: + +```bash +$ cd src +``` + +量化调整参数文件: + +```bash +$ python quantize_state_dict.py --input-path your_cpmbee_model.bin --output-path your_cpmbee_quantize_model.bin +``` + +其次,您需要设置模型config文件。 + +下面的例子代表采用全精度+int4量化(默认compute_dtype为torch.float32;采用双重量化;量化类型为nf4); + +```json + "half" : false, + "int4" : true, +``` + +最后,完成以上步骤后,您就可以参考基础微调教程来完成其余部分,我们在`scripts`目录下提供了示例脚本`finetune_cpm_bee_qlora.sh`,您可以参考。 + +注意在您的微调脚本中记得将`--load`内容替换为 + +`your_cpmbee_quantize_model.bin` + +