-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[Inference] Add the logic of the inference engine #5173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
FrankLeeeee
merged 31 commits into
hpcaitech:feature/colossal-infer
from
isky-cd:infer_struct_and_config
Dec 18, 2023
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
2a52867
add infer_struct and infer_config
isky-cd fbe49fc
update codes
isky-cd 2ee231d
change InferConfig
isky-cd c500891
Add hf_model_config to the engine
isky-cd 9aeb8e8
rm _get_hf_model_config
isky-cd c53bfe2
update codes
isky-cd c09279e
made adjustments according to the feedback from the reviewer.
isky-cd 736dcbb
update codes
isky-cd 138a2be
Merge branch 'feature/colossal-infer' into infer_struct_and_config
isky-cd d5058a3
add ci test for config and struct
isky-cd 9608ec2
Merge branch 'feature/colossal-infer' into infer_struct_and_config
isky-cd 3d7f8b4
Add the logic of the inference engine
isky-cd 43ff010
fix conflict
isky-cd 687a3b6
update engine and test
isky-cd 6d6de07
Recover cache_manager.py
isky-cd 7a71a3e
add logger
isky-cd 12bc611
fix conflict
isky-cd 3d93c65
fix conflict
isky-cd d970a56
fix conflict
isky-cd 93536cc
update codes
isky-cd 4524fe9
update codes
isky-cd a1cf6e2
update model and tokenizer
isky-cd 95485ce
fix add the logic about shardformer
isky-cd e5f60d6
change kvcache_manager docstring
isky-cd f4c4e64
add policy
isky-cd 74dde72
fix ci bug in test_kvcache_manager.py
isky-cd cc61172
remove codes related o tokenizer and move model_policy
isky-cd 1ca8115
fix code style
isky-cd 177fcd8
add ordered_set to requirements-infer.txt
isky-cd ed2db8e
Delete extra empty lines
isky-cd 99e30ed
add ordered_set to requirements-test.txt
isky-cd File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,232 @@ | ||
from logging import Logger | ||
from typing import Optional | ||
from itertools import count | ||
from typing import List, Optional, Union | ||
|
||
from transformers import AutoConfig | ||
import torch | ||
import torch.nn as nn | ||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast | ||
|
||
from colossalai.cluster import ProcessGroupMesh | ||
from colossalai.inference.config import InferenceConfig | ||
from colossalai.inference.modeling.policy import model_policy_map | ||
from colossalai.inference.struct import Sequence | ||
from colossalai.logging import get_dist_logger | ||
from colossalai.pipeline.stage_manager import PipelineStageManager | ||
from colossalai.shardformer import ShardConfig, ShardFormer | ||
from colossalai.shardformer.policies.base_policy import Policy | ||
|
||
from .request_handler import RequestHandler | ||
|
||
PP_AXIS, TP_AXIS = 0, 1 | ||
|
||
_supported_models = [ | ||
"LlamaForCausalLM", | ||
] | ||
|
||
|
||
class InferenceEngine: | ||
""" | ||
InferenceEngine is the core component for Inference. | ||
|
||
It is responsible for launch the inference process, including: | ||
- Initialize model and distributed training environment(if needed) | ||
- Launch request_handler and corresponding kv cache manager | ||
- Receive requests and generate texts. | ||
- Log the generation process | ||
""" | ||
InferenceEngine which manages the inference process.. | ||
|
||
Args: | ||
tokenizer: Path of the tokenizer to use. | ||
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs. | ||
model (nn.Module): Path or nn.Module of this model. | ||
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use. | ||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. | ||
verbose (bool): Determine whether or not to log the generation process. | ||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tokenizer: str = None, | ||
model: nn.Module, | ||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], | ||
inference_config: Optional["InferenceConfig"] = None, | ||
verbose: bool = False, | ||
model_policy: Policy = None, | ||
) -> None: | ||
assert inference_config, "Please provide inference_config." | ||
|
||
self._init_model() | ||
# cache_config may need to be modified later. | ||
# self.request_handler = RequestHandler(cache_config) | ||
self.tokenizer = tokenizer | ||
self.hf_model_config = AutoConfig.from_pretrained( | ||
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision | ||
self.inference_config = inference_config | ||
self.model_config = model.config | ||
|
||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32: | ||
self.dtype = torch.float32 | ||
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16: | ||
self.dtype = torch.float16 | ||
model.half() | ||
else: | ||
self.dtype = torch.bfloat16 | ||
model.to(torch.bfloat16) | ||
|
||
if model_policy is None: | ||
model_policy = model_policy_map[self.model_config.model_type]() | ||
|
||
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size) | ||
|
||
self.model = self._shardformer( | ||
model, | ||
model_policy, | ||
None, | ||
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None, | ||
) | ||
|
||
self.verbose = verbose | ||
if verbose: | ||
self.logger = Logger() | ||
self.logger = get_dist_logger(__name__) | ||
|
||
self.request_handler = RequestHandler(self.inference_config, self.model_config) | ||
self.counter = count() | ||
|
||
def _verify_config(self) -> None: | ||
""" | ||
Verify the input config | ||
""" | ||
if not isinstance(self.model, nn.Module): | ||
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}") | ||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance( | ||
self.tokenizer, PreTrainedTokenizer | ||
): | ||
raise TypeError( | ||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}" | ||
) | ||
assert ( | ||
self.model.__class__.__name__ in _supported_models | ||
), f"Model {self.model.__class__.__name__} is not supported." | ||
|
||
def _shardformer( | ||
self, | ||
model: nn.Module, | ||
model_policy: Policy, | ||
stage_manager: PipelineStageManager = None, | ||
tp_group: ProcessGroupMesh = None, | ||
) -> nn.Module: | ||
""" | ||
Initialize ShardConfig and replace the model with shardformer. | ||
|
||
Args: | ||
model (nn.Module): Path or nn.Module of this model. | ||
model_policy (Policy): The policy to shardformer model which is determined by the model type. | ||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. | ||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. | ||
|
||
Returns: | ||
nn.Module: _description_ | ||
""" | ||
shardconfig = ShardConfig( | ||
tensor_parallel_process_group=tp_group, | ||
pipeline_stage_manager=stage_manager, | ||
enable_tensor_parallelism=(self.inference_config.tp_size > 1), | ||
enable_fused_normalization=False, | ||
enable_all_optimization=False, | ||
enable_flash_attention=False, | ||
enable_jit_fused=False, | ||
enable_sequence_parallelism=False, | ||
extra_kwargs={"quant": self.inference_config.quant_mode}, | ||
) | ||
shardformer = ShardFormer(shard_config=shardconfig) | ||
shard_model, _ = shardformer.optimize(model, model_policy) | ||
return shard_model.cuda() | ||
|
||
def _init_model(self): | ||
def generate( | ||
self, | ||
generation_config: GenerationConfig = None, | ||
) -> List[str]: | ||
""" | ||
Initialize model and distributed training environment(if needed). | ||
May need to provide two different initialization methods: | ||
1. 用户自定义(from local path) | ||
2. 从checkpoint加载(hugging face) | ||
Executing the inference step. | ||
|
||
Args: | ||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. | ||
|
||
Returns: | ||
List[str]: Inference result returned by one generation. | ||
""" | ||
|
||
def _verify_config(self): | ||
self.generation_config = generation_config | ||
|
||
output_list = [] | ||
|
||
while self.request_handler.check_unfinished_seqs(): | ||
output_list += self.step() | ||
|
||
return output_list | ||
|
||
def add_request( | ||
self, | ||
requests_id: List[int] = None, | ||
prompts: List[str] = None, | ||
prompts_token_ids: List[int] = None, | ||
) -> None: | ||
""" | ||
Verify the configuration to avoid potential bugs. | ||
Add requests. | ||
|
||
Args: | ||
requests_id (List[int], optional): The request ID. Defaults to None. | ||
prompts (Union[List[str], optional): Input prompts. Defaults to None. | ||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. | ||
""" | ||
|
||
def generate(self): | ||
pass | ||
block_size = self.inference_config.block_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def step(self): | ||
if prompts_token_ids is None: | ||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." | ||
prompts_token_ids = [] | ||
for prompt in prompts: | ||
prompts_token_ids.append(self.tokenizer.encode(prompt)) | ||
|
||
prompts_num = len(prompts_token_ids) | ||
|
||
for i in range(prompts_num): | ||
if requests_id: | ||
request_id = requests_id[i] | ||
else: | ||
request_id = next(self.counter) | ||
if prompts == None: | ||
prompt = None | ||
else: | ||
prompt = prompts[i] | ||
sequence = Sequence( | ||
request_id, | ||
prompt, | ||
prompts_token_ids[i], | ||
block_size, | ||
None, | ||
None, | ||
self.tokenizer.eos_token_id, | ||
self.inference_config.max_output_len, | ||
) | ||
self.request_handler.add_sequence(sequence) | ||
|
||
def step(self) -> List[str]: | ||
""" | ||
In each step, do the follows: | ||
1. Run request_handler to update the kv cache and running input_ids | ||
1. Run RequestHandler.schedule() and get the batch used for inference. | ||
2. Run model to generate the next token | ||
3. Check whether there is finied request and decode | ||
3. Update waiting list and running list in RequestHandler and get finished sequences. | ||
4. Decode and return finished sequences. | ||
|
||
Returns: | ||
List[str]: Decoded finished sequences generated by one step. | ||
""" | ||
|
||
if self.verbose: | ||
self.logger.info("Running generation step") | ||
|
||
output_list = [] | ||
self.request_handler.schedule() | ||
|
||
# Uncomment if the development of RequestHandler is completed. | ||
# logits = self.model(batch) | ||
# self.request_handler.search_tokens(logits, self.generation_config) | ||
|
||
finished_sequences = self.request_handler.update() | ||
|
||
# Decode completed sentences. | ||
for seq in finished_sequences: | ||
if seq.prompt: | ||
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True) | ||
output_list.append(seq.prompt + output_str) | ||
else: | ||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True) | ||
output_list.append(output_str) | ||
|
||
return output_list |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.