Skip to content

Commit 486d06a

Browse files
github-actions[bot]github-actions
andauthored
[format] applied code formatting on changed files in pull request 4820 (#4886)
Co-authored-by: github-actions <[email protected]>
1 parent c7aa319 commit 486d06a

File tree

13 files changed

+298
-259
lines changed

13 files changed

+298
-259
lines changed

colossalai/inference/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .pipeline import PPInferEngine
22

3-
__all__ = ['PPInferEngine']
3+
__all__ = ["PPInferEngine"]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .engine import PPInferEngine
22

3-
__all__ = ['PPInferEngine']
3+
__all__ = ["PPInferEngine"]

colossalai/inference/pipeline/benchmark/benchmark.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
1+
import argparse
2+
import time
3+
14
import torch
25
import torch.distributed as dist
36
import transformers
47

58
import colossalai
6-
import time
79
from colossalai.inference import PPInferEngine
810
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
9-
import argparse
10-
GIGABYTE = 1024 ** 3
11+
12+
GIGABYTE = 1024**3
1113
MEGABYTE = 1024 * 1024
1214

1315
colossalai.launch_from_torch(config={})
1416

15-
def data_gen(batch_size: int=4, seq_len: int=512):
17+
18+
def data_gen(batch_size: int = 4, seq_len: int = 512):
1619
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
1720
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
1821
data = dict(input_ids=input_ids, attention_mask=attention_mask)
1922
for k, v in data.items():
20-
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
23+
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
2124
new_shape = [1] * v.dim()
2225
new_shape[0] = batch_size
23-
data[k] = v.to('cuda').repeat(*new_shape)
26+
data[k] = v.to("cuda").repeat(*new_shape)
2427
return data
2528

29+
2630
def print_details_info(timestamps, model_config, args, whole_end2end):
2731
if dist.get_rank() == 0:
2832
prefill = []
@@ -31,32 +35,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
3135
for timestamp in timestamps:
3236
prefill.append(timestamp[1] - timestamp[0])
3337
encoder.append(
34-
sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2))
38+
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
39+
)
3540
end2end.append(timestamp[-1] - timestamp[0])
3641
print(whole_end2end)
37-
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f:
38-
mb_avg_end2end = sum(end2end)/len(end2end)
39-
mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size)
40-
whole_avg_latency = whole_end2end/(args.new_length * args.batch_size)
42+
with open(
43+
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
44+
"w+",
45+
) as f:
46+
mb_avg_end2end = sum(end2end) / len(end2end)
47+
mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size)
48+
whole_avg_latency = whole_end2end / (args.new_length * args.batch_size)
4149
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
4250
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
43-
if args.dtype in ['fp16','bf16']:
51+
if args.dtype in ["fp16", "bf16"]:
4452
num_bytes = 2
4553
else:
4654
num_bytes = 4
4755

48-
f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n")
49-
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000))
50-
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000))
51-
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000))
56+
f.write(
57+
f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n"
58+
)
59+
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000))
60+
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000))
61+
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000))
5262
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
53-
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000))
63+
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000))
5464
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
55-
f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000))))
56-
f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12))
65+
f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000))))
66+
f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12))
5767
f.write("----------------------------------------------------------\n")
5868

59-
6069
if torch.cuda.is_available():
6170
current_device = torch.cuda.current_device()
6271

@@ -66,7 +75,10 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
6675
max_memory_allocated = torch.cuda.max_memory_allocated()
6776
memory_reserved = torch.cuda.memory_reserved()
6877
max_memory_reserved = torch.cuda.max_memory_reserved()
69-
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f:
78+
with open(
79+
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
80+
"a",
81+
) as f:
7082
f.write(
7183
f"\nCurrently using GPU: {current_device}\n"
7284
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
@@ -77,29 +89,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
7789
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
7890
)
7991

80-
if __name__ == '__main__':
92+
93+
if __name__ == "__main__":
8194
parser = argparse.ArgumentParser()
82-
parser.add_argument('--model', default='toy', help='the size of model')
83-
parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size')
84-
parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length')
85-
parser.add_argument('--new_length', type=int, default=4, help='new tokens length')
86-
parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size')
87-
parser.add_argument('--pp_size', type=int, default=2, help='pipeline size')
88-
parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log')
89-
parser.add_argument('--dtype', type=str, default='fp16', help='data type')
95+
parser.add_argument("--model", default="toy", help="the size of model")
96+
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
97+
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
98+
parser.add_argument("--new_length", type=int, default=4, help="new tokens length")
99+
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
100+
parser.add_argument("--pp_size", type=int, default=2, help="pipeline size")
101+
parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log")
102+
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
90103
args = parser.parse_args()
91104

92-
if args.model == 'toy':
105+
if args.model == "toy":
93106
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
94-
elif args.model == '7b':
95-
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf'))
96-
elif args.model == '13b':
97-
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf'))
107+
elif args.model == "7b":
108+
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf"))
109+
elif args.model == "13b":
110+
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf"))
98111
else:
99112
raise NotImplementedError
100-
101-
102-
engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True)
113+
114+
engine = PPInferEngine(
115+
pp_size=args.pp_size,
116+
dtype=args.dtype,
117+
micro_batch_size=args.mb_size,
118+
new_length=args.new_length,
119+
model=model,
120+
model_policy=LlamaForCausalLMPipelinePolicy(),
121+
verbose=True,
122+
)
103123
data = data_gen(args.batch_size, args.seq_len)
104124

