From e94f44cef86466a545e518f903322b2a03f772e3 Mon Sep 17 00:00:00 2001
From: yuehuayingxueluo <867460659@qq.com>
Date: Fri, 12 Apr 2024 10:55:55 +0800
Subject: [PATCH 1/5] Adapted to the baichuan2-7B model
---
colossalai/inference/config.py | 1 +
.../modeling/models/nopadding_baichuan.py | 181 ++++
.../modeling/models/nopadding_llama.py | 245 +++---
.../inference/modeling/policy/__init__.py | 4 +-
.../modeling/policy/nopadding_baichuan.py | 64 ++
examples/inference/benchmark_llama.py | 1 +
.../model_utils/baichuan2_7B/__init__.py | 0
.../baichuan2_7B/configuration_baichuan.py | 66 ++
.../baichuan2_7B/generation_utils.py | 81 ++
.../baichuan2_7B/modeling_baichuan.py | 801 ++++++++++++++++++
tests/test_infer/test_models/test_baichuan.py | 98 +++
11 files changed, 1446 insertions(+), 96 deletions(-)
create mode 100644 colossalai/inference/modeling/models/nopadding_baichuan.py
create mode 100644 colossalai/inference/modeling/policy/nopadding_baichuan.py
create mode 100644 tests/test_infer/test_models/model_utils/baichuan2_7B/__init__.py
create mode 100644 tests/test_infer/test_models/model_utils/baichuan2_7B/configuration_baichuan.py
create mode 100644 tests/test_infer/test_models/model_utils/baichuan2_7B/generation_utils.py
create mode 100644 tests/test_infer/test_models/model_utils/baichuan2_7B/modeling_baichuan.py
create mode 100644 tests/test_infer/test_models/test_baichuan.py
diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py
index 01b1ac53ea7d..8ccac2141562 100644
--- a/colossalai/inference/config.py
+++ b/colossalai/inference/config.py
@@ -27,6 +27,7 @@
_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n{input_text}[/INST]",
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
+ "baichuan": "{input_text}",
}
diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py
new file mode 100644
index 000000000000..63cd35252bf2
--- /dev/null
+++ b/colossalai/inference/modeling/models/nopadding_baichuan.py
@@ -0,0 +1,181 @@
+# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from colossalai.inference.flash_decoding_utils import FDIntermTensors
+from colossalai.inference.modeling.models.nopadding_llama import llama_base_attn_forward
+from colossalai.kernel.kernel_loader import InferenceOpsLoader
+from colossalai.logging import get_dist_logger
+
+inference_ops = InferenceOpsLoader().load()
+
+logger = get_dist_logger(__name__)
+
+# def get_alibi_slopes(n_head):
+
+
+class NopadBaiChuanAttention(nn.Module):
+ def __init__(
+ self,
+ config,
+ attn_qproj_w: torch.Tensor = None,
+ attn_kproj_w: torch.Tensor = None,
+ attn_vproj_w: torch.Tensor = None,
+ attn_oproj_w: torch.Tensor = None,
+ ):
+ """This layer will replace the BaichuanAttention.
+
+ Args:
+ config (BaichuanConfig): Holding the Baichuan model config.
+ attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
+ attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
+ attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
+ attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
+ """
+ super().__init__()
+ self.o_proj_weight = attn_oproj_w
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+
+ # Used to adapt llama_base_attn_forward
+ self.num_key_value_heads = self.num_heads
+
+ qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
+ self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
+
+ @staticmethod
+ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
+ """Used for initialize the weight of NopadBaiChuanAttention by origin BaiChuanAttention.
+
+ Args:
+ module (BaiChuanAttention): The origin BaiChuanAttention layer.
+ """
+
+ config = module.config
+
+ q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((3, module.hidden_size, module.hidden_size))
+
+ attn_qproj_w = q_proj_w.transpose(0, 1)
+ attn_kproj_w = k_proj_w.transpose(0, 1)
+ attn_vproj_w = v_proj_w.transpose(0, 1)
+ attn_oproj_w = module.o_proj.weight.transpose(0, 1)
+
+ attn_layer = NopadBaiChuanAttention(
+ config=config,
+ attn_qproj_w=attn_qproj_w,
+ attn_kproj_w=attn_kproj_w,
+ attn_vproj_w=attn_vproj_w,
+ attn_oproj_w=attn_oproj_w,
+ )
+
+ return attn_layer
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ block_tables: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ sequence_lengths: torch.Tensor,
+ cos_sin: Tuple[torch.Tensor],
+ fd_inter_tensor: FDIntermTensors,
+ is_prompts: bool = True,
+ kv_seq_len: int = 0,
+ output_tensor: torch.Tensor = None,
+ sm_scale: int = None,
+ use_cuda_kernel: bool = True,
+ cu_seqlens: torch.Tensor = None,
+ high_precision: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Args:
+ hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
+ block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
+ storing mapping of token_position_id -> block_id.
+ k_cache (torch.Tensor): It holds the GPU memory for the key cache.
+ v_cache (torch.Tensor): It holds the GPU memory for the key cache.
+ sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
+ cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
+ fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
+ storing intermediate values in flash-decoding.
+ is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
+ kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
+ output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
+ sm_scale (int, optional): Used for flash attention. Defaults to None.
+ use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
+ cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
+ high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
+ """
+
+ return llama_base_attn_forward(
+ self,
+ hidden_states=hidden_states,
+ block_tables=block_tables,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ sequence_lengths=sequence_lengths,
+ cos_sin=cos_sin,
+ fd_inter_tensor=fd_inter_tensor,
+ is_prompts=is_prompts,
+ kv_seq_len=kv_seq_len,
+ output_tensor=output_tensor,
+ sm_scale=sm_scale,
+ use_cuda_kernel=use_cuda_kernel,
+ cu_seqlens=cu_seqlens,
+ high_precision=high_precision,
+ )
+
+
+# NOTE This will cause the result to be different from the transformer in some cases.
+class NopadBaichuanMLP(nn.Module):
+ def __init__(
+ self,
+ mlp_gproj_w: torch.Tensor = None,
+ mlp_uproj_w: torch.Tensor = None,
+ mlp_dproj_w: torch.Tensor = None,
+ ):
+ """This layer will replace the LlamaAttention.
+
+ Args:
+ mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
+ mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
+ mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
+ """
+ super().__init__()
+ self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
+ self.down_proj_weight = mlp_dproj_w
+
+ @staticmethod
+ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
+ """Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).
+
+ Args:
+ module (nn.Module): The origin MLP(Baichuan) layer.
+ """
+
+ mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
+ mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
+ mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
+
+ mlp_layer = NopadBaichuanMLP(
+ mlp_gproj_w=mlp_gproj_w,
+ mlp_uproj_w=mlp_uproj_w,
+ mlp_dproj_w=mlp_dproj_w,
+ )
+
+ return mlp_layer
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
+ """
+ hidden_states = hidden_states.expand(2, -1, -1)
+ gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
+ act_out = inference_ops.silu_and_mul(gate_up_proj_out)
+ return torch.mm(act_out, self.down_proj_weight)
diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py
index c5b61385f822..a89fd363b6ec 100644
--- a/colossalai/inference/modeling/models/nopadding_llama.py
+++ b/colossalai/inference/modeling/models/nopadding_llama.py
@@ -301,7 +301,6 @@ def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttentio
return attn_layer
- # Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
@@ -339,100 +338,23 @@ def forward(
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
- token_nums = hidden_states.size(0)
-
- if self.num_heads != self.num_key_value_heads:
- query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
- key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
- value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
- else:
- # fused qkv
- hidden_states = hidden_states.expand(3, -1, -1)
- query_states, key_states, value_states = (
- torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
- )
-
- block_size = k_cache.size(-2)
-
- if is_prompts:
- if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
- # flash attn 2 currently only supports FP16/BF16.
- inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
- inference_ops.context_kv_cache_memcpy(
- key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
- )
-
- attn_output = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens,
- cu_seqlens_k=cu_seqlens,
- max_seqlen_q=kv_seq_len,
- max_seqlen_k=kv_seq_len,
- dropout_p=0.0,
- softmax_scale=sm_scale,
- causal=True,
- )
- attn_output = attn_output.view(token_nums, -1)
- else:
- rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
- attn_output = context_attention_unpadded(
- q=query_states,
- k=key_states,
- v=value_states,
- k_cache=k_cache,
- v_cache=v_cache,
- context_lengths=sequence_lengths,
- block_tables=block_tables,
- block_size=block_size,
- output=output_tensor,
- max_seq_len=kv_seq_len,
- sm_scale=sm_scale,
- )
- else:
- if use_cuda_kernel:
- inference_ops.rotary_embedding_and_cache_copy(
- query_states,
- key_states,
- value_states,
- cos_sin[0],
- cos_sin[1],
- k_cache,
- v_cache,
- sequence_lengths,
- block_tables,
- high_precision,
- )
- else:
- decoding_fused_rotary_embedding(
- query_states,
- key_states,
- value_states,
- cos_sin[0],
- cos_sin[1],
- k_cache,
- v_cache,
- block_tables,
- sequence_lengths,
- )
- attn_output = flash_decoding_attention(
- q=query_states,
- k_cache=k_cache,
- v_cache=v_cache,
- kv_seq_len=sequence_lengths,
- block_tables=block_tables,
- block_size=block_size,
- max_seq_len_in_batch=kv_seq_len,
- output=output_tensor,
- mid_output=fd_inter_tensor.mid_output,
- mid_output_lse=fd_inter_tensor.mid_output_lse,
- sm_scale=sm_scale,
- )
-
- attn_output = torch.mm(attn_output, self.o_proj_weight)
-
- return attn_output
+ return llama_base_attn_forward(
+ self,
+ hidden_states=hidden_states,
+ block_tables=block_tables,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ sequence_lengths=sequence_lengths,
+ cos_sin=cos_sin,
+ fd_inter_tensor=fd_inter_tensor,
+ is_prompts=is_prompts,
+ kv_seq_len=kv_seq_len,
+ output_tensor=output_tensor,
+ sm_scale=sm_scale,
+ use_cuda_kernel=use_cuda_kernel,
+ cu_seqlens=cu_seqlens,
+ high_precision=high_precision,
+ )
# NOTE This will cause the result to be different from the transformer in some cases.
@@ -490,3 +412,136 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return torch.mm(act_out, self.down_proj_weight)
+
+
+# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
+def llama_base_attn_forward(
+ self,
+ hidden_states: torch.Tensor,
+ block_tables: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ sequence_lengths: torch.Tensor,
+ cos_sin: Tuple[torch.Tensor],
+ fd_inter_tensor: FDIntermTensors,
+ is_prompts: bool = True,
+ kv_seq_len: int = 0,
+ output_tensor: torch.Tensor = None,
+ sm_scale: int = None,
+ use_cuda_kernel: bool = True,
+ cu_seqlens: torch.Tensor = None,
+ high_precision: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Args:
+ hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
+ block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
+ storing mapping of token_position_id -> block_id.
+ k_cache (torch.Tensor): It holds the GPU memory for the key cache.
+ v_cache (torch.Tensor): It holds the GPU memory for the key cache.
+ sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
+ cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
+ fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
+ storing intermediate values in flash-decoding.
+ is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
+ kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
+ output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
+ sm_scale (int, optional): Used for flash attention. Defaults to None.
+ use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
+ cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
+ high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
+ """
+
+ token_nums = hidden_states.size(0)
+
+ if self.num_heads != self.num_key_value_heads:
+ query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
+ key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
+ value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
+ else:
+ # fused qkv
+ hidden_states = hidden_states.expand(3, -1, -1)
+ query_states, key_states, value_states = (
+ torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
+ )
+
+ block_size = k_cache.size(-2)
+
+ if is_prompts:
+ if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
+ # flash attn 2 currently only supports FP16/BF16.
+ inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
+ inference_ops.context_kv_cache_memcpy(
+ key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
+ )
+
+ attn_output = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=kv_seq_len,
+ max_seqlen_k=kv_seq_len,
+ dropout_p=0.0,
+ softmax_scale=sm_scale,
+ causal=True,
+ )
+ attn_output = attn_output.view(token_nums, -1)
+ else:
+ rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
+ attn_output = context_attention_unpadded(
+ q=query_states,
+ k=key_states,
+ v=value_states,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ context_lengths=sequence_lengths,
+ block_tables=block_tables,
+ block_size=block_size,
+ output=output_tensor,
+ max_seq_len=kv_seq_len,
+ sm_scale=sm_scale,
+ )
+ else:
+ if use_cuda_kernel:
+ inference_ops.rotary_embedding_and_cache_copy(
+ query_states,
+ key_states,
+ value_states,
+ cos_sin[0],
+ cos_sin[1],
+ k_cache,
+ v_cache,
+ sequence_lengths,
+ block_tables,
+ high_precision,
+ )
+ else:
+ decoding_fused_rotary_embedding(
+ query_states,
+ key_states,
+ value_states,
+ cos_sin[0],
+ cos_sin[1],
+ k_cache,
+ v_cache,
+ block_tables,
+ sequence_lengths,
+ )
+ attn_output = flash_decoding_attention(
+ q=query_states,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ kv_seq_len=sequence_lengths,
+ block_tables=block_tables,
+ block_size=block_size,
+ max_seq_len_in_batch=kv_seq_len,
+ output=output_tensor,
+ mid_output=fd_inter_tensor.mid_output,
+ mid_output_lse=fd_inter_tensor.mid_output_lse,
+ sm_scale=sm_scale,
+ )
+
+ attn_output = torch.mm(attn_output, self.o_proj_weight)
+ return attn_output
diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py
index 1b905fdae620..1c45f5f1f3ba 100644
--- a/colossalai/inference/modeling/policy/__init__.py
+++ b/colossalai/inference/modeling/policy/__init__.py
@@ -1,7 +1,9 @@
+from .nopadding_baichuan import NoPaddingBaiChuanModelInferPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
+ "nopadding_baichuan": NoPaddingBaiChuanModelInferPolicy,
}
-__all__ = ["NoPaddingLlamaModelInferPolicy", "model_polic_map"]
+__all__ = ["NoPaddingLlamaModelInferPolicy", "NoPaddingBaiChuanModelInferPolicy", "model_polic_map"]
diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py
new file mode 100644
index 000000000000..a45184b235cc
--- /dev/null
+++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py
@@ -0,0 +1,64 @@
+import torch.nn as nn
+from torch.nn import Parameter
+
+from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaiChuanAttention, NopadBaichuanMLP
+from colossalai.inference.modeling.models.nopadding_llama import (
+ llama_causal_lm_forward,
+ llama_decoder_layer_forward,
+ llama_model_forward,
+ llama_rmsnorm_forward,
+)
+from colossalai.inference.utils import init_to_get_rotary
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
+
+# import colossalai
+from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
+
+
+class NoPaddingBaiChuanModelInferPolicy(LlamaForCausalLMPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+
+ decoder_attribute_replacement = {
+ "lm_head.weight": Parameter(
+ nn.functional.normalize(self.model.lm_head.weight).transpose(0, 1), requires_grad=False
+ ),
+ }
+ policy["BaichuanForCausalLM"] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ )
+
+ policy["DecoderLayer"] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="mlp",
+ target_module=NopadBaichuanMLP,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn",
+ target_module=NopadBaiChuanAttention,
+ ),
+ ]
+ )
+
+ self.append_or_create_method_replacement(
+ description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
+ )
+ self.append_or_create_method_replacement(
+ description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
+ )
+ self.append_or_create_method_replacement(
+ description={"forward": llama_decoder_layer_forward}, policy=policy, target_key="DecoderLayer"
+ )
+ self.append_or_create_method_replacement(
+ description={"forward": llama_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
+ )
+
+ return policy
+
+ def postprocess(self):
+ init_to_get_rotary(self.model.model)
+ return self.model
diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py
index 448a84c6fa0e..8128ce9f3f76 100644
--- a/examples/inference/benchmark_llama.py
+++ b/examples/inference/benchmark_llama.py
@@ -117,6 +117,7 @@ def benchmark_inference(args):
max_output_len=args.output_len,
prefill_ratio=1.2,
block_size=32,
+ use_cuda_kernel=True,
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
elif args.mode == "vllm":
diff --git a/tests/test_infer/test_models/model_utils/baichuan2_7B/__init__.py b/tests/test_infer/test_models/model_utils/baichuan2_7B/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/test_infer/test_models/model_utils/baichuan2_7B/configuration_baichuan.py b/tests/test_infer/test_models/model_utils/baichuan2_7B/configuration_baichuan.py
new file mode 100644
index 000000000000..ed499b501ae9
--- /dev/null
+++ b/tests/test_infer/test_models/model_utils/baichuan2_7B/configuration_baichuan.py
@@ -0,0 +1,66 @@
+# Copyright 2023 Baichuan Inc. All Rights Reserved.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class BaichuanConfig(PretrainedConfig):
+ model_type = "baichuan"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=125696,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
diff --git a/tests/test_infer/test_models/model_utils/baichuan2_7B/generation_utils.py b/tests/test_infer/test_models/model_utils/baichuan2_7B/generation_utils.py
new file mode 100644
index 000000000000..5cfafa2b76cd
--- /dev/null
+++ b/tests/test_infer/test_models/model_utils/baichuan2_7B/generation_utils.py
@@ -0,0 +1,81 @@
+from queue import Queue
+from typing import List
+
+import torch
+
+
+def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int = 0):
+ def _parse_messages(messages, split_role="user"):
+ system, rounds = "", []
+ round = []
+ for i, message in enumerate(messages):
+ if message["role"] == "system":
+ assert i == 0
+ system = message["content"]
+ continue
+ if message["role"] == split_role and round:
+ rounds.append(round)
+ round = []
+ round.append(message)
+ if round:
+ rounds.append(round)
+ return system, rounds
+
+ max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
+ max_input_tokens = model.config.model_max_length - max_new_tokens
+ system, rounds = _parse_messages(messages, split_role="user")
+ system_tokens = tokenizer.encode(system)
+ max_history_tokens = max_input_tokens - len(system_tokens)
+
+ history_tokens = []
+ for round in rounds[::-1]:
+ round_tokens = []
+ for message in round:
+ if message["role"] == "user":
+ round_tokens.append(model.generation_config.user_token_id)
+ else:
+ round_tokens.append(model.generation_config.assistant_token_id)
+ round_tokens.extend(tokenizer.encode(message["content"]))
+ if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
+ history_tokens = round_tokens + history_tokens # concat left
+ if len(history_tokens) < max_history_tokens:
+ continue
+ break
+
+ input_tokens = system_tokens + history_tokens
+ if messages[-1]["role"] != "assistant":
+ input_tokens.append(model.generation_config.assistant_token_id)
+ input_tokens = input_tokens[-max_input_tokens:] # truncate left
+ return torch.LongTensor([input_tokens]).to(model.device)
+
+
+class TextIterStreamer:
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
+ self.tokenizer = tokenizer
+ self.skip_prompt = skip_prompt
+ self.skip_special_tokens = skip_special_tokens
+ self.tokens = []
+ self.text_queue = Queue()
+ self.next_tokens_are_prompt = True
+
+ def put(self, value):
+ if self.skip_prompt and self.next_tokens_are_prompt:
+ self.next_tokens_are_prompt = False
+ else:
+ if len(value.shape) > 1:
+ value = value[0]
+ self.tokens.extend(value.tolist())
+ self.text_queue.put(self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
+
+ def end(self):
+ self.text_queue.put(None)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ value = self.text_queue.get()
+ if value is None:
+ raise StopIteration()
+ else:
+ return value
diff --git a/tests/test_infer/test_models/model_utils/baichuan2_7B/modeling_baichuan.py b/tests/test_infer/test_models/model_utils/baichuan2_7B/modeling_baichuan.py
new file mode 100644
index 000000000000..f6db626fb0c4
--- /dev/null
+++ b/tests/test_infer/test_models/model_utils/baichuan2_7B/modeling_baichuan.py
@@ -0,0 +1,801 @@
+# Copyright 2023 Baichuan Inc. All Rights Reserved.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+import os
+from contextlib import contextmanager
+from threading import Thread
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from torch.nn import functional as F
+from transformers import PretrainedConfig, PreTrainedModel
+from transformers.activations import ACT2FN
+from transformers.generation.utils import GenerationConfig
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.utils import ContextManagers, logging
+
+from .configuration_baichuan import BaichuanConfig
+from .generation_utils import TextIterStreamer, build_chat_input
+
+logger = logging.get_logger(__name__)
+
+try:
+ from xformers import ops as xops
+except ImportError:
+ xops = None
+ logger.warning(
+ "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
+ )
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ if len(mask.size()) == 3:
+ bsz, src_len, _ = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+ expanded_mask = mask[:, None, :, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+ else:
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+class RotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
+ freqs = torch.outer(t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+ if seq_len > self.max_seq_len_cached:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
+ freqs = torch.outer(t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
+ elif self.cos_cached.device != x.device:
+ self.cos_cached = self.cos_cached.to(x.device)
+ self.sin_cached = self.sin_cached.to(x.device)
+ return (
+ self.cos_cached[:, :, :seq_len, ...],
+ self.sin_cached[:, :, :seq_len, ...],
+ )
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
+ cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
+ return q_embed.to(q.dtype), k_embed.to(k.dtype)
+
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: BaichuanConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ proj = self.W_pack(hidden_states)
+ proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
+ query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+ if xops is not None and self.training:
+ attn_weights = None
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ attn_output = xops.memory_efficient_attention(
+ query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
+ )
+ else:
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
+ attn_output = F.scaled_dot_product_attention(
+ query_states, key_states, value_states, attn_mask=attention_mask
+ )
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class DecoderLayer(nn.Module):
+ def __init__(self, config: BaichuanConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Attention(config=config)
+ self.mlp = MLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class BaichuanPreTrainedModel(PreTrainedModel):
+ config_class = BaichuanConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DecoderLayer"]
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, BaichuanModel):
+ module.gradient_checkpointing = value
+
+
+class BaichuanModel(BaichuanPreTrainedModel):
+ def __init__(self, config: BaichuanConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class NormHead(nn.Module):
+ def __init__(self, hidden_size, vocab_size, bias=False):
+ super().__init__()
+ self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ self.first_flag = True
+
+ def forward(self, hidden_states):
+ if self.training:
+ norm_weight = nn.functional.normalize(self.weight)
+ elif self.first_flag:
+ self.first_flag = False
+ self.weight = nn.Parameter(nn.functional.normalize(self.weight))
+ norm_weight = self.weight
+ else:
+ norm_weight = self.weight
+ return nn.functional.linear(hidden_states, norm_weight)
+
+
+_init_weights = True
+
+
+@contextmanager
+def no_init_weights(_enable=True):
+ global _init_weights
+ old_init_weights = _init_weights
+ if _enable:
+ _init_weights = False
+ try:
+ yield
+ finally:
+ _init_weights = old_init_weights
+
+
+class BaichuanForCausalLM(BaichuanPreTrainedModel):
+ def __init__(self, config, *model_args, **model_kwargs):
+ super().__init__(config, *model_args, **model_kwargs)
+ self.model = BaichuanModel(config)
+
+ self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
+ if hasattr(config, "quantization_config") and config.quantization_config["load_in_4bit"]:
+ try:
+ from .quantizer import quantize_offline
+ except ImportError:
+ raise ImportError(f"Needs QLinear to run quantize.")
+ quantize_offline(self, 4)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ *model_args,
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ ignore_mismatched_sizes: bool = False,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ use_safetensors: bool = None,
+ **kwargs,
+ ):
+ # Load config if we don't provide a configuration
+ if not isinstance(config, PretrainedConfig):
+ config_path = config if config is not None else pretrained_model_name_or_path
+ config, model_kwargs = cls.config_class.from_pretrained(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=False,
+ proxies=None,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder="",
+ _from_auto=False,
+ _from_pipeline=None,
+ **kwargs,
+ )
+ else:
+ pass
+
+ if hasattr(config, "quantization_config") and config.quantization_config["load_in_4bit"]:
+ try:
+ from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
+ from accelerate.utils import CustomDtype, get_balanced_memory
+
+ from .quantizer import init_model_weight_int4
+ except ImportError:
+ raise ImportError(f"Needs import model weight init func to run quantize.")
+ # Instantiate model.
+ init_contexts = [no_init_weights(_enable=True)]
+ init_contexts.append(init_empty_weights())
+ with ContextManagers(init_contexts):
+ model = cls(config)
+
+ model_file = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
+ state_dict = torch.load(model_file, map_location="cpu")
+ model.is_quantized = True
+
+ device_map = kwargs.pop("device_map", None)
+ kwargs.pop("torch_dtype", None)
+
+ kwargs = {"no_split_module_classes": model._no_split_modules}
+ target_dtype = CustomDtype.INT4
+ max_memory = get_balanced_memory(
+ model,
+ dtype=target_dtype,
+ low_zero=(device_map == "balanced_low_0"),
+ max_memory=None,
+ **kwargs,
+ )
+ kwargs["max_memory"] = max_memory
+
+ device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
+ model = init_model_weight_int4(config, model, state_dict)
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+ # If it is a model with generation capabilities, attempt to load the generation config
+ if model.can_generate():
+ try:
+ model.generation_config = GenerationConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=False,
+ proxies=None,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder="",
+ _from_auto=False,
+ _from_pipeline=None,
+ **kwargs,
+ )
+ except (OSError, TypeError):
+ logger.info(
+ "Generation config file not found, using a generation config created from the model config."
+ )
+
+ if device_map is not None:
+ dispatch_model(model, device_map=device_map)
+
+ return model
+ return super(BaichuanForCausalLM, cls).from_pretrained(
+ pretrained_model_name_or_path,
+ *model_args,
+ config=config,
+ cache_dir=cache_dir,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ use_safetensors=use_safetensors,
+ **kwargs,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ softmax_normalizer = shift_logits.max(-1).values ** 2
+ z_loss = self.config.z_loss_weight * softmax_normalizer.mean()
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels) + z_loss
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+ def quantize(self, bits: int):
+ try:
+ from .quantizer import quantize_online
+ except ImportError:
+ raise ImportError(f"Needs QLinear to run quantize.")
+ return quantize_online(self, bits)
+
+ def chat(self, tokenizer, messages: List[dict], stream=False, generation_config: Optional[GenerationConfig] = None):
+ generation_config = generation_config or self.generation_config
+ input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
+ if stream:
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
+ Thread(
+ target=self.generate,
+ kwargs=dict(
+ inputs=input_ids,
+ streamer=streamer,
+ generation_config=generation_config,
+ ),
+ ).start()
+ return streamer
+ else:
+ outputs = self.generate(input_ids, generation_config=generation_config)
+ response = tokenizer.decode(outputs[0][len(input_ids[0]) :], skip_special_tokens=True)
+ return response
diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py
new file mode 100644
index 000000000000..fd4c186fc688
--- /dev/null
+++ b/tests/test_infer/test_models/test_baichuan.py
@@ -0,0 +1,98 @@
+import random
+
+import numpy as np
+import pytest
+import torch
+from model_utils.baichuan2_7B.configuration_baichuan import BaichuanConfig
+from model_utils.baichuan2_7B.modeling_baichuan import BaichuanForCausalLM
+from transformers import AutoTokenizer, GenerationConfig
+
+import colossalai
+from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
+from colossalai.inference.core.engine import InferenceEngine
+from colossalai.inference.flash_decoding_utils import FDIntermTensors
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def check_inference_engine(use_engine=False, prompt_template=None):
+ setup_seed(20)
+ tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-7B-Base", use_fast=False, trust_remote_code=True)
+ model = BaichuanForCausalLM(
+ BaichuanConfig(
+ vocab_size=125696, hidden_size=32, intermediate_size=1376, num_attention_heads=1, num_hidden_layers=1
+ )
+ ).cuda()
+ model = model.eval()
+
+ inputs = [
+ "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
+ ]
+
+ output_len = 38
+ do_sample = True
+ top_p = 0.5
+ top_k = 50
+
+ if use_engine:
+ inference_config = InferenceConfig(
+ max_output_len=output_len, prompt_template=prompt_template, dtype="fp32", use_cuda_kernel=True
+ )
+ inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
+ assert inference_engine.generation_config.max_new_tokens == output_len
+ inference_engine.add_request(prompts=inputs)
+ assert inference_engine.request_handler._has_waiting()
+ generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
+ outputs = inference_engine.generate(generation_config=generation_config)
+ else:
+ if prompt_template:
+ # apply prompt template
+ inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
+ inputs = inputs.cuda()
+ generation_config = GenerationConfig(
+ do_sample=do_sample,
+ top_p=top_p,
+ top_k=top_k,
+ pad_token_id=tokenizer.pad_token_id,
+ max_new_tokens=output_len,
+ )
+ outputs = model.generate(inputs, generation_config=generation_config)
+ outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+
+ return outputs
+
+
+@parameterize("prompt_template", [None, "baichuan"])
+def check_output_consistency(prompt_template):
+ cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
+ transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
+
+ for s1, s2 in zip(cai_outputs, transformer_outputs):
+ assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
+
+ # clear singleton flash decoding tensors
+ FDIntermTensors._instances = {}
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
+ check_output_consistency()
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_inference_engine():
+ spawn(run_dist, 1)
+
+
+if __name__ == "__main__":
+ test_inference_engine()
From 9517fa61a01e6115c1c8dabc335276e6e1a73cbf Mon Sep 17 00:00:00 2001
From: yuehuayingxueluo <867460659@qq.com>
Date: Fri, 12 Apr 2024 16:49:18 +0800
Subject: [PATCH 2/5] modified according to the review comments.
---
colossalai/inference/core/engine.py | 1 +
.../modeling/models/nopadding_baichuan.py | 12 +-
.../modeling/models/nopadding_llama.py | 274 +++++++-----------
3 files changed, 116 insertions(+), 171 deletions(-)
diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py
index f6b5a6e7951e..466f6749ba10 100644
--- a/colossalai/inference/core/engine.py
+++ b/colossalai/inference/core/engine.py
@@ -27,6 +27,7 @@
_supported_models = [
"LlamaForCausalLM",
+ "BaichuanForCausalLM",
]
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py
index 63cd35252bf2..56e6b0ca96ab 100644
--- a/colossalai/inference/modeling/models/nopadding_baichuan.py
+++ b/colossalai/inference/modeling/models/nopadding_baichuan.py
@@ -5,7 +5,7 @@
import torch.nn as nn
from colossalai.inference.flash_decoding_utils import FDIntermTensors
-from colossalai.inference.modeling.models.nopadding_llama import llama_base_attn_forward
+from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.logging import get_dist_logger
@@ -13,8 +13,6 @@
logger = get_dist_logger(__name__)
-# def get_alibi_slopes(n_head):
-
class NopadBaiChuanAttention(nn.Module):
def __init__(
@@ -85,6 +83,8 @@ def forward(
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True,
+ is_verifier: bool = False,
+ tokens_to_verify: int = None,
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
@@ -112,7 +112,7 @@ def forward(
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
- return llama_base_attn_forward(
+ return NopadLlamaAttention.forward(
self,
hidden_states=hidden_states,
block_tables=block_tables,
@@ -122,6 +122,8 @@ def forward(
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
is_prompts=is_prompts,
+ is_verifier=is_verifier,
+ tokens_to_verify=tokens_to_verify,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
sm_scale=sm_scale,
@@ -139,7 +141,7 @@ def __init__(
mlp_uproj_w: torch.Tensor = None,
mlp_dproj_w: torch.Tensor = None,
):
- """This layer will replace the LlamaAttention.
+ """This layer will replace the BaiChuanAttention.
Args:
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py
index c58038070bc9..2b14190daeea 100644
--- a/colossalai/inference/modeling/models/nopadding_llama.py
+++ b/colossalai/inference/modeling/models/nopadding_llama.py
@@ -330,6 +330,7 @@ def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttentio
return attn_layer
+ # Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
@@ -369,23 +370,113 @@ def forward(
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
- return llama_base_attn_forward(
- self,
- hidden_states=hidden_states,
- block_tables=block_tables,
- k_cache=k_cache,
- v_cache=v_cache,
- sequence_lengths=sequence_lengths,
- cos_sin=cos_sin,
- fd_inter_tensor=fd_inter_tensor,
- is_prompts=is_prompts,
- kv_seq_len=kv_seq_len,
- output_tensor=output_tensor,
- sm_scale=sm_scale,
- use_cuda_kernel=use_cuda_kernel,
- cu_seqlens=cu_seqlens,
- high_precision=high_precision,
- )
+ token_nums = hidden_states.size(0)
+
+ if self.num_heads != self.num_key_value_heads:
+ query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
+ key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
+ value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
+ else:
+ # fused qkv
+ hidden_states = hidden_states.expand(3, -1, -1)
+ query_states, key_states, value_states = (
+ torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
+ )
+
+ block_size = k_cache.size(-2)
+
+ if is_prompts:
+ if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
+ # flash attn 2 currently only supports FP16/BF16.
+ inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
+ inference_ops.context_kv_cache_memcpy(
+ key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
+ )
+
+ attn_output = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=kv_seq_len,
+ max_seqlen_k=kv_seq_len,
+ dropout_p=0.0,
+ softmax_scale=sm_scale,
+ causal=True,
+ )
+ attn_output = attn_output.view(token_nums, -1)
+ else:
+ rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
+ attn_output = context_attention_unpadded(
+ q=query_states,
+ k=key_states,
+ v=value_states,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ context_lengths=sequence_lengths,
+ block_tables=block_tables,
+ block_size=block_size,
+ output=output_tensor,
+ max_seq_len=kv_seq_len,
+ sm_scale=sm_scale,
+ )
+ else:
+ q_len = tokens_to_verify + 1 if is_verifier else 1
+
+ if use_cuda_kernel:
+ inference_ops.rotary_embedding_and_cache_copy(
+ query_states,
+ key_states,
+ value_states,
+ cos_sin[0],
+ cos_sin[1],
+ k_cache,
+ v_cache,
+ sequence_lengths,
+ block_tables,
+ high_precision,
+ )
+ else:
+ if is_verifier:
+ rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
+ copy_k_to_blocked_cache(
+ key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
+ )
+ copy_k_to_blocked_cache(
+ value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
+ )
+ else:
+ decoding_fused_rotary_embedding(
+ query_states,
+ key_states,
+ value_states,
+ cos_sin[0],
+ cos_sin[1],
+ k_cache,
+ v_cache,
+ block_tables,
+ sequence_lengths,
+ )
+ attn_output = flash_decoding_attention(
+ q=query_states,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ kv_seq_len=sequence_lengths,
+ block_tables=block_tables,
+ block_size=block_size,
+ max_seq_len_in_batch=kv_seq_len,
+ output=output_tensor,
+ mid_output=fd_inter_tensor.mid_output,
+ mid_output_lse=fd_inter_tensor.mid_output_lse,
+ sm_scale=sm_scale,
+ q_len=q_len,
+ )
+
+ attn_output = attn_output.view(-1, self.hidden_size)
+ attn_output = torch.mm(attn_output, self.o_proj_weight)
+
+ return attn_output
# NOTE This will cause the result to be different from the transformer in some cases.
@@ -443,152 +534,3 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return torch.mm(act_out, self.down_proj_weight)
-
-
-# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
-def llama_base_attn_forward(
- self,
- hidden_states: torch.Tensor,
- block_tables: torch.Tensor,
- k_cache: torch.Tensor,
- v_cache: torch.Tensor,
- sequence_lengths: torch.Tensor,
- cos_sin: Tuple[torch.Tensor],
- fd_inter_tensor: FDIntermTensors,
- is_prompts: bool = True,
- is_verifier: bool = False,
- tokens_to_verify: int = None,
- kv_seq_len: int = 0,
- output_tensor: torch.Tensor = None,
- sm_scale: int = None,
- use_cuda_kernel: bool = True,
- cu_seqlens: torch.Tensor = None,
- high_precision: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """
- Args:
- hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
- block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
- storing mapping of token_position_id -> block_id.
- k_cache (torch.Tensor): It holds the GPU memory for the key cache.
- v_cache (torch.Tensor): It holds the GPU memory for the key cache.
- sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
- cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
- fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
- storing intermediate values in flash-decoding.
- is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
- kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
- output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
- sm_scale (int, optional): Used for flash attention. Defaults to None.
- use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
- cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
- high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
- """
-
- token_nums = hidden_states.size(0)
-
- if self.num_heads != self.num_key_value_heads:
- query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
- key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
- value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
- else:
- # fused qkv
- hidden_states = hidden_states.expand(3, -1, -1)
- query_states, key_states, value_states = (
- torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
- )
-
- block_size = k_cache.size(-2)
-
- if is_prompts:
- if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
- # flash attn 2 currently only supports FP16/BF16.
- inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
- inference_ops.context_kv_cache_memcpy(
- key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
- )
-
- attn_output = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens,
- cu_seqlens_k=cu_seqlens,
- max_seqlen_q=kv_seq_len,
- max_seqlen_k=kv_seq_len,
- dropout_p=0.0,
- softmax_scale=sm_scale,
- causal=True,
- )
- attn_output = attn_output.view(token_nums, -1)
- else:
- rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
- attn_output = context_attention_unpadded(
- q=query_states,
- k=key_states,
- v=value_states,
- k_cache=k_cache,
- v_cache=v_cache,
- context_lengths=sequence_lengths,
- block_tables=block_tables,
- block_size=block_size,
- output=output_tensor,
- max_seq_len=kv_seq_len,
- sm_scale=sm_scale,
- )
- else:
- q_len = tokens_to_verify + 1 if is_verifier else 1
-
- if use_cuda_kernel:
- inference_ops.rotary_embedding_and_cache_copy(
- query_states,
- key_states,
- value_states,
- cos_sin[0],
- cos_sin[1],
- k_cache,
- v_cache,
- sequence_lengths,
- block_tables,
- high_precision,
- )
- else:
- if is_verifier:
- rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
- copy_k_to_blocked_cache(
- key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
- )
- copy_k_to_blocked_cache(
- value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
- )
- else:
- decoding_fused_rotary_embedding(
- query_states,
- key_states,
- value_states,
- cos_sin[0],
- cos_sin[1],
- k_cache,
- v_cache,
- block_tables,
- sequence_lengths,
- )
- attn_output = flash_decoding_attention(
- q=query_states,
- k_cache=k_cache,
- v_cache=v_cache,
- kv_seq_len=sequence_lengths,
- block_tables=block_tables,
- block_size=block_size,
- max_seq_len_in_batch=kv_seq_len,
- output=output_tensor,
- mid_output=fd_inter_tensor.mid_output,
- mid_output_lse=fd_inter_tensor.mid_output_lse,
- sm_scale=sm_scale,
- q_len=q_len,
- )
-
- attn_output = attn_output.view(-1, self.hidden_size)
- attn_output = torch.mm(attn_output, self.o_proj_weight)
-
- return attn_output
From 25f107d5242f6e0dc0f345001aa8c4b4f7ee252f Mon Sep 17 00:00:00 2001
From: yuehuayingxueluo <867460659@qq.com>
Date: Mon, 15 Apr 2024 15:01:21 +0800
Subject: [PATCH 3/5] Modified the method of obtaining random weights.
---
.../model_utils/baichuan2_7B/__init__.py | 0
.../baichuan2_7B/configuration_baichuan.py | 66 --
.../baichuan2_7B/generation_utils.py | 81 --
.../baichuan2_7B/modeling_baichuan.py | 801 ------------------
tests/test_infer/test_models/test_baichuan.py | 32 +-
5 files changed, 18 insertions(+), 962 deletions(-)
delete mode 100644 tests/test_infer/test_models/model_utils/baichuan2_7B/__init__.py
delete mode 100644 tests/test_infer/test_models/model_utils/baichuan2_7B/configuration_baichuan.py
delete mode 100644 tests/test_infer/test_models/model_utils/baichuan2_7B/generation_utils.py
delete mode 100644 tests/test_infer/test_models/model_utils/baichuan2_7B/modeling_baichuan.py
diff --git a/tests/test_infer/test_models/model_utils/baichuan2_7B/__init__.py b/tests/test_infer/test_models/model_utils/baichuan2_7B/__init__.py
deleted file mode 100644
index e69de29bb2d1..000000000000
diff --git a/tests/test_infer/test_models/model_utils/baichuan2_7B/configuration_baichuan.py b/tests/test_infer/test_models/model_utils/baichuan2_7B/configuration_baichuan.py
deleted file mode 100644
index ed499b501ae9..000000000000
--- a/tests/test_infer/test_models/model_utils/baichuan2_7B/configuration_baichuan.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright 2023 Baichuan Inc. All Rights Reserved.
-
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from transformers.configuration_utils import PretrainedConfig
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-
-class BaichuanConfig(PretrainedConfig):
- model_type = "baichuan"
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- vocab_size=125696,
- hidden_size=4096,
- intermediate_size=11008,
- num_hidden_layers=32,
- num_attention_heads=32,
- hidden_act="silu",
- max_position_embeddings=4096,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- pad_token_id=0,
- bos_token_id=1,
- eos_token_id=2,
- tie_word_embeddings=False,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.use_cache = use_cache
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
diff --git a/tests/test_infer/test_models/model_utils/baichuan2_7B/generation_utils.py b/tests/test_infer/test_models/model_utils/baichuan2_7B/generation_utils.py
deleted file mode 100644
index 5cfafa2b76cd..000000000000
--- a/tests/test_infer/test_models/model_utils/baichuan2_7B/generation_utils.py
+++ /dev/null
@@ -1,81 +0,0 @@
-from queue import Queue
-from typing import List
-
-import torch
-
-
-def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int = 0):
- def _parse_messages(messages, split_role="user"):
- system, rounds = "", []
- round = []
- for i, message in enumerate(messages):
- if message["role"] == "system":
- assert i == 0
- system = message["content"]
- continue
- if message["role"] == split_role and round:
- rounds.append(round)
- round = []
- round.append(message)
- if round:
- rounds.append(round)
- return system, rounds
-
- max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
- max_input_tokens = model.config.model_max_length - max_new_tokens
- system, rounds = _parse_messages(messages, split_role="user")
- system_tokens = tokenizer.encode(system)
- max_history_tokens = max_input_tokens - len(system_tokens)
-
- history_tokens = []
- for round in rounds[::-1]:
- round_tokens = []
- for message in round:
- if message["role"] == "user":
- round_tokens.append(model.generation_config.user_token_id)
- else:
- round_tokens.append(model.generation_config.assistant_token_id)
- round_tokens.extend(tokenizer.encode(message["content"]))
- if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
- history_tokens = round_tokens + history_tokens # concat left
- if len(history_tokens) < max_history_tokens:
- continue
- break
-
- input_tokens = system_tokens + history_tokens
- if messages[-1]["role"] != "assistant":
- input_tokens.append(model.generation_config.assistant_token_id)
- input_tokens = input_tokens[-max_input_tokens:] # truncate left
- return torch.LongTensor([input_tokens]).to(model.device)
-
-
-class TextIterStreamer:
- def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
- self.tokenizer = tokenizer
- self.skip_prompt = skip_prompt
- self.skip_special_tokens = skip_special_tokens
- self.tokens = []
- self.text_queue = Queue()
- self.next_tokens_are_prompt = True
-
- def put(self, value):
- if self.skip_prompt and self.next_tokens_are_prompt:
- self.next_tokens_are_prompt = False
- else:
- if len(value.shape) > 1:
- value = value[0]
- self.tokens.extend(value.tolist())
- self.text_queue.put(self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
-
- def end(self):
- self.text_queue.put(None)
-
- def __iter__(self):
- return self
-
- def __next__(self):
- value = self.text_queue.get()
- if value is None:
- raise StopIteration()
- else:
- return value
diff --git a/tests/test_infer/test_models/model_utils/baichuan2_7B/modeling_baichuan.py b/tests/test_infer/test_models/model_utils/baichuan2_7B/modeling_baichuan.py
deleted file mode 100644
index f6db626fb0c4..000000000000
--- a/tests/test_infer/test_models/model_utils/baichuan2_7B/modeling_baichuan.py
+++ /dev/null
@@ -1,801 +0,0 @@
-# Copyright 2023 Baichuan Inc. All Rights Reserved.
-
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import math
-import os
-from contextlib import contextmanager
-from threading import Thread
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-from torch.nn import CrossEntropyLoss
-from torch.nn import functional as F
-from transformers import PretrainedConfig, PreTrainedModel
-from transformers.activations import ACT2FN
-from transformers.generation.utils import GenerationConfig
-from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
-from transformers.utils import ContextManagers, logging
-
-from .configuration_baichuan import BaichuanConfig
-from .generation_utils import TextIterStreamer, build_chat_input
-
-logger = logging.get_logger(__name__)
-
-try:
- from xformers import ops as xops
-except ImportError:
- xops = None
- logger.warning(
- "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
- )
-
-
-# Copied from transformers.models.bart.modeling_bart._make_causal_mask
-def _make_causal_mask(
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
-):
- """
- Make causal mask used for bi-directional self-attention.
- """
- bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
- mask_cond = torch.arange(mask.size(-1), device=device)
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
- mask = mask.to(dtype)
-
- if past_key_values_length > 0:
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
-
-
-def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
- """
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
- """
- if len(mask.size()) == 3:
- bsz, src_len, _ = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
- expanded_mask = mask[:, None, :, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
- else:
- bsz, src_len = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
-
- inverted_mask = 1.0 - expanded_mask
-
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
-
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
-
- # convert into half-precision if necessary
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
- hidden_states = hidden_states.to(self.weight.dtype)
-
- return self.weight * hidden_states
-
-
-class RotaryEmbedding(torch.nn.Module):
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
- super().__init__()
- self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
- self.max_seq_len_cached = max_position_embeddings
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
- freqs = torch.outer(t, self.inv_freq)
- emb = torch.cat((freqs, freqs), dim=-1)
- self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
- self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
-
- def forward(self, x, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
- if seq_len > self.max_seq_len_cached:
- self.max_seq_len_cached = seq_len
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
- freqs = torch.outer(t, self.inv_freq)
- emb = torch.cat((freqs, freqs), dim=-1)
- self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
- self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
- elif self.cos_cached.device != x.device:
- self.cos_cached = self.cos_cached.to(x.device)
- self.sin_cached = self.sin_cached.to(x.device)
- return (
- self.cos_cached[:, :, :seq_len, ...],
- self.sin_cached[:, :, :seq_len, ...],
- )
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
- cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
- sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
- k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
- return q_embed.to(q.dtype), k_embed.to(k.dtype)
-
-
-class MLP(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- hidden_act: str,
- ):
- super().__init__()
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.act_fn = ACT2FN[hidden_act]
-
- def forward(self, x):
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
-
-
-class Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self, config: BaichuanConfig):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.max_position_embeddings = config.max_position_embeddings
-
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
- self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
- self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
-
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
-
- proj = self.W_pack(hidden_states)
- proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
- query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
- # [bsz, nh, t, hd]
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
- if xops is not None and self.training:
- attn_weights = None
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- attn_output = xops.memory_efficient_attention(
- query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
- )
- else:
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
- attn_output = F.scaled_dot_product_attention(
- query_states, key_states, value_states, attn_mask=attention_mask
- )
- attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class DecoderLayer(nn.Module):
- def __init__(self, config: BaichuanConfig):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = Attention(config=config)
- self.mlp = MLP(
- hidden_size=self.hidden_size,
- intermediate_size=config.intermediate_size,
- hidden_act=config.hidden_act,
- )
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- if use_cache:
- outputs += (present_key_value,)
-
- return outputs
-
-
-class BaichuanPreTrainedModel(PreTrainedModel):
- config_class = BaichuanConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["DecoderLayer"]
- _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, BaichuanModel):
- module.gradient_checkpointing = value
-
-
-class BaichuanModel(BaichuanPreTrainedModel):
- def __init__(self, config: BaichuanConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
- # create causal mask
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- combined_attention_mask = None
- if input_shape[-1] > 1:
- combined_attention_mask = _make_causal_mask(
- input_shape,
- inputs_embeds.dtype,
- device=inputs_embeds.device,
- past_key_values_length=past_key_values_length,
- )
-
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
- inputs_embeds.device
- )
- combined_attention_mask = (
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
- )
-
- return combined_attention_mask
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
- else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
-
- seq_length_with_past = seq_length
- past_key_values_length = 0
-
- if past_key_values is not None:
- past_key_values_length = past_key_values[0][0].shape[2]
- seq_length_with_past = seq_length_with_past + past_key_values_length
-
- if position_ids is None:
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- position_ids = torch.arange(
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
- )
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
- else:
- position_ids = position_ids.view(-1, seq_length).long()
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- # embed positions
- if attention_mask is None:
- attention_mask = torch.ones(
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
- )
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
- )
-
- hidden_states = inputs_embeds
-
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = () if use_cache else None
-
- for idx, decoder_layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- past_key_value = past_key_values[idx] if past_key_values is not None else None
-
- if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, output_attentions, None)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(decoder_layer),
- hidden_states,
- attention_mask,
- position_ids,
- None,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- next_cache = next_decoder_cache if use_cache else None
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
-
-
-class NormHead(nn.Module):
- def __init__(self, hidden_size, vocab_size, bias=False):
- super().__init__()
- self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
- nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
- self.first_flag = True
-
- def forward(self, hidden_states):
- if self.training:
- norm_weight = nn.functional.normalize(self.weight)
- elif self.first_flag:
- self.first_flag = False
- self.weight = nn.Parameter(nn.functional.normalize(self.weight))
- norm_weight = self.weight
- else:
- norm_weight = self.weight
- return nn.functional.linear(hidden_states, norm_weight)
-
-
-_init_weights = True
-
-
-@contextmanager
-def no_init_weights(_enable=True):
- global _init_weights
- old_init_weights = _init_weights
- if _enable:
- _init_weights = False
- try:
- yield
- finally:
- _init_weights = old_init_weights
-
-
-class BaichuanForCausalLM(BaichuanPreTrainedModel):
- def __init__(self, config, *model_args, **model_kwargs):
- super().__init__(config, *model_args, **model_kwargs)
- self.model = BaichuanModel(config)
-
- self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
- if hasattr(config, "quantization_config") and config.quantization_config["load_in_4bit"]:
- try:
- from .quantizer import quantize_offline
- except ImportError:
- raise ImportError(f"Needs QLinear to run quantize.")
- quantize_offline(self, 4)
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @classmethod
- def from_pretrained(
- cls,
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
- *model_args,
- config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
- cache_dir: Optional[Union[str, os.PathLike]] = None,
- ignore_mismatched_sizes: bool = False,
- force_download: bool = False,
- local_files_only: bool = False,
- token: Optional[Union[str, bool]] = None,
- revision: str = "main",
- use_safetensors: bool = None,
- **kwargs,
- ):
- # Load config if we don't provide a configuration
- if not isinstance(config, PretrainedConfig):
- config_path = config if config is not None else pretrained_model_name_or_path
- config, model_kwargs = cls.config_class.from_pretrained(
- config_path,
- cache_dir=cache_dir,
- return_unused_kwargs=True,
- force_download=force_download,
- resume_download=False,
- proxies=None,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder="",
- _from_auto=False,
- _from_pipeline=None,
- **kwargs,
- )
- else:
- pass
-
- if hasattr(config, "quantization_config") and config.quantization_config["load_in_4bit"]:
- try:
- from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
- from accelerate.utils import CustomDtype, get_balanced_memory
-
- from .quantizer import init_model_weight_int4
- except ImportError:
- raise ImportError(f"Needs import model weight init func to run quantize.")
- # Instantiate model.
- init_contexts = [no_init_weights(_enable=True)]
- init_contexts.append(init_empty_weights())
- with ContextManagers(init_contexts):
- model = cls(config)
-
- model_file = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
- state_dict = torch.load(model_file, map_location="cpu")
- model.is_quantized = True
-
- device_map = kwargs.pop("device_map", None)
- kwargs.pop("torch_dtype", None)
-
- kwargs = {"no_split_module_classes": model._no_split_modules}
- target_dtype = CustomDtype.INT4
- max_memory = get_balanced_memory(
- model,
- dtype=target_dtype,
- low_zero=(device_map == "balanced_low_0"),
- max_memory=None,
- **kwargs,
- )
- kwargs["max_memory"] = max_memory
-
- device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
- model = init_model_weight_int4(config, model, state_dict)
-
- # Set model in evaluation mode to deactivate DropOut modules by default
- model.eval()
- # If it is a model with generation capabilities, attempt to load the generation config
- if model.can_generate():
- try:
- model.generation_config = GenerationConfig.from_pretrained(
- pretrained_model_name_or_path,
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=False,
- proxies=None,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder="",
- _from_auto=False,
- _from_pipeline=None,
- **kwargs,
- )
- except (OSError, TypeError):
- logger.info(
- "Generation config file not found, using a generation config created from the model config."
- )
-
- if device_map is not None:
- dispatch_model(model, device_map=device_map)
-
- return model
- return super(BaichuanForCausalLM, cls).from_pretrained(
- pretrained_model_name_or_path,
- *model_args,
- config=config,
- cache_dir=cache_dir,
- ignore_mismatched_sizes=ignore_mismatched_sizes,
- force_download=force_download,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- use_safetensors=use_safetensors,
- **kwargs,
- )
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- hidden_states = outputs[0]
- logits = self.lm_head(hidden_states)
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- softmax_normalizer = shift_logits.max(-1).values ** 2
- z_loss = self.config.z_loss_weight * softmax_normalizer.mean()
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels) + z_loss
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def prepare_inputs_for_generation(
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
- ):
- if past_key_values:
- input_ids = input_ids[:, -1:]
-
- position_ids = kwargs.get("position_ids", None)
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids}
-
- model_inputs.update(
- {
- "position_ids": position_ids,
- "past_key_values": past_key_values,
- "use_cache": kwargs.get("use_cache"),
- "attention_mask": attention_mask,
- }
- )
- return model_inputs
-
- @staticmethod
- def _reorder_cache(past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
- return reordered_past
-
- def quantize(self, bits: int):
- try:
- from .quantizer import quantize_online
- except ImportError:
- raise ImportError(f"Needs QLinear to run quantize.")
- return quantize_online(self, bits)
-
- def chat(self, tokenizer, messages: List[dict], stream=False, generation_config: Optional[GenerationConfig] = None):
- generation_config = generation_config or self.generation_config
- input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
- if stream:
- streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
- Thread(
- target=self.generate,
- kwargs=dict(
- inputs=input_ids,
- streamer=streamer,
- generation_config=generation_config,
- ),
- ).start()
- return streamer
- else:
- outputs = self.generate(input_ids, generation_config=generation_config)
- response = tokenizer.decode(outputs[0][len(input_ids[0]) :], skip_special_tokens=True)
- return response
diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py
index fd4c186fc688..91ff2f343b6c 100644
--- a/tests/test_infer/test_models/test_baichuan.py
+++ b/tests/test_infer/test_models/test_baichuan.py
@@ -1,11 +1,10 @@
+import os
import random
import numpy as np
import pytest
import torch
-from model_utils.baichuan2_7B.configuration_baichuan import BaichuanConfig
-from model_utils.baichuan2_7B.modeling_baichuan import BaichuanForCausalLM
-from transformers import AutoTokenizer, GenerationConfig
+from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
@@ -13,6 +12,13 @@
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+PATH_EXIST = "baichuan-inc/Baichuan2-7B-Base"
+
+if os.path.exists(PATH_EXIST):
+ PATH_EXIST = True
+else:
+ PATH_EXIST = False
+
def setup_seed(seed):
torch.manual_seed(seed)
@@ -23,11 +29,9 @@ def setup_seed(seed):
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
- tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-7B-Base", use_fast=False, trust_remote_code=True)
- model = BaichuanForCausalLM(
- BaichuanConfig(
- vocab_size=125696, hidden_size=32, intermediate_size=1376, num_attention_heads=1, num_hidden_layers=1
- )
+ tokenizer = AutoTokenizer.from_pretrained(PATH_EXIST, use_fast=False, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(
+ PATH_EXIST, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
).cuda()
model = model.eval()
@@ -36,9 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
]
output_len = 38
- do_sample = True
- top_p = 0.5
- top_k = 50
+ do_sample = False
if use_engine:
inference_config = InferenceConfig(
@@ -48,7 +50,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
- generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
+ generation_config = GenerationConfig(do_sample=do_sample)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
@@ -60,8 +62,6 @@ def check_inference_engine(use_engine=False, prompt_template=None):
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
- top_p=top_p,
- top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
@@ -88,6 +88,10 @@ def run_dist(rank, world_size, port):
check_output_consistency()
+@pytest.mark.skipif(
+ not PATH_EXIST,
+ reason="There is no local model address included, please replace this address with a valid one.",
+)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
From c1a5fa3f510ee5138e58189074df7e7d1beab2dd Mon Sep 17 00:00:00 2001
From: yuehuayingxueluo <867460659@qq.com>
Date: Mon, 15 Apr 2024 16:17:33 +0800
Subject: [PATCH 4/5] modified according to the review comments.
---
.../inference/modeling/models/nopadding_baichuan.py | 12 ++++++------
colossalai/inference/modeling/policy/__init__.py | 6 +++---
.../inference/modeling/policy/nopadding_baichuan.py | 8 +++-----
tests/test_infer/test_models/test_baichuan.py | 13 ++++---------
4 files changed, 16 insertions(+), 23 deletions(-)
diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py
index 56e6b0ca96ab..e60f368636aa 100644
--- a/colossalai/inference/modeling/models/nopadding_baichuan.py
+++ b/colossalai/inference/modeling/models/nopadding_baichuan.py
@@ -14,7 +14,7 @@
logger = get_dist_logger(__name__)
-class NopadBaiChuanAttention(nn.Module):
+class NopadBaichuanAttention(nn.Module):
def __init__(
self,
config,
@@ -47,11 +47,11 @@ def __init__(
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
@staticmethod
- def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
- """Used for initialize the weight of NopadBaiChuanAttention by origin BaiChuanAttention.
+ def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBaichuanAttention":
+ """Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention.
Args:
- module (BaiChuanAttention): The origin BaiChuanAttention layer.
+ module (nn.Module): The origin BaichuanAttention layer.
"""
config = module.config
@@ -63,7 +63,7 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
attn_vproj_w = v_proj_w.transpose(0, 1)
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
- attn_layer = NopadBaiChuanAttention(
+ attn_layer = NopadBaichuanAttention(
config=config,
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
@@ -141,7 +141,7 @@ def __init__(
mlp_uproj_w: torch.Tensor = None,
mlp_dproj_w: torch.Tensor = None,
):
- """This layer will replace the BaiChuanAttention.
+ """This layer will replace the BaichuanAttention.
Args:
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py
index 82e0b69bb19b..fa03955907fe 100644
--- a/colossalai/inference/modeling/policy/__init__.py
+++ b/colossalai/inference/modeling/policy/__init__.py
@@ -1,16 +1,16 @@
from .glide_llama import GlideLlamaModelPolicy
-from .nopadding_baichuan import NoPaddingBaiChuanModelInferPolicy
+from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
- "nopadding_baichuan": NoPaddingBaiChuanModelInferPolicy,
+ "nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
"glide_llama": GlideLlamaModelPolicy,
}
__all__ = [
"NoPaddingLlamaModelInferPolicy",
- "NoPaddingBaiChuanModelInferPolicy",
+ "NoPaddingBaichuanModelInferPolicy",
"GlideLlamaModelPolicy",
"model_polic_map",
]
diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py
index a45184b235cc..64dc40dbc0b9 100644
--- a/colossalai/inference/modeling/policy/nopadding_baichuan.py
+++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py
@@ -1,7 +1,7 @@
import torch.nn as nn
from torch.nn import Parameter
-from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaiChuanAttention, NopadBaichuanMLP
+from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaichuanAttention, NopadBaichuanMLP
from colossalai.inference.modeling.models.nopadding_llama import (
llama_causal_lm_forward,
llama_decoder_layer_forward,
@@ -10,12 +10,10 @@
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
-
-# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
-class NoPaddingBaiChuanModelInferPolicy(LlamaForCausalLMPolicy):
+class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
@@ -39,7 +37,7 @@ def module_policy(self):
),
SubModuleReplacementDescription(
suffix="self_attn",
- target_module=NopadBaiChuanAttention,
+ target_module=NopadBaichuanAttention,
),
]
)
diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py
index 91ff2f343b6c..5ca67c5be7b4 100644
--- a/tests/test_infer/test_models/test_baichuan.py
+++ b/tests/test_infer/test_models/test_baichuan.py
@@ -12,12 +12,7 @@
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-PATH_EXIST = "baichuan-inc/Baichuan2-7B-Base"
-
-if os.path.exists(PATH_EXIST):
- PATH_EXIST = True
-else:
- PATH_EXIST = False
+BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
def setup_seed(seed):
@@ -29,9 +24,9 @@ def setup_seed(seed):
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
- tokenizer = AutoTokenizer.from_pretrained(PATH_EXIST, use_fast=False, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
- PATH_EXIST, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
+ BAICHUAN_MODEL_NAME_OR_PATH, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
).cuda()
model = model.eval()
@@ -89,7 +84,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.skipif(
- not PATH_EXIST,
+ not os.path.exists(BAICHUAN_MODEL_NAME_OR_PATH),
reason="There is no local model address included, please replace this address with a valid one.",
)
@pytest.mark.dist
From 437ab59192066ae3297d47f978782ea154ee6116 Mon Sep 17 00:00:00 2001
From: yuehuayingxueluo <867460659@qq.com>
Date: Mon, 15 Apr 2024 16:49:25 +0800
Subject: [PATCH 5/5] change mlp layewr 'NOTE'
---
colossalai/inference/modeling/models/nopadding_baichuan.py | 2 +-
colossalai/inference/modeling/models/nopadding_llama.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py
index e60f368636aa..893d45c1f2c4 100644
--- a/colossalai/inference/modeling/models/nopadding_baichuan.py
+++ b/colossalai/inference/modeling/models/nopadding_baichuan.py
@@ -133,7 +133,7 @@ def forward(
)
-# NOTE This will cause the result to be different from the transformer in some cases.
+# NOTE This will cause difference as out length increases.
class NopadBaichuanMLP(nn.Module):
def __init__(
self,
diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py
index 2b14190daeea..010abc1db0b1 100644
--- a/colossalai/inference/modeling/models/nopadding_llama.py
+++ b/colossalai/inference/modeling/models/nopadding_llama.py
@@ -479,7 +479,7 @@ def forward(
return attn_output
-# NOTE This will cause the result to be different from the transformer in some cases.
+# NOTE This will cause difference as out length increases.
class NopadLlamaMLP(LlamaMLP):
def __init__(
self,