diff --git a/benchmark/benchmark_prefix_translation_orca.py b/benchmark/benchmark_prefix_translation_orca.py new file mode 100644 index 000000000000..4592bddfee1a --- /dev/null +++ b/benchmark/benchmark_prefix_translation_orca.py @@ -0,0 +1,279 @@ +import argparse +import logging +import os +import pickle +import time +from typing import List + +from tqdm import tqdm +from transformers import AutoConfig + +from benchmark.trace import generate_translation_requests_orca +from cacheflow.master.simple_frontend import SimpleFrontend +from cacheflow.master.server import (Server, add_server_arguments, + initialize_ray_cluster) +from cacheflow.sampling_params import SamplingParams +from cacheflow.utils import get_gpu_memory, get_cpu_memory + + +logger = logging.getLogger(__name__) + + +def main(args: argparse.Namespace): + assert args.pipeline_parallel_size == 1, ( + 'Pipeline parallelism is not supported yet.') + + (num_nodes, num_devices_per_node, distributed_init_method, + all_stage_devices) = ( + initialize_ray_cluster( + address='local', + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size)) + + # Create a server. + server = Server( + model=args.model, + model_path=args.model_path, + use_dummy_weights=args.use_dummy_weights, + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size, + block_size=args.block_size, + dtype=args.dtype, + seed=args.seed, + swap_space=args.swap_space, + max_num_batched_tokens=args.max_num_batched_tokens, + max_num_sequences=args.max_num_sequences, + num_nodes=num_nodes, + num_devices_per_node=num_devices_per_node, + distributed_init_method=distributed_init_method, + all_stage_devices=all_stage_devices, + gpu_memory=get_gpu_memory(), + cpu_memory=get_cpu_memory(), + len_estimator=args.len_estimator, + collect_stats=True, + do_memory_analysis=args.do_memory_analysis, + ) + + # Create a frontend. + frontend = SimpleFrontend( + model_name=args.model, + block_size=args.block_size, + ) + # Generate requests. + requests = generate_translation_requests_orca( + model=args.model, + dataset=args.dataset, + num_examples=args.num_prefix_examples, + request_rate=args.request_rate, + duration=args.duration, + seed=args.seed, + ) + + # Warm up. + logger.info('Warming up.') + num_warmup_requests = 8 + warmup_input_len = 8 + warmup_output_len = 32 + warmup_sampling_params = SamplingParams( + n=1, + temperature=1.0, + top_p=0.99, + max_num_steps=warmup_output_len, + use_beam_search=False, + stop_token_ids=set(), + num_logprobs=0, + context_window_size=None, + ) + for _ in range(num_warmup_requests): + frontend._add_query([0] * warmup_input_len, warmup_sampling_params) + server.add_sequence_groups(frontend.get_inputs()) + while True: + server.step() + if not server.has_unfinished_requests(): + break + + # Start benchmarking. + logger.info('Start benchmarking.') + # Initialize tqdm. + pbar = tqdm(total=len(requests), desc='Finished requests') + + finished = [] + server.scheduler.reset_stats() + start_time = time.time() + while True: + now = time.time() + if args.timeout is not None and now - start_time > args.timeout: + logger.info('Timeout. Stop benchmarking.') + break + + while requests: + if requests[0][0] <= now - start_time: + request_time, input_tokens, sampling_params = requests.pop(0) + frontend._add_query( + input_tokens, sampling_params, arrival_time=start_time + request_time) + else: + break + server.add_sequence_groups(frontend.get_inputs()) + updated_seq_groups = server.step() + + now = time.time() + for seq_group in updated_seq_groups: + if not seq_group.is_finished(): + continue + # Print outputs. + # frontend.print_response(seq_group) + + arrival_time = seq_group.arrival_time + finish_time = now + for seq in seq_group.get_seqs(): + seq_len = seq.get_len() + output_len = seq_len - seq.prompt_len + finished.append({ + 'group_id': seq_group.group_id, + 'seq_id': seq.seq_id, + 'arrival_time': arrival_time, + 'finish_time': finish_time, + 'prompt_len': seq.prompt_len, + 'output_len': output_len, + }) + pbar.update(1) + + if not (requests or server.has_unfinished_requests()): + break + pbar.close() + logger.info('Finish benchmarking. Saving stats.') + server.scheduler.save_stats(args.output_dir) + with open(os.path.join(args.output_dir, 'sequences.pkl'), 'wb') as f: + pickle.dump(finished, f) + logger.info('Done.') + + +def get_model_name(model: str) -> str: + OPT_MODELS = [ + 'opt-125m', + 'opt-350m', + 'opt-1.3b', + 'opt-2.7b', + 'opt-6.7b', + 'opt-13b', + 'opt-30b', + 'opt-66b', + 'opt-175b', + ] + for opt_model in OPT_MODELS: + if opt_model in model: + return opt_model + + config = AutoConfig.from_pretrained(model) + assert config.model_type == 'llama' + hidden_size = config.hidden_size + if hidden_size == 4096: + return 'llama-7b' + elif hidden_size == 5120: + return 'llama-13b' + elif hidden_size == 6656: + return 'llama-30b' + elif hidden_size == 8192: + return 'llama-65b' + else: + raise ValueError(f'Unknown model: {model}') + + +def get_dataset_name(dataset: str) -> str: + if 'sharegpt' in dataset.lower(): + return 'sharegpt' + elif 'alpaca' in dataset.lower(): + return 'alpaca' + else: + raise ValueError(f'Unknown dataset: {dataset}') + + +def get_sampling_dir_name( + n1: float, + n2: float, + n3: float, + n4: float, + n6: float, + n2_beam: float, + n4_beam: float, + n6_beam: float, + n8_beam: float, +) -> str: + method = '' + if n1 > 0.0: + method = 'n1' if n1 == 1.0 else method + f'n1-{n1}-' + if n2 > 0.0: + method = 'n2' if n2 == 1.0 else method + f'n2-{n2}-' + if n3 > 0.0: + method = 'n3' if n3 == 1.0 else method + f'n3-{n3}-' + if n4 > 0.0: + method = 'n4' if n4 == 1.0 else method + f'n4-{n4}-' + if n6 > 0.0: + method = 'n6' if n6 == 1.0 else method + f'n6-{n6}-' + if n2_beam > 0.0: + method = 'n2-beam' if n2_beam == 1.0 else method + f'n2-beam-{n2_beam}-' + if n4_beam > 0.0: + method = 'n4-beam' if n4_beam == 1.0 else method + f'n4-beam-{n4_beam}-' + if n6_beam > 0.0: + method = 'n6-beam' if n6_beam == 1.0 else method + f'n6-beam-{n6_beam}-' + if n8_beam > 0.0: + method = 'n8-beam' if n8_beam == 1.0 else method + f'n8-beam-{n8_beam}-' + return method[:-1] if method.endswith('-') else method + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='CacheFlow simple server.') + parser = add_server_arguments(parser) + parser.add_argument('--output-dir', type=str, help='path to output directory', default=None) + parser.add_argument('--len-estimator', type=str, choices=['oracle', 'power2', 'constant'], required=True) + + parser.add_argument('--num-prefix-examples', type=int, help='number of examples to use in prefix', required=True) + parser.add_argument('--dataset', type=str, help='path to dataset', default='wmt16') + parser.add_argument('--request-rate', type=float, help='reqs/sec', required=True) + parser.add_argument('--duration', type=int, help='duration in seconds', required=True) + parser.add_argument('--do-memory-analysis', action='store_true', + help='do memory analysis (This will lower the throughput. Use this only for analysis.)') + parser.add_argument('--timeout', type=int, help='time out in seconds', default=None) + + parser.add_argument('--n1', type=float, help='ratio of requests with n=1', default=0.0) + parser.add_argument('--n2', type=float, help='ratio of requests with n=2', default=0.0) + parser.add_argument('--n3', type=float, help='ratio of requests with n=3', default=0.0) + parser.add_argument('--n4', type=float, help='ratio of requests with n=4', default=0.0) + parser.add_argument('--n6', type=float, help='ratio of requests with n=6', default=0.0) + parser.add_argument('--n2-beam', type=float, help='ratio of requests with n=2 & beam search', default=0.0) + parser.add_argument('--n4-beam', type=float, help='ratio of requests with n=4 & beam search', default=0.0) + parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0) + parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0) + args = parser.parse_args() + if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0: + raise ValueError('The ratios of requests must sum to 1.') + + model_name = get_model_name(args.model) + sample_dir = get_sampling_dir_name( + args.n1, args.n2, args.n3, args.n4, args.n6, args.n2_beam, args.n4_beam, args.n6_beam, args.n8_beam) + if args.output_dir is None: + args.output_dir = os.path.join( + '../prefix_exp', + f'{args.dataset}-{args.num_prefix_examples}shot', + f'{model_name}-tp{args.tensor_parallel_size}', + sample_dir, + f'orca-{args.len_estimator}', + f'req-rate-{args.request_rate}', + f'seed{args.seed}', + f'duration-{args.duration}', + ) + os.makedirs(args.output_dir, exist_ok=True) + + # Set up logging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + handlers=[ + logging.StreamHandler(), + logging.FileHandler(os.path.join(args.output_dir, 'log.txt')), + ], + ) + logger.info(args) + + main(args) diff --git a/benchmark/benchmark_text_completion.py b/benchmark/benchmark_text_completion.py index e6741d577c6e..05b03097b12a 100644 --- a/benchmark/benchmark_text_completion.py +++ b/benchmark/benchmark_text_completion.py @@ -49,6 +49,7 @@ def main(args: argparse.Namespace): all_stage_devices=all_stage_devices, gpu_memory=get_gpu_memory(), cpu_memory=get_cpu_memory(), + len_estimator=args.len_estimator, collect_stats=True, do_memory_analysis=args.do_memory_analysis, ) @@ -228,6 +229,7 @@ def get_sampling_dir_name( parser = argparse.ArgumentParser(description='CacheFlow simple server.') parser = add_server_arguments(parser) parser.add_argument('--output-dir', type=str, help='path to output directory', default=None) + parser.add_argument('--len-estimator', type=str, choices=['oracle', 'power2', 'constant'], required=True) parser.add_argument('--dataset', type=str, help='path to dataset', required=True) parser.add_argument('--request-rate', type=float, help='reqs/sec', required=True) @@ -267,7 +269,7 @@ def get_sampling_dir_name( dataset_name, f'{model_name}-tp{args.tensor_parallel_size}', sample_dir, - 'cacheflow', + f'orca-{args.len_estimator}', f'block{args.block_size}', f'req-rate-{args.request_rate}', f'seed{args.seed}', diff --git a/benchmark/plot_stats.py b/benchmark/plot_stats.py new file mode 100644 index 000000000000..c391571403ec --- /dev/null +++ b/benchmark/plot_stats.py @@ -0,0 +1,52 @@ +import os +import pickle + +import matplotlib.pyplot as plt + +STAT_NAMES = [ + 'input_lens', + 'num_running', + 'num_waiting', + 'num_preemption', + 'gpu_cache_usage', + 'cpu_cache_usage', + 'num_swapped', + 'swap_in_lens', + 'swap_out_lens', +] + + +def plot_stats(output_dir: str): + # Get stats. + with open(os.path.join(output_dir, 'stats.pkl'), 'rb') as f: + stats = pickle.load(f) + timestamps = stats['timestamps'] + + # Draw one figure for each stat. + num_stats = len(STAT_NAMES) + COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'pink', 'brown', 'gray'] + fig, axs = plt.subplots(num_stats, 1, figsize=(10, 2 * num_stats)) + for i, stat in enumerate(STAT_NAMES): + data = stats[stat] + if stat in ['gpu_cache_usage', 'cpu_cache_usage']: + data = [x * 100 for x in data] + stat = stat + ' (%)' + axs[i].plot(timestamps, data, color=COLORS[i % len(COLORS)]) + axs[i].set_ylabel(stat.replace('_', ' '), fontdict={'fontsize': 12}) + axs[i].set_ylim(bottom=0) + + plt.xlabel('Time (s)') + plt.tight_layout() + fig_path = os.path.join(output_dir, 'stats.png') + plt.savefig(fig_path) + print(f'Saved stats to {fig_path}') + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('output_dir', type=str, help='Output directory.') + args = parser.parse_args() + + plot_stats(args.output_dir) diff --git a/benchmark/trace.py b/benchmark/trace.py index 42e203235843..87b2254fcd01 100644 --- a/benchmark/trace.py +++ b/benchmark/trace.py @@ -2,7 +2,9 @@ import random from typing import List, Tuple +from datasets import load_dataset import numpy as np +from transformers import AutoTokenizer from cacheflow.sampling_params import SamplingParams @@ -114,3 +116,101 @@ def generate_text_completion_requests( cum_sum += 1 requests.append((timestamp, input_tokens, sampling_params)) return requests + + +def generate_translation_requests_orca( + model: str, + dataset: str, + num_examples: int, + request_rate: float, + duration: int, + seed: int, + max_seq_len: int = 2048, + time_quantum: int = 10, +) -> List[Tuple[float, List[int], SamplingParams]]: + tokenizer = AutoTokenizer.from_pretrained(model) + + random.seed(seed) + np.random.seed(seed) + + # Generate timestamps for requests using Poisson distribution. + lam = request_rate * (time_quantum / 1000) + quantums_per_sec = 1000 / time_quantum + arrival_times = np.random.poisson( + lam=lam, size=int(duration * quantums_per_sec)) + timestamps = [] + for i, n in enumerate(arrival_times): + timestamps += [i * (time_quantum / 1000)] * n + + # Load the training dataset and sample examples. + train_set = load_dataset('wmt16', 'de-en', split='train') + train_size = train_set.num_rows + if num_examples > train_size: + raise ValueError( + f'Number of examples ({num_examples}) is greater than the ' + f'number of training examples ({train_size}).') + + # Add instruction first. + prefix = 'Please translate these following English sentence(s) to German sentence(s):\n' + + # 8 fixed common examples of about the same length + # that was generated independently from the WMT16 dataset. + examples = [ + ('The new shopping center has a wide variety of stores, including clothing, electronics, and a supermarket.' + ' => Das neue Einkaufszentrum bietet eine große Vielfalt an Geschäften, einschließlich Bekleidung, Elektronik und einem Supermarkt.'), + ('For a healthier lifestyle, try incorporating regular exercise, a balanced diet, and stress-reducing activities into your routine.' + ' => Für einen gesünderen Lebensstil versuchen Sie, regelmäßige Bewegung, eine ausgewogene Ernährung und stressreduzierende Aktivitäten in Ihre Routine einzubauen.'), + ('The library will be hosting a series of workshops on various topics, such as creative writing.' + ' => Die Bibliothek veranstaltet eine Reihe von Workshops zu verschiedenen Themen wie kreativem Schreiben.'), + ('The museum offers guided tours every day at 11:00 am and 4:00 pm, and admission is free on Sundays.' + ' => Das Museum bietet jeden Tag um 11:00 Uhr und 16:00 Uhr Führungen an, und der Eintritt ist sonntags kostenlos.'), + ('If you experience any technical difficulties during the conference, please don\'t hesitate to contact the support team.' + ' => Wenn Sie während der Konferenz technische Schwierigkeiten haben, zögern Sie bitte nicht, das Support-Team zu kontaktieren.'), + ('The local farmer\'s market offers fresh fruits, vegetables, and other produce directly from the farms every Saturday morning.' + ' => Der örtliche Bauernmarkt bietet jeden Samstagmorgen frische Früchte, Gemüse und andere landwirtschaftliche Produkte direkt von den Höfen an.'), + ('Remember to set your clocks one hour forward for daylight saving time this weekend to enjoy longer days and more sunlight.' + ' => Denken Sie daran, Ihre Uhren am Wochenende für die Sommerzeit eine Stunde vorzustellen, um längere Tage und mehr Sonnenlicht zu genießen.'), + ('The restaurant offers a diverse menu featuring international cuisine, including Italian, French, and Japanese dishes.' + ' => Das Restaurant bietet eine vielfältige Speisekarte mit internationaler Küche, einschließlich italienischer, französischer und japanischer Gerichte'), + ] + assert num_examples <= len(examples) + prefix += '\n'.join(examples[:num_examples]) + '\n' + print("Prefix length in tokens:", len(tokenizer.encode(prefix))) + + # Tokenize the test set. + test_set = load_dataset(dataset, 'de-en', split='test') + tokenized = [] + for data in test_set: + en = data['translation']['en'] + ' =>' + input_tokens = tokenizer.encode(prefix + en) + + de = data['translation']['de'] + output_tokens = tokenizer.encode(de, add_special_tokens=False) + + # Filter out too long sequences. + if len(input_tokens) + len(output_tokens) > max_seq_len: + continue + tokenized.append((input_tokens, len(output_tokens))) + + # Generate requests. + num_requests = len(timestamps) + while len(tokenized) < num_requests: + tokenized += tokenized + tokenized = tokenized[:num_requests] + # Shuffle the requests. + random.shuffle(tokenized) + random_sampling_params_dict = { + 'temperature': 0.0, + 'top_p': 1.0, + 'use_beam_search': False, + 'stop_token_ids': set(), + 'num_logprobs': 0, + 'context_window_size': None, + } + requests = [] + for timestamp, pair in zip(timestamps, tokenized): + input_tokens, output_len = pair + sampling_params = SamplingParams( + n=1, max_num_steps=output_len, **random_sampling_params_dict) + requests.append((timestamp, input_tokens, sampling_params)) + return requests diff --git a/benchmark_prefix_translation_orca.sh b/benchmark_prefix_translation_orca.sh new file mode 100755 index 000000000000..f867e926f97c --- /dev/null +++ b/benchmark_prefix_translation_orca.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# This script is used to test the performance of prefix translation for Orca. + +#!/bin/bash +num_prefix=1 +for request_rate in 1 24 26 28 29 30 31 32; do + python benchmark/benchmark_prefix_translation_orca.py \ + --model ~/hf-llama/llama-13b/ \ + --num-prefix-examples "$num_prefix" \ + --request-rate "$request_rate" \ + --duration 1200 \ + --n1 1.0 \ + --len-estimator oracle \ + &>> log.txt +done + +num_prefix=5 +for request_rate in 1 4 8 9 10 11 12 13; do + python benchmark/benchmark_prefix_translation_orca.py \ + --model ~/hf-llama/llama-13b/ \ + --num-prefix-examples "$num_prefix" \ + --request-rate "$request_rate" \ + --duration 1200 \ + --n1 1.0 \ + --len-estimator oracle \ + &>> log.txt +done diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index 0b188508d15c..059c63a057a7 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -1,3 +1,6 @@ +import collections +import copy +import math from typing import Dict, List, Optional, Set, Tuple from cacheflow.block import PhysicalTokenBlock @@ -6,224 +9,227 @@ from cacheflow.sequence import SequenceStatus from cacheflow.utils import Device +_MAX_SEQ_LEN = 2048 -class BlockAllocator: + +class BuddyAllocator: def __init__( self, device: Device, - block_size: int, - num_blocks: int, + token_block_size: int, + num_token_blocks: int, ) -> None: self.device = device - self.block_size = block_size - self.num_blocks = num_blocks - - # Initialize the free blocks. - # TODO(woosuk): Make this a priority queue. - self.free_blocks = [ - PhysicalTokenBlock(device=device, block_number=i, block_size=block_size) - for i in range(num_blocks) - ] - - def allocate(self) -> PhysicalTokenBlock: - if not self.free_blocks: - raise ValueError('Out of memory! ' - f'No more free blocks are available.') - block = self.free_blocks.pop() - block.ref_count = 1 - return block - - def free(self, block: PhysicalTokenBlock) -> None: - if block.ref_count == 0: - raise ValueError('Double free! ' - f'The block {block} is already freed.') - block.ref_count -= 1 - if block.ref_count == 0: - self.free_blocks.append(block) + self.token_block_size = token_block_size + self.num_token_blocks = num_token_blocks + + self.min_block_size = 1 + self.max_block_size = _MAX_SEQ_LEN // token_block_size + self.size_to_free_blocks: Dict[int, List[int]] = collections.defaultdict(list) + self.addr_to_size: Dict[int, int] = {} + + buddy_size = self.max_block_size + last_start_addr = 0 + start_addrs = [] + while buddy_size >= 1: + new_start_addrs = [] + while last_start_addr + buddy_size <= self.num_token_blocks: + new_start_addrs.append(last_start_addr) + last_start_addr += buddy_size + + self.size_to_free_blocks[buddy_size] = new_start_addrs + for addr in new_start_addrs: + self.addr_to_size[addr] = buddy_size + start_addrs.extend(new_start_addrs) + buddy_size //= 2 + + def can_allocate(self, sizes: List[int]) -> bool: + # FIXME(woosuk): Must be fixed for performance. + size_to_free_blocks = copy.deepcopy(self.size_to_free_blocks) + addr_to_size = copy.deepcopy(self.addr_to_size) + for size in sizes: + try: + self.allocate(size, size_to_free_blocks, addr_to_size) + except ValueError: + return False + return True + + def _resize(self, size: int) -> int: + # Bump up the size to the next power of 2. + size = 2 ** math.ceil(math.log2(size)) + # Make sure the size is not smaller than the min block size. + size = max(size, self.min_block_size) + return size + + def allocate( + self, + size: int, + size_to_free_blocks: Optional[Dict[int, List[int]]] = None, + addr_to_size: Optional[Dict[int, int]] = None, + ) -> List[PhysicalTokenBlock]: + if size_to_free_blocks is None: + size_to_free_blocks = self.size_to_free_blocks + if addr_to_size is None: + addr_to_size = self.addr_to_size + + size = self._resize(size) + if size > self.max_block_size: + raise ValueError( + f'Size {size} is larger than max_block_size {self.max_block_size}.') + + # Find the smallest block that can fit the size. + i = size + while True: + if len(size_to_free_blocks[i]) > 0: + # Found a block. + start = size_to_free_blocks[i].pop() + addr_to_size[start] = size + + # Split the block. + while i > size: + i //= 2 + size_to_free_blocks[i].append(start + i) + addr_to_size[start + i] = i + + # Return the blocks. + physical_blocks = [] + for j in range(size): + physical_block = PhysicalTokenBlock( + device=self.device, + block_number=start + j, + block_size=self.token_block_size, + ) + physical_block.ref_count = 1 + physical_blocks.append(physical_block) + return physical_blocks + else: + i *= 2 + if i > self.max_block_size: + raise ValueError(f'Cannot find a block of size {size}.') + + def free(self, start: int) -> None: + size = self.addr_to_size[start] + del self.addr_to_size[start] + + # Merge the block with its buddy. + while size < self.max_block_size: + buddy = start ^ size + if buddy in self.addr_to_size and self.addr_to_size[buddy] == size: + # Found a buddy. + if buddy in self.size_to_free_blocks[size]: + self.size_to_free_blocks[size].remove(buddy) + del self.addr_to_size[buddy] + size *= 2 + start = min(start, buddy) + else: + break + else: + break + self.size_to_free_blocks[size].append(start) + self.addr_to_size[start] = size def get_num_free_blocks(self) -> int: - return len(self.free_blocks) + total = 0 + for size, free_blocks in self.size_to_free_blocks.items(): + total += size * len(free_blocks) + return total -# Mapping: logical block number -> physical block. BlockTable = List[PhysicalTokenBlock] -class BlockSpaceManager: +class BuddyBlockSpaceManager: def __init__( self, block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, - watermark: float = 0.01, + len_estimator: str, ) -> None: self.block_size = block_size self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks - self.watermark = watermark - assert watermark >= 0.0 + self.len_estimator = len_estimator - self.watermark_blocks = int(watermark * num_gpu_blocks) - self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks) - self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) + self.gpu_allocator = BuddyAllocator( + Device.GPU, block_size, num_gpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} - def can_allocate(self, seq_group: SequenceGroup) -> bool: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - seq = seq_group.seqs[0] - num_required_blocks = len(seq.logical_token_blocks) - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - # Use watermark to avoid frequent cache eviction. - return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks + # Mapping src physical block number -> List[dst physical block number]. + self.forked: Dict[int, List[int]] = {} - def allocate(self, seq_group: SequenceGroup) -> None: - # NOTE: Here we assume that all sequences in the group have the same prompt. + def _oracle(self, seq_group: SequenceGroup) -> int: + return seq_group.max_num_steps + + def _next_power_of_two(self, seq_group: SequenceGroup) -> int: + output_len = seq_group.max_num_steps + return 1 << (output_len - 1).bit_length() + + def _constant(self, seq_group: SequenceGroup) -> int: + # FIXME + return _MAX_SEQ_LEN + + def _compute_allocation_size(self, seq_group: SequenceGroup) -> int: + if self.len_estimator == 'oracle': + output_len = self._oracle(seq_group) + elif self.len_estimator == 'power2': + output_len = self._next_power_of_two(seq_group) + elif self.len_estimator == 'constant': + output_len = self._constant(seq_group) seq = seq_group.seqs[0] + seq_len = min(seq.get_len() + output_len, _MAX_SEQ_LEN) + size = (seq_len + self.block_size - 1) // self.block_size + return size - # Allocate new physical token blocks that will store the prompt tokens. - block_table: BlockTable = [] - for _ in range(len(seq.logical_token_blocks)): - block = self.gpu_allocator.allocate() - # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() - block_table.append(block) + def can_allocate(self, seq_group: SequenceGroup) -> bool: + # NOTE: Here we assume that all sequences in the group have the same prompt. + size = self._compute_allocation_size(seq_group) + return self.gpu_allocator.can_allocate([size] * len(seq_group.seqs)) - # Assign the block table for each sequence. + def allocate(self, seq_group: SequenceGroup) -> None: + # NOTE: Here we assume that all sequences in the group have the same prompt. + size = self._compute_allocation_size(seq_group) for seq in seq_group.seqs: - self.block_tables[seq.seq_id] = block_table.copy() + self.block_tables[seq.seq_id] = self.gpu_allocator.allocate(size) def can_append(self, seq_group: SequenceGroup) -> bool: - # Simple heuristic: If there is at least one free block - # for each sequence, we can append. - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) - return num_seqs <= num_free_gpu_blocks - - def append(self, seq: Sequence) -> Optional[Tuple[int, int]]: - """Allocate a physical slot for the new token.""" - logical_blocks = seq.logical_token_blocks - block_table = self.block_tables[seq.seq_id] + return True - if len(block_table) < len(logical_blocks): - # The sequence has a new logical block. - # Allocate a new physical block. - block = self.gpu_allocator.allocate() - block_table.append(block) - return None - - # We want to append the token to the last physical block. - last_block = block_table[-1] - assert last_block.device == Device.GPU - if last_block.ref_count == 1: - # Not shared with other sequences. Appendable. - return None - else: - # The last block is shared with other sequences. - # Copy on Write: Allocate a new block and copy the tokens. - new_block = self.gpu_allocator.allocate() - block_table[-1] = new_block - self.gpu_allocator.free(last_block) - return last_block.block_number, new_block.block_number + def append(self, seq: Sequence) -> Dict[int, List[int]]: + ret: Dict[int, List[int]] = {} + block_table = self.block_tables[seq.seq_id] + for block in block_table: + if block.block_number in self.forked: + assert block.block_number not in ret + ret[block.block_number] = self.forked[block.block_number] + del self.forked[block.block_number] + return ret def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - # NOTE: fork does not allocate a new physical block. - # Thus, it is always safe from OOM. src_block_table = self.block_tables[parent_seq.seq_id] - self.block_tables[child_seq.seq_id] = src_block_table.copy() - for block in src_block_table: - block.ref_count += 1 - - def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: - # NOTE: Here, we assume that the physical blocks are only shared by - # the sequences in the same group. - blocks: Set[PhysicalTokenBlock] = set() - for seq in seq_group.seqs: - if seq.status == SequenceStatus.FINISHED: - continue - block_table = self.block_tables[seq.seq_id] - for block in block_table: - blocks.add(block) - return list(blocks) + dst_block_table = self.block_tables[child_seq.seq_id] + for src_block, dst_block in zip(src_block_table, dst_block_table): + if src_block.block_number in self.forked: + self.forked[src_block.block_number].append(dst_block.block_number) + else: + self.forked[src_block.block_number] = [dst_block.block_number] def can_swap_in(self, seq_group: SequenceGroup) -> bool: - blocks = self._get_physical_blocks(seq_group) - num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - num_free_blocks = self.gpu_allocator.get_num_free_blocks() - # NOTE: Conservatively, we assume that every sequence will allocate - # at least one free block right after the swap-in. - # NOTE: This should match the logic in can_append(). - num_required_blocks = len(blocks) + num_swapped_seqs - return num_free_blocks - num_required_blocks >= self.watermark_blocks - - def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: - # CPU block -> GPU block. - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.seqs: - if seq.status == SequenceStatus.FINISHED: - continue - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for cpu_block in block_table: - if cpu_block in mapping: - gpu_block = mapping[cpu_block] - gpu_block.ref_count += 1 - else: - gpu_block = self.gpu_allocator.allocate() - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) - self.block_tables[seq.seq_id] = new_block_table - - block_number_mapping = { - cpu_block.block_number: gpu_block.block_number - for cpu_block, gpu_block in mapping.items() - } - return block_number_mapping + return False def can_swap_out(self, seq_group: SequenceGroup) -> bool: - blocks = self._get_physical_blocks(seq_group) - return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: - # GPU block -> CPU block. - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.seqs: - if seq.status == SequenceStatus.FINISHED: - continue - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for gpu_block in block_table: - if gpu_block in mapping: - cpu_block = mapping[gpu_block] - cpu_block.ref_count += 1 - else: - cpu_block = self.cpu_allocator.allocate() - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - self.gpu_allocator.free(gpu_block) - self.block_tables[seq.seq_id] = new_block_table - - block_number_mapping = { - gpu_block.block_number: cpu_block.block_number - for gpu_block, cpu_block in mapping.items() - } - return block_number_mapping + return False def _free_block_table(self, block_table: BlockTable) -> None: + block = block_table[0] + self.gpu_allocator.free(block.block_number) for block in block_table: - if block.device == Device.GPU: - self.gpu_allocator.free(block) - else: - self.cpu_allocator.free(block) + if block.block_number in self.forked: + del self.forked[block.block_number] def free(self, seq: Sequence) -> None: block_table = self.block_tables[seq.seq_id] @@ -237,10 +243,12 @@ def reset(self) -> None: def get_block_table(self, seq: Sequence) -> List[int]: block_table = self.block_tables[seq.seq_id] + num_blocks = len(seq.logical_token_blocks) + block_table = block_table[:num_blocks] return [block.block_number for block in block_table] def get_num_free_gpu_blocks(self) -> int: return self.gpu_allocator.get_num_free_blocks() def get_num_free_cpu_blocks(self) -> int: - return self.cpu_allocator.get_num_free_blocks() + return self.num_total_cpu_blocks diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index da461798bb6e..8b69924b9706 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -4,7 +4,7 @@ import time from typing import Any, Dict, List, Optional, Tuple -from cacheflow.master.block_manager import BlockSpaceManager +from cacheflow.master.block_manager import BuddyBlockSpaceManager from cacheflow.master.policy import PolicyFactory from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence @@ -37,6 +37,7 @@ def __init__( num_cpu_blocks: int, max_num_batched_tokens: int, max_num_sequences: int, + len_estimator: str, collect_stats: bool, do_memory_analysis: bool = False, ) -> None: @@ -48,14 +49,16 @@ def __init__( self.max_num_sequences = max_num_sequences self.collect_stats = collect_stats self.do_memory_analysis = do_memory_analysis + self.len_estimator = len_estimator # Instantiate the scheduling policy. self.policy = PolicyFactory.get_policy(policy_name='fcfs') # Create the block space manager. - self.block_manager = BlockSpaceManager( + self.block_manager = BuddyBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, + len_estimator=len_estimator, ) # Sequence groups in the WAITING state. @@ -125,8 +128,8 @@ def _schedule( # Swap in the sequence groups in the SWAPPED state if possible. self.swapped = self.policy.sort_by_priority(now, self.swapped) - # FCFS - while self.swapped and not blocks_to_swap_out: + assert not self.swapped, 'In Orca, swapping never happens.' + while self.swapped: seq_group = self.swapped[0] # If the sequence group has been preempted in this step, stop. if seq_group in preempted: @@ -210,8 +213,8 @@ def _schedule( num_logical_tokens = 0 num_physical_blocks = 0 num_physical_tokens = 0 - physical_block_numbers = set() num_reserved_tokens = 0 + num_internal_tokens = 0 for seq_group in self.running: group_id = seq_group.group_id sampling_params = self.sampling_params[group_id] @@ -223,18 +226,32 @@ def _schedule( seq_id = seq.seq_id block_table = block_tables[seq_id] for i, block in enumerate(block_table): - if block.block_number in physical_block_numbers: - continue - physical_block_numbers.add(block.block_number) num_physical_blocks += 1 - num_physical_tokens += seq.logical_token_blocks[i].num_tokens - - assert num_physical_blocks == num_used_gpu_blocks + if i < len(seq.logical_token_blocks): + num_physical_tokens += seq.logical_token_blocks[i].num_tokens + + reserved = seq.prompt_len + max_num_steps - seq.get_len() + num_reserved_tokens += reserved + if self.len_estimator == 'oracle': + output_len = max_num_steps + elif self.len_estimator == 'power2': + output_len = 1 << (max_num_steps - 1).bit_length() + elif self.len_estimator == 'constant': + output_len = 2048 + else: + assert False + allocated = min(seq.prompt_len + output_len, 2048) + internal = allocated - (seq.prompt_len + max_num_steps) + num_internal_tokens += internal + + assert num_physical_blocks == num_used_gpu_blocks, \ + f'{num_physical_blocks} != {num_used_gpu_blocks}' self.stats.num_logical_blocks.append(num_logical_blocks) self.stats.num_logical_tokens.append(num_logical_tokens) self.stats.num_physical_blocks.append(num_physical_blocks) self.stats.num_physical_tokens.append(num_physical_tokens) self.stats.num_reserved_tokens.append(num_reserved_tokens) + self.stats.num_internal_tokens.append() return (blocks_to_swap_in, blocks_to_swap_out, @@ -316,8 +333,6 @@ def post_step( output = seq_outputs[seq.seq_id] if seq.seq_id != output.parent_seq_id: # The sequence is a fork of the parent sequence (beam search). - # Free the current sequence. - self.block_manager.free(seq) # Fork the parent sequence. parent_seq = seq_group.find(output.parent_seq_id) parent_seq.fork(seq) @@ -366,13 +381,12 @@ def _append( blocks_to_copy: Dict[int, List[int]], ) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - ret = self.block_manager.append(seq) - if ret is not None: - src_block, dst_block = ret + mapping = self.block_manager.append(seq) + for src_block, dst_blocks in mapping.items(): if src_block in blocks_to_copy: - blocks_to_copy[src_block].append(dst_block) + blocks_to_copy[src_block] += dst_blocks else: - blocks_to_copy[src_block] = [dst_block] + blocks_to_copy[src_block] = dst_blocks def _preempt( self, @@ -380,6 +394,7 @@ def _preempt( blocks_to_swap_out: Dict[int, int], preemption_mode: Optional[PreemptionMode] = None, ) -> None: + assert False, 'In Orca, preemption never happens.' # If preemption mode is not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than # swapping. However, when the sequence group has multiple sequences @@ -494,6 +509,7 @@ def __init__( self.num_physical_blocks: List[int] = [] self.num_physical_tokens: List[int] = [] self.num_reserved_tokens: List[int] = [] + self.num_internal_tokens: List[int] = [] def reset( self, @@ -522,6 +538,7 @@ def to_dict(self) -> Dict[str, Any]: 'num_physical_blocks': self.num_physical_blocks, 'num_physical_tokens': self.num_physical_tokens, 'num_reserved_tokens': self.num_reserved_tokens, + 'num_internal_tokens': self.num_internal_tokens, } def save(self, output_dir: str) -> None: diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 5b8110a3dab4..54f09c1642d1 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -31,6 +31,7 @@ def __init__( all_stage_devices: List[List[DeviceID]], gpu_memory: int, cpu_memory: int, + len_estimator: str = 'oracle', collect_stats: bool = False, do_memory_analysis: bool = False, ): @@ -83,6 +84,7 @@ def __init__( num_cpu_blocks=self.num_cpu_blocks, max_num_batched_tokens=max_num_batched_tokens, max_num_sequences=max_num_sequences, + len_estimator=len_estimator, collect_stats=collect_stats, do_memory_analysis=do_memory_analysis, ) diff --git a/cacheflow/master/simple_frontend.py b/cacheflow/master/simple_frontend.py index f8396269874f..9551be8c4037 100644 --- a/cacheflow/master/simple_frontend.py +++ b/cacheflow/master/simple_frontend.py @@ -50,7 +50,8 @@ def _add_query( seqs.append(seq) group_id = next(self.seq_group_counter) - seq_group = SequenceGroup(group_id, seqs, arrival_time) + seq_group = SequenceGroup( + group_id, seqs, arrival_time=arrival_time, max_num_steps=sampling_params.max_num_steps) self.inputs.append((seq_group, sampling_params)) def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]: diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 1e358c7e5278..538f68ca327c 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -80,13 +80,7 @@ def _get_temperatures( # (i.e., greedy sampling or beam search). # Set the temperature to 1 to avoid division by zero. temperature = 1.0 - - if i < input_metadata.num_prompts: - # A prompt input. - temperatures.append(temperature) - else: - # A generation token. - temperatures += [temperature] * len(seq_ids) + temperatures += [temperature] * len(seq_ids) return temperatures @@ -96,12 +90,7 @@ def _get_top_ps( top_ps: List[float] = [] for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group - if i < input_metadata.num_prompts: - # A prompt input. - top_ps.append(sampling_params.top_p) - else: - # A generation token. - top_ps += [sampling_params.top_p] * len(seq_ids) + top_ps += [sampling_params.top_p] * len(seq_ids) return top_ps @@ -234,12 +223,14 @@ def _sample( idx = 0 for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group - if i < input_metadata.num_prompts: + # NOTE(woosuk): In Orca, we must use idx instead of i because + # each beam is considered as a separate prompt. + if idx < input_metadata.num_prompts: # Generate the next tokens for a prompt input. assert len(seq_ids) == sampling_params.n prob = probs[idx] logprob = logprobs[idx] - idx += 1 + idx += len(seq_ids) # Sample the next tokens. next_token_ids = _sample_from_prompt(prob, sampling_params) diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 6f5501a99468..a299c94e8a2f 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -89,10 +89,12 @@ def __init__( self, group_id: int, seqs: List[Sequence], + max_num_steps: int, arrival_time: float, ) -> None: self.group_id = group_id self.seqs = seqs + self.max_num_steps = max_num_steps self.arrival_time = arrival_time def get_seqs( diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 95ce2c6a869e..2ee41afc30bc 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -114,25 +114,25 @@ def prepare_inputs( seq_groups.append((seq_ids, sampling_params)) seq_logprobs.update(input_seq_group.seq_logprobs) - # Use any sequence in the group. - seq_id = seq_ids[0] - - prompt_tokens = input_seq_group.input_tokens[seq_id] - prompt_len = len(prompt_tokens) - prompt_lens.append(prompt_len) - - input_tokens.extend(prompt_tokens) - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.extend(range(len(prompt_tokens))) - - # Compute the slot mapping. - block_table = input_seq_group.block_tables[seq_id] - for i in range(prompt_len): - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) + # In Orca, we treat all sequences in a sequence group as if they + # were independent. + for seq_id in seq_ids: + prompt_tokens = input_seq_group.input_tokens[seq_id] + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.extend(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(range(len(prompt_tokens))) + + # Compute the slot mapping. + block_table = input_seq_group.block_tables[seq_id] + for i in range(prompt_len): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) cumulative_prompt_lens: List[int] = [0] for prompt_len in prompt_lens: