Skip to content

[inference/model]Adapted to the baichuan2-7B model #5591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"baichuan": "<reserved_106>{input_text}<reserved_107>",
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
}

Expand Down
1 change: 1 addition & 0 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

_supported_models = [
"LlamaForCausalLM",
"BaichuanForCausalLM",
]

_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
Expand Down
183 changes: 183 additions & 0 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
from typing import Optional, Tuple

import torch
import torch.nn as nn

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.logging import get_dist_logger

inference_ops = InferenceOpsLoader().load()

logger = get_dist_logger(__name__)


class NopadBaiChuanAttention(nn.Module):
def __init__(
self,
config,
attn_qproj_w: torch.Tensor = None,
attn_kproj_w: torch.Tensor = None,
attn_vproj_w: torch.Tensor = None,
attn_oproj_w: torch.Tensor = None,
):
"""This layer will replace the BaichuanAttention.

Args:
config (BaichuanConfig): Holding the Baichuan model config.
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
"""
super().__init__()
self.o_proj_weight = attn_oproj_w

self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads

# Used to adapt llama_base_attn_forward
self.num_key_value_heads = self.num_heads

qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)

@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
"""Used for initialize the weight of NopadBaiChuanAttention by origin BaiChuanAttention.

Args:
module (BaiChuanAttention): The origin BaiChuanAttention layer.
"""

config = module.config

q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((3, module.hidden_size, module.hidden_size))

attn_qproj_w = q_proj_w.transpose(0, 1)
attn_kproj_w = k_proj_w.transpose(0, 1)
attn_vproj_w = v_proj_w.transpose(0, 1)
attn_oproj_w = module.o_proj.weight.transpose(0, 1)

attn_layer = NopadBaiChuanAttention(
config=config,
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj_w=attn_oproj_w,
)

return attn_layer

def forward(
self,
hidden_states: torch.Tensor,
block_tables: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
sequence_lengths: torch.Tensor,
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True,
is_verifier: bool = False,
tokens_to_verify: int = None,
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
high_precision: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id.
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""

return NopadLlamaAttention.forward(
self,
hidden_states=hidden_states,
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
sequence_lengths=sequence_lengths,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
is_prompts=is_prompts,
is_verifier=is_verifier,
tokens_to_verify=tokens_to_verify,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
)


# NOTE This will cause the result to be different from the transformer in some cases.
class NopadBaichuanMLP(nn.Module):
def __init__(
self,
mlp_gproj_w: torch.Tensor = None,
mlp_uproj_w: torch.Tensor = None,
mlp_dproj_w: torch.Tensor = None,
):
"""This layer will replace the BaiChuanAttention.

Args:
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__()
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
self.down_proj_weight = mlp_dproj_w

@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
"""Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).

Args:
module (nn.Module): The origin MLP(Baichuan) layer.
"""

mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)

mlp_layer = NopadBaichuanMLP(
mlp_gproj_w=mlp_gproj_w,
mlp_uproj_w=mlp_uproj_w,
mlp_dproj_w=mlp_dproj_w,
)

return mlp_layer

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
"""
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return torch.mm(act_out, self.down_proj_weight)
9 changes: 8 additions & 1 deletion colossalai/inference/modeling/policy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from .glide_llama import GlideLlamaModelPolicy
from .nopadding_baichuan import NoPaddingBaiChuanModelInferPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy

model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
"nopadding_baichuan": NoPaddingBaiChuanModelInferPolicy,
"glide_llama": GlideLlamaModelPolicy,
}

__all__ = ["NoPaddingLlamaModelInferPolicy", "GlideLlamaModelPolicy", "model_polic_map"]
__all__ = [
"NoPaddingLlamaModelInferPolicy",
"NoPaddingBaiChuanModelInferPolicy",
"GlideLlamaModelPolicy",
"model_polic_map",
]
64 changes: 64 additions & 0 deletions colossalai/inference/modeling/policy/nopadding_baichuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch.nn as nn
from torch.nn import Parameter

from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaiChuanAttention, NopadBaichuanMLP
from colossalai.inference.modeling.models.nopadding_llama import (
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
llama_rmsnorm_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription

# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy


class NoPaddingBaiChuanModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()

def module_policy(self):
policy = super().module_policy()

decoder_attribute_replacement = {
"lm_head.weight": Parameter(
nn.functional.normalize(self.model.lm_head.weight).transpose(0, 1), requires_grad=False
),
}
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)

policy["DecoderLayer"] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadBaiChuanAttention,
),
]
)

self.append_or_create_method_replacement(
description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
)
self.append_or_create_method_replacement(
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
)
self.append_or_create_method_replacement(
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key="DecoderLayer"
)
self.append_or_create_method_replacement(
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
)

return policy

def postprocess(self):
init_to_get_rotary(self.model.model)
return self.model
1 change: 1 addition & 0 deletions examples/inference/benchmark_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def benchmark_inference(args):
max_output_len=args.output_len,
prefill_ratio=1.2,
block_size=32,
use_cuda_kernel=True,
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
elif args.mode == "vllm":
Expand Down
102 changes: 102 additions & 0 deletions tests/test_infer/test_models/test_baichuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import random

import numpy as np
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn

PATH_EXIST = "baichuan-inc/Baichuan2-7B-Base"

if os.path.exists(PATH_EXIST):
PATH_EXIST = True
else:
PATH_EXIST = False


def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)


def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained(PATH_EXIST, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
PATH_EXIST, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
).cuda()
model = model.eval()

inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
]

output_len = 38
do_sample = False

if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len, prompt_template=prompt_template, dtype="fp32", use_cuda_kernel=True
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

return outputs


@parameterize("prompt_template", [None, "baichuan"])
def check_output_consistency(prompt_template):
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)

for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"

# clear singleton flash decoding tensors
FDIntermTensors._instances = {}


def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency()


@pytest.mark.skipif(
not PATH_EXIST,
reason="There is no local model address included, please replace this address with a valid one.",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
spawn(run_dist, 1)


if __name__ == "__main__":
test_inference_engine()