Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2a52867
add infer_struct and infer_config
isky-cd Dec 1, 2023
fbe49fc
update codes
isky-cd Dec 4, 2023
2ee231d
change InferConfig
isky-cd Dec 6, 2023
c500891
Add hf_model_config to the engine
isky-cd Dec 6, 2023
9aeb8e8
rm _get_hf_model_config
isky-cd Dec 6, 2023
c53bfe2
update codes
isky-cd Dec 6, 2023
c09279e
made adjustments according to the feedback from the reviewer.
isky-cd Dec 6, 2023
736dcbb
update codes
isky-cd Dec 6, 2023
138a2be
Merge branch 'feature/colossal-infer' into infer_struct_and_config
isky-cd Dec 7, 2023
d5058a3
add ci test for config and struct
isky-cd Dec 7, 2023
9608ec2
Merge branch 'feature/colossal-infer' into infer_struct_and_config
isky-cd Dec 7, 2023
3d7f8b4
Add the logic of the inference engine
isky-cd Dec 8, 2023
43ff010
fix conflict
isky-cd Dec 11, 2023
687a3b6
update engine and test
isky-cd Dec 12, 2023
6d6de07
Recover cache_manager.py
isky-cd Dec 12, 2023
7a71a3e
add logger
isky-cd Dec 13, 2023
12bc611
fix conflict
isky-cd Dec 13, 2023
3d93c65
fix conflict
isky-cd Dec 13, 2023
d970a56
fix conflict
isky-cd Dec 13, 2023
93536cc
update codes
isky-cd Dec 14, 2023
4524fe9
update codes
isky-cd Dec 14, 2023
a1cf6e2
update model and tokenizer
isky-cd Dec 14, 2023
95485ce
fix add the logic about shardformer
isky-cd Dec 14, 2023
e5f60d6
change kvcache_manager docstring
isky-cd Dec 14, 2023
f4c4e64
add policy
isky-cd Dec 14, 2023
74dde72
fix ci bug in test_kvcache_manager.py
isky-cd Dec 14, 2023
cc61172
remove codes related o tokenizer and move model_policy
isky-cd Dec 15, 2023
1ca8115
fix code style
isky-cd Dec 15, 2023
177fcd8
add ordered_set to requirements-infer.txt
isky-cd Dec 17, 2023
ed2db8e
Delete extra empty lines
isky-cd Dec 17, 2023
99e30ed
add ordered_set to requirements-test.txt
isky-cd Dec 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 44 additions & 34 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.distributed as dist

GibiByte = 1024**3

Expand All @@ -15,44 +15,44 @@ class InferenceConfig:
"""The inference configuration.

Args:
model: Path or nn.Module of this model.
tokenizer: Path of the tokenizer to use.
tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Whether to trust remote code from huggingface.
max_batch_size: Maximum batch size.
max_output_len: Maximum output length.
max_input_len: Maximum input length.
block_size: The number of blocks in a logical block.
dtype: The data type for weights and activations.
tp_size: Tensor parallel size.
pp_size: Pipeline parallel size.
max_seq_len: Maximum length of input sentence.
quant_mode: Quantization mode.
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
beam_width: The maximum beam width used to initialize KV Cache.
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
max_batch_size (int): Maximum batch size.
max_output_len (int): Maximum output length.
max_input_len (int): Maximum input length.
block_size (int): The number of blocks in a logical block.
dtype (Union[str, torch.dtype]): The data type for weights and activations.
tp_size (int): Tensor parallel size.
pp_size (int): Pipeline parallel size.
max_seq_len (int): Maximum length of input sentence.
beam_width (int): The maximum beam width used to initialize KV Cache.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
when the actual value exceeds this ratio.
quant_mode (Optional[str]): Quantization mode.
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
"""

model: Union[str, nn.Module]
tokenizer: str = None
tokenizer_mode: str = "auto"
trust_remote_code: bool = False
max_batch_size: int = None
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
max_batch_size: int = 8
max_output_len: int = 256
max_input_len: int = 256
block_size: int = 16
dtype: Union[str, torch.dtype] = torch.float32
tp_size: int = 1
pp_size: int = 1
max_seq_len: Optional[int] = None
quant_mode: Optional[str] = None
revision: Optional[str] = None
beam_width: int = 1
max_seq_len: int = 512
# TODO: beam search is not support for now
prefill_ratio: Optional[float] = 1.2
beam_width: int = 1
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
prefill_ratio: Optional[float] = 1.2
quant_mode: Optional[str] = None
revision: Optional[str] = None

def __post_init__(self):
self._init_batch_size()
self._verify_config()

