diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 3e9ceb34af5..4ba645ffd87 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import time from abc import ABC, abstractmethod from typing import List, Optional @@ -97,6 +98,7 @@ def generate( # noqa: C901 pos_base: int = 0, ) -> List[int]: # Prefill + prefill_start = time.time() logits = self.forward( tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), input_pos=( @@ -105,11 +107,13 @@ def generate( # noqa: C901 else None ), ) + prefill_time = time.time() - prefill_start current_token = next_token(logits, temperature, top_p) print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) tokens = prompt_tokens + [current_token] + generate_start = time.time() while len(tokens) < max_seq_len: if self.use_kv_cache: logits = self.forward( @@ -140,6 +144,10 @@ def generate( # noqa: C901 print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) print("\n") + generate_time = time.time() - generate_start + print(f"Prefill time: {prefill_time}") + print(f"Generation tok/s: {len(tokens) / generate_time}") + return tokens if echo else tokens[len(prompt_tokens) :] def text_completion(