From 2aa67b418cb9c77644245b0dee500a0bf5a78071 Mon Sep 17 00:00:00 2001 From: anirudh Date: Sun, 9 Feb 2025 12:56:30 +0530 Subject: [PATCH 01/13] [wip] Added cli args and other changes to eval multi-modal models --- install/requirements.txt | 2 +- torchchat/cli/builder.py | 6 + torchchat/cli/cli.py | 10 ++ torchchat/usages/eval.py | 325 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 339 insertions(+), 4 deletions(-) diff --git a/install/requirements.txt b/install/requirements.txt index bd1e09174..e9df1a209 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -34,4 +34,4 @@ streamlit flask # eval -lm_eval==0.4.2 +lm_eval==0.4.5 diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 1e04800ab..010db5374 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -71,6 +71,7 @@ class BuilderArgs: dynamic_shapes: bool = False max_seq_length: Optional[int] = None attention_backend: str = "math" + modality: Optional[str] = "text" def __post_init__(self): if self.device is None: @@ -146,6 +147,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": aoti_package_path = getattr(args, "aoti_package_path", None) snapshot_path = getattr(args, "snapshot_path", None) + modality = "text" + if args.modality: + modality = args.modality + is_chat_model = False if args.is_chat_model: is_chat_model = True @@ -222,6 +227,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": chpt_from=chpt_from, distribution_path=distribution_path, is_chat_model=is_chat_model, + modality=modality, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), attention_backend=attention_backend, diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index f6bf32e40..559bb645b 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -137,6 +137,16 @@ def _add_model_specification_args(parser) -> None: help=argparse.SUPPRESS, ) + model_specification_parser.add_argument( + "--modality", + type=str, + default="text", + choices=["text", "text-image"], + # help=argparse.SUPPRESS, + help="Modality of the model. Options: text, text-image", + # help="Modality of the model. Options: text, text-image", + ) + # Add CLI Args related to model configuration (compilation, quant, etc) # Excludes compile args if subcommand is export diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index b708e5840..4d17a978e 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse -from typing import Callable, Optional +from typing import Callable, Optional, Dict, List import torch import torch._dynamo.config @@ -33,6 +33,21 @@ from lm_eval.evaluator import evaluate from lm_eval.models.huggingface import HFLM as eval_wrapper from lm_eval.tasks import get_task_dict +from lm_eval.models.hf_vlms import HFMultimodalLM +from lm_eval.evaluator import evaluate + +from torchtune.modules.model_fusion import DeepFusionModel +from torchtune.modules.transforms import Transform +from torchtune.data import ( + format_content_with_images, + left_pad_sequence, + Message, + padded_collate_tiled_images_and_mask, +) +from torchtune.generation import generate, sample +from torchtune import utils + +import PIL def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( @@ -89,7 +104,7 @@ def __init__( device="cpu", is_pte_model: bool = False, ): - super().__init__(device=device) + super().__init__(pretrained="gpt2", device=device) self._model = model self._model_forward = ( model_forward @@ -168,6 +183,250 @@ def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented") +# Dummy class which _VLMEvalWrapper can inherit from when the imports don't work +# class HFMultimodalLM(): +# def __init__(self): +# return + +class _VLMEvalWrapper(HFMultimodalLM): + """An EvalWrapper for EleutherAI's eval harness based on gpt-fast's + EvalWrapper: https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py. + + Note: + This is ONLY for vision-language models. + + Args: + model (DeepFusionModel): The VLM to evaluate. + transform (Transform): The transform (tokenizer) to use for preprocessing. + device (torch.device): The device to use. + max_seq_length (int): The maximum sequence length. + batch_size (int): The batch size. + dtype (torch.dtype): dtype for the model caches during generation. + enable_kv_cache (bool): Whether to enable KV cache for generation. + image_tag (str): The string to use for the image token. Default is "", which + is the default used by the MMMU dataset. + max_images_per_sample (int): The maximum number of images per sample. Defaults to + the max number of images in MMMU. + """ + + def __init__( + self, + model: DeepFusionModel, + transform: Transform, + *, + device: torch.device, + max_seq_length: int = 4096, + batch_size: int = 8, + dtype: torch.dtype = torch.bfloat16, + enable_kv_cache: bool = True, + # TODO (@joecummings): Update these defaults once more multimodal + # tasks are added to the eval harness + image_tag: str = "", + max_images_per_sample: int = 7, + ): + self._model = model + self._transform = transform + self._device = device + self._max_seq_length = max_seq_length + self._batch_size = batch_size + self._dtype = dtype + # Defaulting KV cache to True for multimodal + self._enable_kv_cache = True + self._image_tag = image_tag + self._max_images_per_sample = max_images_per_sample + + @property + def model(self): + # Not actually changing the dtype here, just adding it as a + # property on the model + self._model.dtype = self._dtype + return self._model + + @property + def model_transform(self): + return self._transform + + @property + def device(self): + return self._device + + @property + def cache_hook(self): + # Dummy class to appease the Harness + class DummyCacheHook: + def __init__(self): + self.add_partial = lambda x, y, z: True + + return DummyCacheHook() + + @property + def rank(self): + # Hardcoded for now b/c we only support single GPU eval + return 0 + + @property + def world_size(self): + # Hardcoded for now b/c we only support single GPU eval + return 1 + + @property + def batch_size(self): + return self._batch_size + + @property + def eos_token_id(self): + return self._transform.tokenizer.eos_id + + @property + def eot_token_id(self): + return self._transform.tokenizer.eot_id + + @property + def max_length(self): + return self._max_seq_length + + @property + def truncation(self): + return True + + def tok_encode(self, string, **kwargs) -> List[int]: + # This is only used to get a number of tokens for use in sorting samples in dataset + # These values will not actually be used for eval + return self._transform.tokenizer.encode(string, add_bos=False, add_eos=False) + + def tok_decode(self, tokens, skip_special_tokens=True) -> str: + if isinstance(tokens, int): + tokens = [tokens] + return self._transform.tokenizer.decode( + tokens, skip_special_tokens=skip_special_tokens + ) + + def tok_batch_multimodal_encode( + self, + all_texts: List[str], + all_images: List[List[PIL.Image.Image]], + left_truncate_len: int = None, + *args, + **kwargs, + ): + # Eleuther already parses out the text and images, so we just need to get + # it into a Message format for our tokenizer + all_encoded_messages = [] + + for text, images in zip(all_texts, all_images): + # Ensure images are all RGB + proper_images = [] + for image in images: + if image.mode != "RGB": + image = image.convert("RGB") + proper_images.append(image) + + # Construct the messages + messages = [] + content = format_content_with_images( + text, image_tag=self._image_tag, images=proper_images + ) + messages.append(Message(role="user", content=content)) + messages.append(Message(role="assistant", content="")) + + # Transform the messages + tok_batch = self.model_transform({"messages": messages}, inference=True) + all_encoded_messages.append(tok_batch) + + # Pad the encoded messages + tok_batch = padded_collate_tiled_images_and_mask( + all_encoded_messages, + pad_direction="left", + pad_max_images=self._max_images_per_sample, + ) + utils.batch_to_device(tok_batch, self.device) + + # Convert the batch to the format expected by the HF + tok_batch["input_ids"] = tok_batch.pop("tokens") + + # the harness will use left_truncate_len to indicate that the current batch + # needs to be truncated to self.max_seq_len - self.max_gen_toks + if left_truncate_len is not None: + tok_batch["input_ids"] = tok_batch["input_ids"][:, -left_truncate_len:] + + return tok_batch + + @torch.inference_mode() + def _model_multimodal_generate( + self, + batch: Dict[str, torch.Tensor], + max_length: int, + stop: List[str], + **generation_kwargs, + ): + # 1. Validate inputs + prompt = batch.pop("input_ids") + bsz, seq_len = prompt.shape + + temperature = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", False) + if do_sample or temperature != 0.0: + raise RuntimeError( + "Any decoding strategy other than greedy is not supported." + ) + + if bsz > 1: + raise ValueError( + f"Got a batch size of '{bsz}'. Batch size > 1 is not yet supported for " + "multimodal generation." + ) + + # 2. Setup KV cache and masks for bsz 1 + with self.device: + if self.model.caches_are_enabled(): + self.model.reset_caches() + else: + self.model.setup_caches( + batch_size=1, + dtype=self._dtype, + encoder_max_seq_len=self.model_transform.image_seq_len + * self._max_images_per_sample, + decoder_max_seq_len=self.max_length, + ) + causal_mask = torch.tril( + torch.ones( + size=(self.max_length, self.max_length), + dtype=torch.bool, + ) + ) + input_pos = torch.arange(self.max_length) + + batch["input_pos"] = input_pos[None, :seq_len] + batch["mask"] = causal_mask[None, :seq_len] + + # 3. Prefill step + generated_tokens = [] + logits = self.model(prompt, **batch)[:, -1] + token = sample(logits, temperature=0.0, top_k=None) + generated_tokens.append(token.item()) + + cache_mask = batch["encoder_mask"][:, -1:] + + # 4. Continue generating + for _ in range(max_length): + if token.item() in self.model_transform.stop_tokens: + break + logits = self.model( + token, + mask=causal_mask[None, seq_len, None, :], + encoder_input=None, + encoder_mask=cache_mask, + input_pos=input_pos[None, seq_len], + )[:, -1] + token = sample(logits, temperature=0.0, top_k=None) + generated_tokens.append(token.item()) + seq_len += 1 + + # 5. Return generated tokens + return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0) + + + @torch.no_grad() def eval( model: Model, @@ -223,6 +482,54 @@ def eval( return eval_results +def multi_model_eval( + model: Model, + model_forward: Callable, + tokenizer, + tasks: Optional[list] = None, + limit: Optional[int] = None, + max_seq_length: Optional[int] = None, + device: str = "cpu", + is_pte_model: bool = False,): + """ + Evaluates a language model on a specified task using the lm-evaluation-harness library. + + Args: + model (Model): The pre-trained language model to evaluate. + tokenizer: The tokenizer to use for encoding/decoding text. + tasks (Optional[list]): The names of the evaluation tasks to perform. + limit (Optional[int]): The maximum number of samples to evaluate (None for all available). + max_seq_length (Optional[int]): The maximum sequence length allowed for input text. + + Returns: + eval_results (dict): A dictionary of evaluation results for the specified task(s). + """ + if tasks is None: + tasks = ["wikitext"] + + model_eval_wrapper = _VLMEvalWrapper( + model, + transform=tokenizer, # tranform is the tokenizer for multimodal models + max_seq_length=max_seq_length, + device=device, + ) + + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + task_dict = get_task_dict(tasks) + + eval_results = evaluate( + model_eval_wrapper, + task_dict, + limit=limit, + ) + eval_results["times"] = model_eval_wrapper.times + return eval_results + + def main(args) -> None: """Evaluates model on a task from the `lm-evaluation-harness` library. @@ -244,6 +551,10 @@ def main(args) -> None: compile = args.compile max_seq_length = args.max_seq_length + modality = builder_args.modality + print(f"Modality of model={modality}") + assert modality in ["text", "text-image"], "Only text and text-plus-image modality is supported for evaluation" + print(f"Using device={device}") set_precision(builder_args.precision) @@ -267,8 +578,16 @@ def main(args) -> None: ) torch._inductor.config.coordinate_descent_tuning = False if device == "cpu" else True + evaluator = None + if modality == "text": + evaluator = eval + elif modality == "text-image": + evaluator = multi_model_eval + else: + raise ValueError(f"Unsupported modality: {modality}") + with measure_time("Time to run eval: {time:.02f}s."): - result = eval( + result = evaluator( model.to(device), model_forward, tokenizer, From 78bdacf9487740d602b97e96c9babb7c69841b12 Mon Sep 17 00:00:00 2001 From: anirudh Date: Sun, 9 Feb 2025 13:05:13 +0530 Subject: [PATCH 02/13] remove redundant comment --- torchchat/cli/cli.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 559bb645b..520925e2b 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -144,7 +144,6 @@ def _add_model_specification_args(parser) -> None: choices=["text", "text-image"], # help=argparse.SUPPRESS, help="Modality of the model. Options: text, text-image", - # help="Modality of the model. Options: text, text-image", ) From bfc62dc87961d95118e10c462551534672d676a9 Mon Sep 17 00:00:00 2001 From: anirudh Date: Sun, 9 Feb 2025 21:18:24 +0530 Subject: [PATCH 03/13] Added Llama3VisionTransform in TokenizerArgs and other changes --- torchchat/cli/builder.py | 29 +++++++-- .../model_params/Llama-3.2-11B-Vision.json | 2 +- torchchat/usages/eval.py | 60 ++++++++++--------- 3 files changed, 58 insertions(+), 33 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 010db5374..630af1ead 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -252,13 +252,29 @@ class TokenizerArgs: is_sentencepiece: bool = False is_tiktoken: bool = False is_hf_tokenizer: bool = False + is_llama_3_2_mm: bool = False t: Optional[Any] = None def __post_init__(self): + # special handling for llama-3.2-mm + if "llama-3.2-11b-vision" in str(self.tokenizer_path).lower(): + try: + from torchtune.models.llama3_2_vision import llama3_2_vision_transform + + self.t = llama3_2_vision_transform(path=str(self.tokenizer_path)) + self.is_llama_3_2_mm = True + self.is_tiktoken = False + self.is_sentencepiece = False + self.is_hf_tokenizer = False + return + except: + pass + try: from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path)) + self.is_llama_3_2_mm = False self.is_tiktoken = True self.is_sentencepiece = False self.is_hf_tokenizer = False @@ -270,6 +286,7 @@ def __post_init__(self): from sentencepiece import SentencePieceProcessor self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path)) + self.is_llama_3_2_mm = False self.is_tiktoken = False self.is_sentencepiece = True self.is_hf_tokenizer = False @@ -281,6 +298,7 @@ def __post_init__(self): from tokenizer.hf_tokenizer import HFTokenizer self.t = HFTokenizer(str(self.tokenizer_path)) + self.is_llama_3_2_mm = False self.is_tiktoken = False self.is_sentencepiece = False self.is_hf_tokenizer = True @@ -288,6 +306,7 @@ def __post_init__(self): except: pass + self.is_llama_3_2_mm = False self.is_tiktoken = False self.is_sentencepiece = False self.is_hf_tokenizer = False @@ -302,20 +321,22 @@ def validate_model( if model is None: return - if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: + if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece, self.is_llama_3_2_mm]) != 1: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece is_hf_tokenizer = self.is_hf_tokenizer + is_llama_3_2_mm = self.is_llama_3_2_mm + use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer - use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) - + use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer) if ( (is_tiktoken and not use_tiktoken) or (is_hf_tokenizer and not use_hf_tokenizer) or - (is_sentencepiece and not use_sentencepiece) + (is_sentencepiece and not use_other_tokenizer) or + (is_llama_3_2_mm and not use_other_tokenizer) ): raise RuntimeError( "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( diff --git a/torchchat/model_params/Llama-3.2-11B-Vision.json b/torchchat/model_params/Llama-3.2-11B-Vision.json index 5232e3512..b9b66ba94 100644 --- a/torchchat/model_params/Llama-3.2-11B-Vision.json +++ b/torchchat/model_params/Llama-3.2-11B-Vision.json @@ -1,6 +1,6 @@ { "model_type": "flamingo", - "use_tiktoken": true, + "use_tiktoken": false, "encoder": { "patch_size": 14, "num_heads": 16, diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index 4d17a978e..4f4be9aff 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -378,16 +378,13 @@ def _model_multimodal_generate( # 2. Setup KV cache and masks for bsz 1 with self.device: - if self.model.caches_are_enabled(): - self.model.reset_caches() - else: - self.model.setup_caches( - batch_size=1, - dtype=self._dtype, - encoder_max_seq_len=self.model_transform.image_seq_len - * self._max_images_per_sample, - decoder_max_seq_len=self.max_length, - ) + self.model.setup_caches( + batch_size=1, + dtype=self._dtype, + encoder_max_seq_len=self.model_transform.image_seq_len + * self._max_images_per_sample, + decoder_max_seq_len=self.max_length, + ) causal_mask = torch.tril( torch.ones( size=(self.max_length, self.max_length), @@ -506,6 +503,8 @@ def multi_model_eval( """ if tasks is None: tasks = ["wikitext"] + max_seq_length = 4096 if max_seq_length is None else max_seq_length + device = utils.get_device(device) if isinstance(device, str) else device model_eval_wrapper = _VLMEvalWrapper( model, @@ -578,25 +577,30 @@ def main(args) -> None: ) torch._inductor.config.coordinate_descent_tuning = False if device == "cpu" else True - evaluator = None - if modality == "text": - evaluator = eval - elif modality == "text-image": - evaluator = multi_model_eval - else: - raise ValueError(f"Unsupported modality: {modality}") - with measure_time("Time to run eval: {time:.02f}s."): - result = evaluator( - model.to(device), - model_forward, - tokenizer, - tasks, - limit, - max_seq_length, - device=builder_args.device, - is_pte_model=builder_args.pte_path is not None, - ) + if modality == "text": + result = eval( + model.to(device), + model_forward, + tokenizer, + tasks, + limit, + max_seq_length, + device=builder_args.device, + is_pte_model=builder_args.pte_path is not None, + ) + elif modality == "text-image": + result = multi_model_eval( + model.to(device), + model_forward, + tokenizer, + tasks, + limit, + max_seq_length, + device=builder_args.device, + ) + else: + raise ValueError(f"Unsupported modality: {modality}") times = torch.tensor(result["times"]) print( From 8900f8ad4f2257c31a564f9704c568f9c6363e03 Mon Sep 17 00:00:00 2001 From: anirudh Date: Thu, 20 Feb 2025 09:13:28 +0530 Subject: [PATCH 04/13] use kv caching and other minor fixes --- install/requirements.txt | 2 +- torchchat/model.py | 6 +++ torchchat/usages/eval.py | 79 +++++++++++++++++++++------------------- 3 files changed, 49 insertions(+), 38 deletions(-) diff --git a/install/requirements.txt b/install/requirements.txt index e9df1a209..73be68763 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -34,4 +34,4 @@ streamlit flask # eval -lm_eval==0.4.5 +lm_eval==0.4.7 diff --git a/torchchat/model.py b/torchchat/model.py index ce7dcb5e4..9722ca240 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -608,6 +608,12 @@ def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_l decoder_max_seq_len=decoder_max_seq_len, ) + def caches_are_setup(self) -> bool: + return self.model.caches_are_setup() + + def caches_are_enabled(self) -> bool: + return self.model.caches_are_enabled() + def reset_caches(self): self.model.reset_caches() diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index 4f4be9aff..c9f5e034d 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -36,6 +36,7 @@ from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.evaluator import evaluate +from torchtune.modules.common_utils import local_kv_cache from torchtune.modules.model_fusion import DeepFusionModel from torchtune.modules.transforms import Transform from torchtune.data import ( @@ -183,12 +184,7 @@ def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented") -# Dummy class which _VLMEvalWrapper can inherit from when the imports don't work -# class HFMultimodalLM(): -# def __init__(self): -# return - -class _VLMEvalWrapper(HFMultimodalLM): +class VLMEvalWrapper(HFMultimodalLM): """An EvalWrapper for EleutherAI's eval harness based on gpt-fast's EvalWrapper: https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py. @@ -234,6 +230,7 @@ def __init__( self._enable_kv_cache = True self._image_tag = image_tag self._max_images_per_sample = max_images_per_sample + self.times = [] @property def model(self): @@ -338,6 +335,7 @@ def tok_batch_multimodal_encode( all_encoded_messages, pad_direction="left", pad_max_images=self._max_images_per_sample, + pad_max_tiles=self._transform.max_num_tiles, ) utils.batch_to_device(tok_batch, self.device) @@ -376,15 +374,11 @@ def _model_multimodal_generate( "multimodal generation." ) - # 2. Setup KV cache and masks for bsz 1 + encoder_max_seq_len = ( + self.model_transform.image_seq_len * self._max_images_per_sample + ) + # Setup masks for bsz 1 with self.device: - self.model.setup_caches( - batch_size=1, - dtype=self._dtype, - encoder_max_seq_len=self.model_transform.image_seq_len - * self._max_images_per_sample, - decoder_max_seq_len=self.max_length, - ) causal_mask = torch.tril( torch.ones( size=(self.max_length, self.max_length), @@ -396,28 +390,39 @@ def _model_multimodal_generate( batch["input_pos"] = input_pos[None, :seq_len] batch["mask"] = causal_mask[None, :seq_len] - # 3. Prefill step - generated_tokens = [] - logits = self.model(prompt, **batch)[:, -1] - token = sample(logits, temperature=0.0, top_k=None) - generated_tokens.append(token.item()) - - cache_mask = batch["encoder_mask"][:, -1:] - - # 4. Continue generating - for _ in range(max_length): - if token.item() in self.model_transform.stop_tokens: - break - logits = self.model( - token, - mask=causal_mask[None, seq_len, None, :], - encoder_input=None, - encoder_mask=cache_mask, - input_pos=input_pos[None, seq_len], - )[:, -1] - token = sample(logits, temperature=0.0, top_k=None) - generated_tokens.append(token.item()) - seq_len += 1 + with measure_time(message=None) as measure: + # 2. Setup KV cache + with local_kv_cache( + self.model, + batch_size=self.batch_size, + device=self.device, + dtype=self._dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=self.max_length, + ): + # 3. Prefill step + generated_tokens = [] + logits = self.model(prompt, **batch)[:, -1] + token = sample(logits, temperature=0.0, top_k=None) + generated_tokens.append(token.item()) + + cache_mask = batch["encoder_mask"][:, -1:] + + # 4. Continue generating + for _ in range(max_length): + if token.item() in self.model_transform.stop_tokens: + break + logits = self.model( + token, + mask=causal_mask[None, seq_len, None, :], + encoder_input=None, + encoder_mask=cache_mask, + input_pos=input_pos[None, seq_len], + )[:, -1] + token = sample(logits, temperature=0.0, top_k=None) + generated_tokens.append(token.item()) + seq_len += 1 + self.times.append(measure.get_time()) # 5. Return generated tokens return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0) @@ -506,7 +511,7 @@ def multi_model_eval( max_seq_length = 4096 if max_seq_length is None else max_seq_length device = utils.get_device(device) if isinstance(device, str) else device - model_eval_wrapper = _VLMEvalWrapper( + model_eval_wrapper = VLMEvalWrapper( model, transform=tokenizer, # tranform is the tokenizer for multimodal models max_seq_length=max_seq_length, From 59ce657fddd8b9c8682d8f2a5bcd88bf1b09a82a Mon Sep 17 00:00:00 2001 From: anirudh Date: Sun, 23 Feb 2025 08:44:10 +0530 Subject: [PATCH 05/13] default batch size 1 --- torchchat/usages/eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index c9f5e034d..730640031 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -212,7 +212,7 @@ def __init__( *, device: torch.device, max_seq_length: int = 4096, - batch_size: int = 8, + batch_size: int = 1, dtype: torch.dtype = torch.bfloat16, enable_kv_cache: bool = True, # TODO (@joecummings): Update these defaults once more multimodal From afdb3ceabb46b376e9b92c76aa139817ef314a4b Mon Sep 17 00:00:00 2001 From: anirudh Date: Sun, 23 Feb 2025 09:09:49 +0530 Subject: [PATCH 06/13] lint eval.py and builder.py --- torchchat/cli/builder.py | 91 ++++++++++++++++++++++++---------------- torchchat/usages/eval.py | 36 +++++++++------- 2 files changed, 77 insertions(+), 50 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 630af1ead..3a9dd2b69 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,13 +16,22 @@ import torch._inductor.config import torch.distributed as dist -from torchchat.distributed.utils import( +from torchtune.models.convert_weights import meta_to_tune + +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE + +from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune + +from torchtune.training import set_default_dtype + +from torchchat.distributed.logging_utils import SingletonLogger + +from torchchat.distributed.utils import ( Color as color, CUDATrackTime, - init_distributed, GPUMemoryMonitor, + init_distributed, ) -from torchchat.distributed.logging_utils import SingletonLogger from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs from torchchat.model_config.model_config import resolve_model_config @@ -36,15 +45,6 @@ from torchchat.utils.quantize import quantize_model -from torchtune.models.convert_weights import meta_to_tune - -from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE - -from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune - -from torchtune.training import set_default_dtype - - @dataclass class BuilderArgs: checkpoint_path: Optional[Union[Path, str]] = None @@ -194,15 +194,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": tp = getattr(args, "tp", 1) chpt_from = getattr(args, "chpt_from", "hf") sdp_backend_dict = { - 'math': torch.nn.attention.SDPBackend.MATH, - 'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION, - 'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, - 'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + "math": torch.nn.attention.SDPBackend.MATH, + "flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION, + "efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + "cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION, } attention_backend = sdp_backend_dict[args.attention_backend] - if args.device == "cpu" and (args.attention_backend == "efficient_attention" - or args.attention_backend == "cudnn_attention"): - print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.") + if args.device == "cpu" and ( + args.attention_backend == "efficient_attention" + or args.attention_backend == "cudnn_attention" + ): + print( + f"Warning: {args.attention_backend} is not supported on CPU. Using math instead." + ) attention_backend = torch.nn.attention.SDPBackend.MATH return cls( checkpoint_dir=checkpoint_dir, @@ -321,7 +325,17 @@ def validate_model( if model is None: return - if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece, self.is_llama_3_2_mm]) != 1: + if ( + sum( + [ + self.is_tiktoken, + self.is_hf_tokenizer, + self.is_sentencepiece, + self.is_llama_3_2_mm, + ] + ) + != 1 + ): raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") is_tiktoken = self.is_tiktoken @@ -333,10 +347,10 @@ def validate_model( use_hf_tokenizer = model.config.use_hf_tokenizer use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer) if ( - (is_tiktoken and not use_tiktoken) or - (is_hf_tokenizer and not use_hf_tokenizer) or - (is_sentencepiece and not use_other_tokenizer) or - (is_llama_3_2_mm and not use_other_tokenizer) + (is_tiktoken and not use_tiktoken) + or (is_hf_tokenizer and not use_hf_tokenizer) + or (is_sentencepiece and not use_other_tokenizer) + or (is_llama_3_2_mm and not use_other_tokenizer) ): raise RuntimeError( "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( @@ -534,6 +548,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: # AOTI-compoiled model will load its own weights. # Release weights here to avoid OOM import gc + if hasattr(model, "model"): model.model = None gc.collect() @@ -591,6 +606,7 @@ def _initialize_model( def do_nothing(max_batch_size, max_seq_length): pass + model.setup_caches = do_nothing model.forward = torch._export.aot_load( @@ -628,6 +644,7 @@ def do_nothing(max_batch_size, max_seq_length): def do_nothing(max_batch_size, max_seq_length): pass + model.setup_caches = do_nothing model.forward = aoti_compiled_model @@ -702,7 +719,9 @@ def do_nothing(max_batch_size, max_seq_length): logger = SingletonLogger.get_logger() gpu_memory_monitor = GPUMemoryMonitor("cuda") - logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") + logger.info( + f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}" + ) # Model-level config if builder_args.params_table: @@ -713,20 +732,16 @@ def do_nothing(max_batch_size, max_seq_length): config = TransformerArgs.from_params(model_config.transformer_args["text"]) logger.info(f"Transformer Config: {config}") - #TODO: Move into head of file after solving circular import - from torchchat.distributed.checkpoint_utils import ( - load_model_weights, - ) + # TODO: Move into head of file after solving circular import + from torchchat.distributed.checkpoint_utils import load_model_weights # Validate pipeline degree assert config.n_layers % pp_degree == 0 # Create device mesh device_mesh = dist.init_device_mesh( - "cuda", - (pp_degree, tp_degree), - mesh_dim_names=("pp", "tp") - ) + "cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") + ) tp_mesh = device_mesh["tp"] pp_mesh = device_mesh["pp"] logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") @@ -755,7 +770,13 @@ def do_nothing(max_batch_size, max_seq_length): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") with CUDATrackTime() as timer: - load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from) + load_model_weights( + model, + builder_args.distribution_path, + device, + config, + builder_args.chpt_from, + ) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" @@ -769,7 +790,7 @@ def do_nothing(max_batch_size, max_seq_length): # lanes. # TODO: bump up the lane count pipeline_lanes = 1 - seqlen_prefill=1024 + seqlen_prefill = 1024 with device: model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index 730640031..22286293b 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse -from typing import Callable, Optional, Dict, List +from typing import Callable, Dict, List, Optional import torch import torch._dynamo.config @@ -30,15 +30,13 @@ import lm_eval +import PIL + from lm_eval.evaluator import evaluate +from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.models.huggingface import HFLM as eval_wrapper from lm_eval.tasks import get_task_dict -from lm_eval.models.hf_vlms import HFMultimodalLM -from lm_eval.evaluator import evaluate - -from torchtune.modules.common_utils import local_kv_cache -from torchtune.modules.model_fusion import DeepFusionModel -from torchtune.modules.transforms import Transform +from torchtune import utils from torchtune.data import ( format_content_with_images, left_pad_sequence, @@ -46,9 +44,10 @@ padded_collate_tiled_images_and_mask, ) from torchtune.generation import generate, sample -from torchtune import utils -import PIL +from torchtune.modules.common_utils import local_kv_cache +from torchtune.modules.model_fusion import DeepFusionModel +from torchtune.modules.transforms import Transform def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( @@ -428,7 +427,6 @@ def _model_multimodal_generate( return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0) - @torch.no_grad() def eval( model: Model, @@ -492,7 +490,8 @@ def multi_model_eval( limit: Optional[int] = None, max_seq_length: Optional[int] = None, device: str = "cpu", - is_pte_model: bool = False,): + is_pte_model: bool = False, +): """ Evaluates a language model on a specified task using the lm-evaluation-harness library. @@ -513,7 +512,7 @@ def multi_model_eval( model_eval_wrapper = VLMEvalWrapper( model, - transform=tokenizer, # tranform is the tokenizer for multimodal models + transform=tokenizer, # tranform is the tokenizer for multimodal models max_seq_length=max_seq_length, device=device, ) @@ -557,7 +556,10 @@ def main(args) -> None: modality = builder_args.modality print(f"Modality of model={modality}") - assert modality in ["text", "text-image"], "Only text and text-plus-image modality is supported for evaluation" + assert modality in [ + "text", + "text-image", + ], "Only text and text-image modality is supported for evaluation" print(f"Using device={device}") set_precision(builder_args.precision) @@ -575,12 +577,16 @@ def main(args) -> None: if compile: assert not ( - builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path + builder_args.dso_path + or builder_args.pte_path + or builder_args.aoti_package_path ), "cannot compile exported model" model_forward = torch.compile( model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True ) - torch._inductor.config.coordinate_descent_tuning = False if device == "cpu" else True + torch._inductor.config.coordinate_descent_tuning = ( + False if device == "cpu" else True + ) with measure_time("Time to run eval: {time:.02f}s."): if modality == "text": From ae66bafceb9d61d2a57694204606acf99d76f778 Mon Sep 17 00:00:00 2001 From: anirudh Date: Mon, 24 Feb 2025 15:45:24 +0530 Subject: [PATCH 07/13] lm-eval 0.4.2->0.4.7 in install_requirements.sh --- install/install_requirements.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 41fe30baa..f1fb72247 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -130,5 +130,5 @@ if [[ -x "$(command -v nvidia-smi)" ]]; then fi ( set -x - $PIP_EXECUTABLE install evaluate=="0.4.3" lm-eval=="0.4.2" psutil=="6.0.0" + $PIP_EXECUTABLE install evaluate=="0.4.3" lm-eval=="0.4.7" psutil=="6.0.0" ) From 7721be9ab6bd3dbffda1b2c12004536068a319a4 Mon Sep 17 00:00:00 2001 From: anirudh Date: Mon, 3 Mar 2025 22:19:36 +0530 Subject: [PATCH 08/13] fixes from code review --- torchchat/cli/cli.py | 17 +++-- torchchat/usages/eval.py | 135 +++++++++++++-------------------------- 2 files changed, 51 insertions(+), 101 deletions(-) diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 520925e2b..7fd02eed3 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -137,15 +137,6 @@ def _add_model_specification_args(parser) -> None: help=argparse.SUPPRESS, ) - model_specification_parser.add_argument( - "--modality", - type=str, - default="text", - choices=["text", "text-image"], - # help=argparse.SUPPRESS, - help="Modality of the model. Options: text, text-image", - ) - # Add CLI Args related to model configuration (compilation, quant, etc) # Excludes compile args if subcommand is export @@ -441,6 +432,14 @@ def _add_evaluation_args(parser) -> None: help="Maximum length sequence to evaluate", ) + eval_parser.add_argument( + "--modality", + type=str, + default="text", + choices=["text", "text-image"], + help="Modality of the model. Options: text, text-image", + ) + # Add CLI Args related to distributed inference # This feature is currently a [WIP] and hidden from --help diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index 22286293b..5936727a9 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Literal import torch import torch._dynamo.config @@ -184,7 +184,12 @@ def _model_generate(self, context, max_length, eos_token_id): class VLMEvalWrapper(HFMultimodalLM): - """An EvalWrapper for EleutherAI's eval harness based on gpt-fast's + """ + This class is adapted from torchtune. + Source: https://github.com/pytorch/torchtune/blob/main/recipes/eleuther_eval.py + ------------------------------------------------------------------------------- + + An EvalWrapper for EleutherAI's eval harness based on gpt-fast's EvalWrapper: https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py. Note: @@ -437,6 +442,7 @@ def eval( max_seq_length: Optional[int] = None, device: str = "cpu", is_pte_model: bool = False, + modality: Literal["text", "text-image"] = "text", ) -> dict: """ Evaluates a language model on a specified task using the lm-evaluation-harness library. @@ -447,21 +453,33 @@ def eval( tasks (Optional[list]): The names of the evaluation tasks to perform. limit (Optional[int]): The maximum number of samples to evaluate (None for all available). max_seq_length (Optional[int]): The maximum sequence length allowed for input text. + modality (str): The modality of the model. Options: text, text-image Returns: eval_results (dict): A dictionary of evaluation results for the specified task(s). """ if tasks is None: - tasks = ["wikitext"] - - model_eval_wrapper = GPTFastEvalWrapper( - model, - tokenizer, - model_forward=model_forward, - max_seq_length=max_seq_length, - device=device, - is_pte_model=is_pte_model, - ) + if modality == "text": + tasks = ["wikitext"] + elif modality == "text-image": + tasks = ["mmmu-val-art"] + + if modality == "text": + model_eval_wrapper = GPTFastEvalWrapper( + model, + tokenizer, + model_forward=model_forward, + max_seq_length=max_seq_length, + device=device, + is_pte_model=is_pte_model, + ) + elif modality == "text-image": + model_eval_wrapper = VLMEvalWrapper( + model, + transform=tokenizer, + max_seq_length = 4096 if max_seq_length is None else max_seq_length, + device = utils.get_device(device) if isinstance(device, str) else device, + ) try: lm_eval.tasks.initialize_tasks() @@ -482,57 +500,6 @@ def eval( return eval_results -def multi_model_eval( - model: Model, - model_forward: Callable, - tokenizer, - tasks: Optional[list] = None, - limit: Optional[int] = None, - max_seq_length: Optional[int] = None, - device: str = "cpu", - is_pte_model: bool = False, -): - """ - Evaluates a language model on a specified task using the lm-evaluation-harness library. - - Args: - model (Model): The pre-trained language model to evaluate. - tokenizer: The tokenizer to use for encoding/decoding text. - tasks (Optional[list]): The names of the evaluation tasks to perform. - limit (Optional[int]): The maximum number of samples to evaluate (None for all available). - max_seq_length (Optional[int]): The maximum sequence length allowed for input text. - - Returns: - eval_results (dict): A dictionary of evaluation results for the specified task(s). - """ - if tasks is None: - tasks = ["wikitext"] - max_seq_length = 4096 if max_seq_length is None else max_seq_length - device = utils.get_device(device) if isinstance(device, str) else device - - model_eval_wrapper = VLMEvalWrapper( - model, - transform=tokenizer, # tranform is the tokenizer for multimodal models - max_seq_length=max_seq_length, - device=device, - ) - - try: - lm_eval.tasks.initialize_tasks() - except: - pass - - task_dict = get_task_dict(tasks) - - eval_results = evaluate( - model_eval_wrapper, - task_dict, - limit=limit, - ) - eval_results["times"] = model_eval_wrapper.times - return eval_results - - def main(args) -> None: """Evaluates model on a task from the `lm-evaluation-harness` library. @@ -553,13 +520,8 @@ def main(args) -> None: limit = args.limit compile = args.compile max_seq_length = args.max_seq_length + modality = args.modality - modality = builder_args.modality - print(f"Modality of model={modality}") - assert modality in [ - "text", - "text-image", - ], "Only text and text-image modality is supported for evaluation" print(f"Using device={device}") set_precision(builder_args.precision) @@ -588,30 +550,19 @@ def main(args) -> None: False if device == "cpu" else True ) + with measure_time("Time to run eval: {time:.02f}s."): - if modality == "text": - result = eval( - model.to(device), - model_forward, - tokenizer, - tasks, - limit, - max_seq_length, - device=builder_args.device, - is_pte_model=builder_args.pte_path is not None, - ) - elif modality == "text-image": - result = multi_model_eval( - model.to(device), - model_forward, - tokenizer, - tasks, - limit, - max_seq_length, - device=builder_args.device, - ) - else: - raise ValueError(f"Unsupported modality: {modality}") + result = eval( + model.to(device), + model_forward, + tokenizer, + tasks, + limit, + max_seq_length, + device=builder_args.device, + is_pte_model=builder_args.pte_path is not None, + modality=modality, + ) times = torch.tensor(result["times"]) print( From 96ab799d8ae81558d17d98641a1b96bac9110be4 Mon Sep 17 00:00:00 2001 From: anirudh Date: Sun, 9 Mar 2025 12:52:20 +0530 Subject: [PATCH 09/13] remove modality from builder args --- torchchat/cli/builder.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 59a85bc27..65897910b 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -71,7 +71,6 @@ class BuilderArgs: dynamic_shapes: bool = False max_seq_length: Optional[int] = None attention_backend: str = "math" - modality: Optional[str] = "text" def __post_init__(self): if self.device is None: @@ -147,10 +146,6 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": aoti_package_path = getattr(args, "aoti_package_path", None) snapshot_path = getattr(args, "snapshot_path", None) - modality = "text" - if args.modality: - modality = args.modality - is_chat_model = False if args.is_chat_model: is_chat_model = True @@ -231,7 +226,6 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": chpt_from=chpt_from, distribution_path=distribution_path, is_chat_model=is_chat_model, - modality=modality, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), attention_backend=attention_backend, From 51b0e83f7ba82deb226280fef82badc1b2bf0a83 Mon Sep 17 00:00:00 2001 From: anirudh Date: Mon, 17 Mar 2025 23:13:16 +0530 Subject: [PATCH 10/13] use custom prefix token --- torchchat/usages/eval.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index 5936727a9..5d72890b5 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -473,6 +473,8 @@ def eval( device=device, is_pte_model=is_pte_model, ) + # use eot_token_id as prefix_token_id. + model_eval_wrapper.custom_prefix_token_id = model_eval_wrapper.eot_token_id elif modality == "text-image": model_eval_wrapper = VLMEvalWrapper( model, From 51135fdd02bfdcb17dad671054d98df86dabc1de Mon Sep 17 00:00:00 2001 From: anirudh Date: Wed, 19 Mar 2025 00:10:02 +0530 Subject: [PATCH 11/13] move torchtune imports inside VLMEvalWrapper --- torchchat/usages/eval.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index 5d72890b5..a2af33270 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -36,18 +36,6 @@ from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.models.huggingface import HFLM as eval_wrapper from lm_eval.tasks import get_task_dict -from torchtune import utils -from torchtune.data import ( - format_content_with_images, - left_pad_sequence, - Message, - padded_collate_tiled_images_and_mask, -) -from torchtune.generation import generate, sample - -from torchtune.modules.common_utils import local_kv_cache -from torchtune.modules.model_fusion import DeepFusionModel -from torchtune.modules.transforms import Transform def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( @@ -209,6 +197,20 @@ class VLMEvalWrapper(HFMultimodalLM): the max number of images in MMMU. """ + # Having the imports here allow running other evals without installing torchtune + from torchtune import utils + from torchtune.data import ( + format_content_with_images, + left_pad_sequence, + Message, + padded_collate_tiled_images_and_mask, + ) + from torchtune.generation import generate, sample + + from torchtune.modules.common_utils import local_kv_cache + from torchtune.modules.model_fusion import DeepFusionModel + from torchtune.modules.transforms import Transform + def __init__( self, model: DeepFusionModel, From 14502bfb8117df87b47916710afa4c9db32dd963 Mon Sep 17 00:00:00 2001 From: anirudh Date: Sun, 23 Mar 2025 20:00:23 +0530 Subject: [PATCH 12/13] revert changes from builder.py --- torchchat/cli/builder.py | 116 +++++++++++++-------------------------- 1 file changed, 37 insertions(+), 79 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 65897910b..fcc2d5f66 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,22 +16,13 @@ import torch._inductor.config import torch.distributed as dist -from torchtune.models.convert_weights import meta_to_tune - -from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE - -from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune - -from torchtune.training import set_default_dtype - -from torchchat.distributed.logging_utils import SingletonLogger - -from torchchat.distributed.utils import ( +from torchchat.distributed.utils import( Color as color, CUDATrackTime, - GPUMemoryMonitor, init_distributed, + GPUMemoryMonitor, ) +from torchchat.distributed.logging_utils import SingletonLogger from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs from torchchat.model_config.model_config import resolve_model_config @@ -45,6 +36,15 @@ from torchchat.utils.quantize import quantize_model +from torchtune.models.convert_weights import meta_to_tune + +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE + +from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune + +from torchtune.training import set_default_dtype + + @dataclass class BuilderArgs: checkpoint_path: Optional[Union[Path, str]] = None @@ -189,19 +189,15 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": tp = getattr(args, "tp", 1) chpt_from = getattr(args, "chpt_from", "hf") sdp_backend_dict = { - "math": torch.nn.attention.SDPBackend.MATH, - "flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION, - "efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, - "cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + 'math': torch.nn.attention.SDPBackend.MATH, + 'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION, + 'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + 'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION, } attention_backend = sdp_backend_dict[args.attention_backend] - if args.device == "cpu" and ( - args.attention_backend == "efficient_attention" - or args.attention_backend == "cudnn_attention" - ): - print( - f"Warning: {args.attention_backend} is not supported on CPU. Using math instead." - ) + if args.device == "cpu" and (args.attention_backend == "efficient_attention" + or args.attention_backend == "cudnn_attention"): + print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.") attention_backend = torch.nn.attention.SDPBackend.MATH return cls( checkpoint_dir=checkpoint_dir, @@ -250,29 +246,13 @@ class TokenizerArgs: is_sentencepiece: bool = False is_tiktoken: bool = False is_hf_tokenizer: bool = False - is_llama_3_2_mm: bool = False t: Optional[Any] = None def __post_init__(self): - # special handling for llama-3.2-mm - if "llama-3.2-11b-vision" in str(self.tokenizer_path).lower(): - try: - from torchtune.models.llama3_2_vision import llama3_2_vision_transform - - self.t = llama3_2_vision_transform(path=str(self.tokenizer_path)) - self.is_llama_3_2_mm = True - self.is_tiktoken = False - self.is_sentencepiece = False - self.is_hf_tokenizer = False - return - except: - pass - try: from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path)) - self.is_llama_3_2_mm = False self.is_tiktoken = True self.is_sentencepiece = False self.is_hf_tokenizer = False @@ -284,7 +264,6 @@ def __post_init__(self): from sentencepiece import SentencePieceProcessor self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path)) - self.is_llama_3_2_mm = False self.is_tiktoken = False self.is_sentencepiece = True self.is_hf_tokenizer = False @@ -296,7 +275,6 @@ def __post_init__(self): from tokenizer.hf_tokenizer import HFTokenizer self.t = HFTokenizer(str(self.tokenizer_path)) - self.is_llama_3_2_mm = False self.is_tiktoken = False self.is_sentencepiece = False self.is_hf_tokenizer = True @@ -304,7 +282,6 @@ def __post_init__(self): except: pass - self.is_llama_3_2_mm = False self.is_tiktoken = False self.is_sentencepiece = False self.is_hf_tokenizer = False @@ -319,32 +296,20 @@ def validate_model( if model is None: return - if ( - sum( - [ - self.is_tiktoken, - self.is_hf_tokenizer, - self.is_sentencepiece, - self.is_llama_3_2_mm, - ] - ) - != 1 - ): + if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece is_hf_tokenizer = self.is_hf_tokenizer - is_llama_3_2_mm = self.is_llama_3_2_mm - use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer - use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer) + use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) + if ( - (is_tiktoken and not use_tiktoken) - or (is_hf_tokenizer and not use_hf_tokenizer) - or (is_sentencepiece and not use_other_tokenizer) - or (is_llama_3_2_mm and not use_other_tokenizer) + (is_tiktoken and not use_tiktoken) or + (is_hf_tokenizer and not use_hf_tokenizer) or + (is_sentencepiece and not use_sentencepiece) ): raise RuntimeError( "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( @@ -542,7 +507,6 @@ def _load_model(builder_args: BuilderArgs) -> Model: # AOTI-compoiled model will load its own weights. # Release weights here to avoid OOM import gc - if hasattr(model, "model"): model.model = None gc.collect() @@ -600,7 +564,6 @@ def _initialize_model( def do_nothing(max_batch_size, max_seq_length): pass - model.setup_caches = do_nothing model.forward = torch._export.aot_load( @@ -638,7 +601,6 @@ def do_nothing(max_batch_size, max_seq_length): def do_nothing(max_batch_size, max_seq_length): pass - model.setup_caches = do_nothing model.forward = aoti_compiled_model @@ -713,9 +675,7 @@ def do_nothing(max_batch_size, max_seq_length): logger = SingletonLogger.get_logger() gpu_memory_monitor = GPUMemoryMonitor("cuda") - logger.info( - f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}" - ) + logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") # Model-level config if builder_args.params_table: @@ -726,16 +686,20 @@ def do_nothing(max_batch_size, max_seq_length): config = TransformerArgs.from_params(model_config.transformer_args["text"]) logger.info(f"Transformer Config: {config}") - # TODO: Move into head of file after solving circular import - from torchchat.distributed.checkpoint_utils import load_model_weights + #TODO: Move into head of file after solving circular import + from torchchat.distributed.checkpoint_utils import ( + load_model_weights, + ) # Validate pipeline degree assert config.n_layers % pp_degree == 0 # Create device mesh device_mesh = dist.init_device_mesh( - "cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") - ) + "cuda", + (pp_degree, tp_degree), + mesh_dim_names=("pp", "tp") + ) tp_mesh = device_mesh["tp"] pp_mesh = device_mesh["pp"] logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") @@ -764,13 +728,7 @@ def do_nothing(max_batch_size, max_seq_length): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") with CUDATrackTime() as timer: - load_model_weights( - model, - builder_args.distribution_path, - device, - config, - builder_args.chpt_from, - ) + load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" @@ -784,7 +742,7 @@ def do_nothing(max_batch_size, max_seq_length): # lanes. # TODO: bump up the lane count pipeline_lanes = 1 - seqlen_prefill = 1024 + seqlen_prefill=1024 with device: model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) @@ -836,4 +794,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str: return "TikToken" if tokenizers: return "Tokenizers" - return "SentencePiece" + return "SentencePiece" \ No newline at end of file From 815966c7b12367553ffa532a2e05181a8fdba120 Mon Sep 17 00:00:00 2001 From: anirudh Date: Sun, 23 Mar 2025 23:23:58 +0530 Subject: [PATCH 13/13] instantiate transform in eval() --- .../model_params/Llama-3.2-11B-Vision.json | 2 +- torchchat/usages/eval.py | 61 +++++++++++-------- 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/torchchat/model_params/Llama-3.2-11B-Vision.json b/torchchat/model_params/Llama-3.2-11B-Vision.json index b9b66ba94..5232e3512 100644 --- a/torchchat/model_params/Llama-3.2-11B-Vision.json +++ b/torchchat/model_params/Llama-3.2-11B-Vision.json @@ -1,6 +1,6 @@ { "model_type": "flamingo", - "use_tiktoken": false, + "use_tiktoken": true, "encoder": { "patch_size": 14, "num_heads": 16, diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index a2af33270..882b650d0 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -197,24 +197,11 @@ class VLMEvalWrapper(HFMultimodalLM): the max number of images in MMMU. """ - # Having the imports here allow running other evals without installing torchtune - from torchtune import utils - from torchtune.data import ( - format_content_with_images, - left_pad_sequence, - Message, - padded_collate_tiled_images_and_mask, - ) - from torchtune.generation import generate, sample - - from torchtune.modules.common_utils import local_kv_cache - from torchtune.modules.model_fusion import DeepFusionModel - from torchtune.modules.transforms import Transform def __init__( self, - model: DeepFusionModel, - transform: Transform, + model: Model, + transform, *, device: torch.device, max_seq_length: int = 4096, @@ -226,6 +213,25 @@ def __init__( image_tag: str = "", max_images_per_sample: int = 7, ): + # Having the imports here allow running other evals without installing torchtune + from torchtune.utils import batch_to_device + from torchtune.data import ( + format_content_with_images, + left_pad_sequence, + Message, + padded_collate_tiled_images_and_mask, + ) + from torchtune.generation import generate, sample + from torchtune.modules.common_utils import local_kv_cache + self.batch_to_device = batch_to_device + self.format_content_with_images = format_content_with_images + self.left_pad_sequence = left_pad_sequence + self.Message = Message + self.padded_collate_tiled_images_and_mask = padded_collate_tiled_images_and_mask + self.generate = generate + self.sample = sample + self.local_kv_cache = local_kv_cache + self._model = model self._transform = transform self._device = device @@ -326,24 +332,24 @@ def tok_batch_multimodal_encode( # Construct the messages messages = [] - content = format_content_with_images( + content = self.format_content_with_images( text, image_tag=self._image_tag, images=proper_images ) - messages.append(Message(role="user", content=content)) - messages.append(Message(role="assistant", content="")) + messages.append(self.Message(role="user", content=content)) + messages.append(self.Message(role="assistant", content="")) # Transform the messages tok_batch = self.model_transform({"messages": messages}, inference=True) all_encoded_messages.append(tok_batch) # Pad the encoded messages - tok_batch = padded_collate_tiled_images_and_mask( + tok_batch = self.padded_collate_tiled_images_and_mask( all_encoded_messages, pad_direction="left", pad_max_images=self._max_images_per_sample, pad_max_tiles=self._transform.max_num_tiles, ) - utils.batch_to_device(tok_batch, self.device) + self.batch_to_device(tok_batch, self.device) # Convert the batch to the format expected by the HF tok_batch["input_ids"] = tok_batch.pop("tokens") @@ -398,7 +404,7 @@ def _model_multimodal_generate( with measure_time(message=None) as measure: # 2. Setup KV cache - with local_kv_cache( + with self.local_kv_cache( self.model, batch_size=self.batch_size, device=self.device, @@ -409,7 +415,7 @@ def _model_multimodal_generate( # 3. Prefill step generated_tokens = [] logits = self.model(prompt, **batch)[:, -1] - token = sample(logits, temperature=0.0, top_k=None) + token = self.sample(logits, temperature=0.0, top_k=None) generated_tokens.append(token.item()) cache_mask = batch["encoder_mask"][:, -1:] @@ -425,7 +431,7 @@ def _model_multimodal_generate( encoder_mask=cache_mask, input_pos=input_pos[None, seq_len], )[:, -1] - token = sample(logits, temperature=0.0, top_k=None) + token = self.sample(logits, temperature=0.0, top_k=None) generated_tokens.append(token.item()) seq_len += 1 self.times.append(measure.get_time()) @@ -460,6 +466,7 @@ def eval( Returns: eval_results (dict): A dictionary of evaluation results for the specified task(s). """ + if tasks is None: if modality == "text": tasks = ["wikitext"] @@ -478,11 +485,14 @@ def eval( # use eot_token_id as prefix_token_id. model_eval_wrapper.custom_prefix_token_id = model_eval_wrapper.eot_token_id elif modality == "text-image": + from torchtune.utils import get_device + from torchtune.models.llama3_2_vision import llama3_2_vision_transform + model_eval_wrapper = VLMEvalWrapper( model, - transform=tokenizer, + transform=llama3_2_vision_transform(path=str(tokenizer.tokenizer_path)), max_seq_length = 4096 if max_seq_length is None else max_seq_length, - device = utils.get_device(device) if isinstance(device, str) else device, + device = get_device(device) if isinstance(device, str) else device, ) try: @@ -531,6 +541,7 @@ def main(args) -> None: set_precision(builder_args.precision) tokenizer = _initialize_tokenizer(tokenizer_args) + tokenizer.tokenizer_path = tokenizer_args.tokenizer_path builder_args.setup_caches = False model = _initialize_model( builder_args,