def _init_batch_size(self):
"""
Expand All @@ -75,10 +75,20 @@ def _init_batch_size(self):
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
)

def __post_init__(self):
self._init_batch_size()
self._verify_args()

def _verify_args(self):
if self.tokenizer_mode not in ["auto", "slow"]:
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")
def _verify_config(self) -> None:
"""
Verify the input config
"""
assert (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert self.dtype in [
"fp16",
"fp32",
"bf16",
torch.float32,
torch.float16,
torch.bfloat16,
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
231 changes: 199 additions & 32 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,232 @@
from logging import Logger
from typing import Optional
from itertools import count
from typing import List, Optional, Union

from transformers import AutoConfig
import torch
import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast

from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import InferenceConfig
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy

from .request_handler import RequestHandler

PP_AXIS, TP_AXIS = 0, 1

_supported_models = [
"LlamaForCausalLM",
]


class InferenceEngine:
"""
InferenceEngine is the core component for Inference.

It is responsible for launch the inference process, including:
- Initialize model and distributed training environment(if needed)
- Launch request_handler and corresponding kv cache manager
- Receive requests and generate texts.
- Log the generation process
"""
InferenceEngine which manages the inference process..

Args:
tokenizer: Path of the tokenizer to use.
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
model (nn.Module): Path or nn.Module of this model.
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
"""

def __init__(
self,
tokenizer: str = None,
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: Optional["InferenceConfig"] = None,
verbose: bool = False,
model_policy: Policy = None,
) -> None:
assert inference_config, "Please provide inference_config."

self._init_model()
# cache_config may need to be modified later.
# self.request_handler = RequestHandler(cache_config)
self.tokenizer = tokenizer
self.hf_model_config = AutoConfig.from_pretrained(
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
self.inference_config = inference_config
self.model_config = model.config

if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
self.dtype = torch.float16
model.half()
else:
self.dtype = torch.bfloat16
model.to(torch.bfloat16)

if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]()

pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)

self.model = self._shardformer(
model,
model_policy,
None,
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
)

self.verbose = verbose
if verbose:
self.logger = Logger()
self.logger = get_dist_logger(__name__)

self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.counter = count()

def _verify_config(self) -> None:
"""
Verify the input config
"""
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
self.tokenizer, PreTrainedTokenizer
):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
)
assert (
self.model.__class__.__name__ in _supported_models
), f"Model {self.model.__class__.__name__} is not supported."

def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.

Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.

Returns:
nn.Module: _description_
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"quant": self.inference_config.quant_mode},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()

def _init_model(self):
def generate(
self,
generation_config: GenerationConfig = None,
) -> List[str]:
"""
Initialize model and distributed training environment(if needed).
May need to provide two different initialization methods:
1. 用户自定义(from local path)
2. 从checkpoint加载(hugging face)
Executing the inference step.

Args:
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.

Returns:
List[str]: Inference result returned by one generation.
"""

def _verify_config(self):
self.generation_config = generation_config

output_list = []

while self.request_handler.check_unfinished_seqs():
output_list += self.step()

return output_list

def add_request(
self,
requests_id: List[int] = None,
prompts: List[str] = None,
prompts_token_ids: List[int] = None,
) -> None:
"""
Verify the configuration to avoid potential bugs.
Add requests.

Args:
requests_id (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
"""

def generate(self):
pass
block_size = self.inference_config.block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block size is a weird arg , what does it mean?


def step(self):
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = []
for prompt in prompts:
prompts_token_ids.append(self.tokenizer.encode(prompt))

prompts_num = len(prompts_token_ids)

for i in range(prompts_num):
if requests_id:
request_id = requests_id[i]
else:
request_id = next(self.counter)
if prompts == None:
prompt = None
else:
prompt = prompts[i]
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
None,
self.tokenizer.eos_token_id,
self.inference_config.max_output_len,
)
self.request_handler.add_sequence(sequence)

def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run request_handler to update the kv cache and running input_ids
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Run model to generate the next token
3. Check whether there is finied request and decode
3. Update waiting list and running list in RequestHandler and get finished sequences.
4. Decode and return finished sequences.

Returns:
List[str]: Decoded finished sequences generated by one step.
"""

if self.verbose:
self.logger.info("Running generation step")

output_list = []
self.request_handler.schedule()

# Uncomment if the development of RequestHandler is completed.
# logits = self.model(batch)
# self.request_handler.search_tokens(logits, self.generation_config)

finished_sequences = self.request_handler.update()

# Decode completed sentences.
for seq in finished_sequences:
if seq.prompt:
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
output_list.append(seq.prompt + output_str)
else:
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
output_list.append(output_str)

return output_list
Loading