105125
torch.cuda.synchronize()
@@ -109,4 +129,3 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
109129
whole_end2end = time.time() - whole_end2end
110130

111131
print_details_info(timestamps, model.config, args, whole_end2end)
112-

colossalai/inference/pipeline/engine.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Callable, List, Optional, Set, Union
2-
31
import torch
42
import torch.nn as nn
53

@@ -13,7 +11,7 @@
1311

1412

1513
class PPInferEngine:
16-
'''
14+
"""
1715
PPInferEngine is a class that handles the pipeline parallel inference.
1816
1917
Args:
@@ -41,20 +39,20 @@ class PPInferEngine:
4139
output = engine.inference([tokenized_input])
4240
```
4341
44-
'''
42+
"""
4543

4644
def __init__(
4745
self,
4846
pp_size: int,
49-
dtype: str = 'fp16',
47+
dtype: str = "fp16",
5048
pp_model: nn.Module = None,
5149
model: nn.Module = None,
5250
model_policy: Policy = None,
5351
new_length: int = 32,
5452
micro_batch_size: int = 1,
5553
micro_batch_buffer_size: int = None,
5654
verbose: bool = False,
57-
# TODO: implement early_stopping, and various gerneration options
55+
# TODO: implement early_stopping, and various gerneration options
5856
early_stopping: bool = False,
5957
do_sample: bool = False,
6058
num_beams: int = 1,
@@ -63,15 +61,16 @@ def __init__(
6361
self.pp_size = pp_size
6462
self.pg_mesh = ProcessGroupMesh(pp_size)
6563
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
66-
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size,
67-
micro_batch_buffer_size or pp_size)
64+
self.mb_manager = MicroBatchManager(
65+
self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
66+
)
6867
self.verbose = verbose
6968
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
7069

71-
assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'"
72-
if dtype == 'fp16':
70+
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
71+
if dtype == "fp16":
7372
model.half()
74-
elif dtype == 'bf16':
73+
elif dtype == "bf16":
7574
model.to(torch.bfloat16)
7675
self.model = pp_model or self._shardformer(model, model_policy)
7776

colossalai/inference/pipeline/microbatch_manager.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
__all__ = 'MicroBatchManager'
6+
__all__ = "MicroBatchManager"
77

88

99
class Status(Enum):
@@ -13,7 +13,7 @@ class Status(Enum):
1313
COOLDOWN = 4
1414

1515

16-
class MicroBatchDescription():
16+
class MicroBatchDescription:
1717
"""
1818
This is the class to record the infomation of each microbatch, and also do some update operation.
1919
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
@@ -30,14 +30,14 @@ def __init__(
3030
output_dict: Dict[str, torch.Tensor],
3131
new_length: int,
3232
) -> None:
33-
assert output_dict.get('hidden_states') is not None
34-
self.mb_length = output_dict['hidden_states'].shape[-2]
33+
assert output_dict.get("hidden_states") is not None
34+
self.mb_length = output_dict["hidden_states"].shape[-2]
3535
self.target_length = self.mb_length + new_length
3636
self.kv_cache = ()
3737

3838
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
3939
if output_dict is not None:
40-
self._update_kvcache(output_dict['past_key_values'])
40+
self._update_kvcache(output_dict["past_key_values"])
4141

4242
def _update_kvcache(self, kv_cache: Tuple):
4343
assert type(kv_cache) == tuple
@@ -64,7 +64,6 @@ def cur_length(self):
6464
Return the current sequnence length of micro batch
6565
6666
"""
67-
pass
6867

6968

7069
class HeadMicroBatchDescription(MicroBatchDescription):
@@ -80,13 +79,14 @@ class HeadMicroBatchDescription(MicroBatchDescription):
8079
8180
"""
8281

83-
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
84-
new_length: int) -> None:
82+
def __init__(
83+
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
84+
) -> None:
8585
super().__init__(inputs_dict, output_dict, new_length)
8686
assert inputs_dict is not None
87-
assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None
88-
self.input_ids = inputs_dict['input_ids']
89-
self.attn_mask = inputs_dict['attention_mask']
87+
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
88+
self.input_ids = inputs_dict["input_ids"]
89+
self.attn_mask = inputs_dict["attention_mask"]
9090
self.new_tokens = None
9191

9292
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
@@ -104,7 +104,8 @@ def _update_newtokens(self, new_token: torch.Tensor):
104104

105105
def _update_attnmask(self):
106106
self.attn_mask = torch.cat(
107-
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1)
107+
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
108+
)
108109

109110
@property
110111
def cur_length(self):
@@ -127,8 +128,9 @@ class BodyMicroBatchDescription(MicroBatchDescription):
127128
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
128129
"""
129130

130-
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
131-
new_length: int) -> None:
131+
def __init__(
132+
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
133+
) -> None:
132134
super().__init__(inputs_dict, output_dict, new_length)
133135

134136
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
@@ -146,8 +148,8 @@ def cur_length(self):
146148
return self.kv_cache[0][0].shape[-2] + 1
147149

148150

149-
class MicroBatchManager():
150-
'''
151+
class MicroBatchManager:
152+
"""
151153
MicroBatchManager is a class that manages the micro batch.
152154
153155
Args:
@@ -156,7 +158,7 @@ class MicroBatchManager():
156158
micro_batch_size (int): the micro batch size.
157159
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
158160
159-
'''
161+
"""
160162

161163
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
162164
self.stage = stage

0 commit comments

Comments
 (0)