diff --git a/install/requirements.txt b/install/requirements.txt index bd1e09174..73be68763 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -34,4 +34,4 @@ streamlit flask # eval -lm_eval==0.4.2 +lm_eval==0.4.7 diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index a5b23dfe3..2fe163023 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,13 +16,22 @@ import torch._inductor.config import torch.distributed as dist -from torchchat.distributed.utils import( +from torchtune.models.convert_weights import meta_to_tune + +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE + +from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune + +from torchtune.training import set_default_dtype + +from torchchat.distributed.logging_utils import SingletonLogger + +from torchchat.distributed.utils import ( Color as color, CUDATrackTime, - init_distributed, GPUMemoryMonitor, + init_distributed, ) -from torchchat.distributed.logging_utils import SingletonLogger from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs from torchchat.model_config.model_config import resolve_model_config @@ -36,15 +45,6 @@ from torchchat.utils.quantize import quantize_model -from torchtune.models.convert_weights import meta_to_tune - -from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE - -from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune - -from torchtune.training import set_default_dtype - - @dataclass class BuilderArgs: checkpoint_path: Optional[Union[Path, str]] = None @@ -70,6 +70,7 @@ class BuilderArgs: dynamic_shapes: bool = False max_seq_length: Optional[int] = None attention_backend: str = "math" + modality: Optional[str] = "text" def __post_init__(self): if self.device is None: @@ -143,6 +144,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": pte_path = getattr(args, "pte_path", None) aoti_package_path = getattr(args, "aoti_package_path", None) + modality = "text" + if args.modality: + modality = args.modality + is_chat_model = False if args.is_chat_model: is_chat_model = True @@ -185,15 +190,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": tp = getattr(args, "tp", 1) chpt_from = getattr(args, "chpt_from", "hf") sdp_backend_dict = { - 'math': torch.nn.attention.SDPBackend.MATH, - 'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION, - 'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, - 'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + "math": torch.nn.attention.SDPBackend.MATH, + "flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION, + "efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + "cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION, } attention_backend = sdp_backend_dict[args.attention_backend] - if args.device == "cpu" and (args.attention_backend == "efficient_attention" - or args.attention_backend == "cudnn_attention"): - print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.") + if args.device == "cpu" and ( + args.attention_backend == "efficient_attention" + or args.attention_backend == "cudnn_attention" + ): + print( + f"Warning: {args.attention_backend} is not supported on CPU. Using math instead." + ) attention_backend = torch.nn.attention.SDPBackend.MATH return cls( checkpoint_dir=checkpoint_dir, @@ -217,6 +226,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": chpt_from=chpt_from, distribution_path=distribution_path, is_chat_model=is_chat_model, + modality=modality, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), attention_backend=attention_backend, @@ -241,13 +251,29 @@ class TokenizerArgs: is_sentencepiece: bool = False is_tiktoken: bool = False is_hf_tokenizer: bool = False + is_llama_3_2_mm: bool = False t: Optional[Any] = None def __post_init__(self): + # special handling for llama-3.2-mm + if "llama-3.2-11b-vision" in str(self.tokenizer_path).lower(): + try: + from torchtune.models.llama3_2_vision import llama3_2_vision_transform + + self.t = llama3_2_vision_transform(path=str(self.tokenizer_path)) + self.is_llama_3_2_mm = True + self.is_tiktoken = False + self.is_sentencepiece = False + self.is_hf_tokenizer = False + return + except: + pass + try: from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path)) + self.is_llama_3_2_mm = False self.is_tiktoken = True self.is_sentencepiece = False self.is_hf_tokenizer = False @@ -259,6 +285,7 @@ def __post_init__(self): from sentencepiece import SentencePieceProcessor self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path)) + self.is_llama_3_2_mm = False self.is_tiktoken = False self.is_sentencepiece = True self.is_hf_tokenizer = False @@ -270,6 +297,7 @@ def __post_init__(self): from tokenizer.hf_tokenizer import HFTokenizer self.t = HFTokenizer(str(self.tokenizer_path)) + self.is_llama_3_2_mm = False self.is_tiktoken = False self.is_sentencepiece = False self.is_hf_tokenizer = True @@ -277,6 +305,7 @@ def __post_init__(self): except: pass + self.is_llama_3_2_mm = False self.is_tiktoken = False self.is_sentencepiece = False self.is_hf_tokenizer = False @@ -291,20 +320,32 @@ def validate_model( if model is None: return - if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: + if ( + sum( + [ + self.is_tiktoken, + self.is_hf_tokenizer, + self.is_sentencepiece, + self.is_llama_3_2_mm, + ] + ) + != 1 + ): raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece is_hf_tokenizer = self.is_hf_tokenizer + is_llama_3_2_mm = self.is_llama_3_2_mm + use_tiktoken = model.config.use_tiktoken use_hf_tokenizer = model.config.use_hf_tokenizer - use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) - + use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer) if ( - (is_tiktoken and not use_tiktoken) or - (is_hf_tokenizer and not use_hf_tokenizer) or - (is_sentencepiece and not use_sentencepiece) + (is_tiktoken and not use_tiktoken) + or (is_hf_tokenizer and not use_hf_tokenizer) + or (is_sentencepiece and not use_other_tokenizer) + or (is_llama_3_2_mm and not use_other_tokenizer) ): raise RuntimeError( "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( @@ -502,6 +543,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: # AOTI-compoiled model will load its own weights. # Release weights here to avoid OOM import gc + if hasattr(model, "model"): model.model = None gc.collect() @@ -559,6 +601,7 @@ def _initialize_model( def do_nothing(max_batch_size, max_seq_length): pass + model.setup_caches = do_nothing model.forward = torch._export.aot_load( @@ -596,6 +639,7 @@ def do_nothing(max_batch_size, max_seq_length): def do_nothing(max_batch_size, max_seq_length): pass + model.setup_caches = do_nothing model.forward = aoti_compiled_model @@ -642,7 +686,9 @@ def do_nothing(max_batch_size, max_seq_length): logger = SingletonLogger.get_logger() gpu_memory_monitor = GPUMemoryMonitor("cuda") - logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") + logger.info( + f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}" + ) # Model-level config if builder_args.params_table: @@ -653,20 +699,16 @@ def do_nothing(max_batch_size, max_seq_length): config = TransformerArgs.from_params(model_config.transformer_args["text"]) logger.info(f"Transformer Config: {config}") - #TODO: Move into head of file after solving circular import - from torchchat.distributed.checkpoint_utils import ( - load_model_weights, - ) + # TODO: Move into head of file after solving circular import + from torchchat.distributed.checkpoint_utils import load_model_weights # Validate pipeline degree assert config.n_layers % pp_degree == 0 # Create device mesh device_mesh = dist.init_device_mesh( - "cuda", - (pp_degree, tp_degree), - mesh_dim_names=("pp", "tp") - ) + "cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") + ) tp_mesh = device_mesh["tp"] pp_mesh = device_mesh["pp"] logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") @@ -695,7 +737,13 @@ def do_nothing(max_batch_size, max_seq_length): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") with CUDATrackTime() as timer: - load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from) + load_model_weights( + model, + builder_args.distribution_path, + device, + config, + builder_args.chpt_from, + ) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" @@ -709,7 +757,7 @@ def do_nothing(max_batch_size, max_seq_length): # lanes. # TODO: bump up the lane count pipeline_lanes = 1 - seqlen_prefill=1024 + seqlen_prefill = 1024 with device: model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 70f404635..63bf224a3 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -137,6 +137,15 @@ def _add_model_specification_args(parser) -> None: help=argparse.SUPPRESS, ) + model_specification_parser.add_argument( + "--modality", + type=str, + default="text", + choices=["text", "text-image"], + # help=argparse.SUPPRESS, + help="Modality of the model. Options: text, text-image", + ) + # Add CLI Args related to model configuration (compilation, quant, etc) # Excludes compile args if subcommand is export diff --git a/torchchat/model.py b/torchchat/model.py index ce7dcb5e4..9722ca240 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -608,6 +608,12 @@ def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_l decoder_max_seq_len=decoder_max_seq_len, ) + def caches_are_setup(self) -> bool: + return self.model.caches_are_setup() + + def caches_are_enabled(self) -> bool: + return self.model.caches_are_enabled() + def reset_caches(self): self.model.reset_caches() diff --git a/torchchat/model_params/Llama-3.2-11B-Vision.json b/torchchat/model_params/Llama-3.2-11B-Vision.json index 5232e3512..b9b66ba94 100644 --- a/torchchat/model_params/Llama-3.2-11B-Vision.json +++ b/torchchat/model_params/Llama-3.2-11B-Vision.json @@ -1,6 +1,6 @@ { "model_type": "flamingo", - "use_tiktoken": true, + "use_tiktoken": false, "encoder": { "patch_size": 14, "num_heads": 16, diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index b708e5840..22286293b 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse -from typing import Callable, Optional +from typing import Callable, Dict, List, Optional import torch import torch._dynamo.config @@ -30,9 +30,24 @@ import lm_eval +import PIL + from lm_eval.evaluator import evaluate +from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.models.huggingface import HFLM as eval_wrapper from lm_eval.tasks import get_task_dict +from torchtune import utils +from torchtune.data import ( + format_content_with_images, + left_pad_sequence, + Message, + padded_collate_tiled_images_and_mask, +) +from torchtune.generation import generate, sample + +from torchtune.modules.common_utils import local_kv_cache +from torchtune.modules.model_fusion import DeepFusionModel +from torchtune.modules.transforms import Transform def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( @@ -89,7 +104,7 @@ def __init__( device="cpu", is_pte_model: bool = False, ): - super().__init__(device=device) + super().__init__(pretrained="gpt2", device=device) self._model = model self._model_forward = ( model_forward @@ -168,6 +183,250 @@ def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented") +class VLMEvalWrapper(HFMultimodalLM): + """An EvalWrapper for EleutherAI's eval harness based on gpt-fast's + EvalWrapper: https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py. + + Note: + This is ONLY for vision-language models. + + Args: + model (DeepFusionModel): The VLM to evaluate. + transform (Transform): The transform (tokenizer) to use for preprocessing. + device (torch.device): The device to use. + max_seq_length (int): The maximum sequence length. + batch_size (int): The batch size. + dtype (torch.dtype): dtype for the model caches during generation. + enable_kv_cache (bool): Whether to enable KV cache for generation. + image_tag (str): The string to use for the image token. Default is "", which + is the default used by the MMMU dataset. + max_images_per_sample (int): The maximum number of images per sample. Defaults to + the max number of images in MMMU. + """ + + def __init__( + self, + model: DeepFusionModel, + transform: Transform, + *, + device: torch.device, + max_seq_length: int = 4096, + batch_size: int = 1, + dtype: torch.dtype = torch.bfloat16, + enable_kv_cache: bool = True, + # TODO (@joecummings): Update these defaults once more multimodal + # tasks are added to the eval harness + image_tag: str = "", + max_images_per_sample: int = 7, + ): + self._model = model + self._transform = transform + self._device = device + self._max_seq_length = max_seq_length + self._batch_size = batch_size + self._dtype = dtype + # Defaulting KV cache to True for multimodal + self._enable_kv_cache = True + self._image_tag = image_tag + self._max_images_per_sample = max_images_per_sample + self.times = [] + + @property + def model(self): + # Not actually changing the dtype here, just adding it as a + # property on the model + self._model.dtype = self._dtype + return self._model + + @property + def model_transform(self): + return self._transform + + @property + def device(self): + return self._device + + @property + def cache_hook(self): + # Dummy class to appease the Harness + class DummyCacheHook: + def __init__(self): + self.add_partial = lambda x, y, z: True + + return DummyCacheHook() + + @property + def rank(self): + # Hardcoded for now b/c we only support single GPU eval + return 0 + + @property + def world_size(self): + # Hardcoded for now b/c we only support single GPU eval + return 1 + + @property + def batch_size(self): + return self._batch_size + + @property + def eos_token_id(self): + return self._transform.tokenizer.eos_id + + @property + def eot_token_id(self): + return self._transform.tokenizer.eot_id + + @property + def max_length(self): + return self._max_seq_length + + @property + def truncation(self): + return True + + def tok_encode(self, string, **kwargs) -> List[int]: + # This is only used to get a number of tokens for use in sorting samples in dataset + # These values will not actually be used for eval + return self._transform.tokenizer.encode(string, add_bos=False, add_eos=False) + + def tok_decode(self, tokens, skip_special_tokens=True) -> str: + if isinstance(tokens, int): + tokens = [tokens] + return self._transform.tokenizer.decode( + tokens, skip_special_tokens=skip_special_tokens + ) + + def tok_batch_multimodal_encode( + self, + all_texts: List[str], + all_images: List[List[PIL.Image.Image]], + left_truncate_len: int = None, + *args, + **kwargs, + ): + # Eleuther already parses out the text and images, so we just need to get + # it into a Message format for our tokenizer + all_encoded_messages = [] + + for text, images in zip(all_texts, all_images): + # Ensure images are all RGB + proper_images = [] + for image in images: + if image.mode != "RGB": + image = image.convert("RGB") + proper_images.append(image) + + # Construct the messages + messages = [] + content = format_content_with_images( + text, image_tag=self._image_tag, images=proper_images + ) + messages.append(Message(role="user", content=content)) + messages.append(Message(role="assistant", content="")) + + # Transform the messages + tok_batch = self.model_transform({"messages": messages}, inference=True) + all_encoded_messages.append(tok_batch) + + # Pad the encoded messages + tok_batch = padded_collate_tiled_images_and_mask( + all_encoded_messages, + pad_direction="left", + pad_max_images=self._max_images_per_sample, + pad_max_tiles=self._transform.max_num_tiles, + ) + utils.batch_to_device(tok_batch, self.device) + + # Convert the batch to the format expected by the HF + tok_batch["input_ids"] = tok_batch.pop("tokens") + + # the harness will use left_truncate_len to indicate that the current batch + # needs to be truncated to self.max_seq_len - self.max_gen_toks + if left_truncate_len is not None: + tok_batch["input_ids"] = tok_batch["input_ids"][:, -left_truncate_len:] + + return tok_batch + + @torch.inference_mode() + def _model_multimodal_generate( + self, + batch: Dict[str, torch.Tensor], + max_length: int, + stop: List[str], + **generation_kwargs, + ): + # 1. Validate inputs + prompt = batch.pop("input_ids") + bsz, seq_len = prompt.shape + + temperature = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", False) + if do_sample or temperature != 0.0: + raise RuntimeError( + "Any decoding strategy other than greedy is not supported." + ) + + if bsz > 1: + raise ValueError( + f"Got a batch size of '{bsz}'. Batch size > 1 is not yet supported for " + "multimodal generation." + ) + + encoder_max_seq_len = ( + self.model_transform.image_seq_len * self._max_images_per_sample + ) + # Setup masks for bsz 1 + with self.device: + causal_mask = torch.tril( + torch.ones( + size=(self.max_length, self.max_length), + dtype=torch.bool, + ) + ) + input_pos = torch.arange(self.max_length) + + batch["input_pos"] = input_pos[None, :seq_len] + batch["mask"] = causal_mask[None, :seq_len] + + with measure_time(message=None) as measure: + # 2. Setup KV cache + with local_kv_cache( + self.model, + batch_size=self.batch_size, + device=self.device, + dtype=self._dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=self.max_length, + ): + # 3. Prefill step + generated_tokens = [] + logits = self.model(prompt, **batch)[:, -1] + token = sample(logits, temperature=0.0, top_k=None) + generated_tokens.append(token.item()) + + cache_mask = batch["encoder_mask"][:, -1:] + + # 4. Continue generating + for _ in range(max_length): + if token.item() in self.model_transform.stop_tokens: + break + logits = self.model( + token, + mask=causal_mask[None, seq_len, None, :], + encoder_input=None, + encoder_mask=cache_mask, + input_pos=input_pos[None, seq_len], + )[:, -1] + token = sample(logits, temperature=0.0, top_k=None) + generated_tokens.append(token.item()) + seq_len += 1 + self.times.append(measure.get_time()) + + # 5. Return generated tokens + return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0) + + @torch.no_grad() def eval( model: Model, @@ -223,6 +482,57 @@ def eval( return eval_results +def multi_model_eval( + model: Model, + model_forward: Callable, + tokenizer, + tasks: Optional[list] = None, + limit: Optional[int] = None, + max_seq_length: Optional[int] = None, + device: str = "cpu", + is_pte_model: bool = False, +): + """ + Evaluates a language model on a specified task using the lm-evaluation-harness library. + + Args: + model (Model): The pre-trained language model to evaluate. + tokenizer: The tokenizer to use for encoding/decoding text. + tasks (Optional[list]): The names of the evaluation tasks to perform. + limit (Optional[int]): The maximum number of samples to evaluate (None for all available). + max_seq_length (Optional[int]): The maximum sequence length allowed for input text. + + Returns: + eval_results (dict): A dictionary of evaluation results for the specified task(s). + """ + if tasks is None: + tasks = ["wikitext"] + max_seq_length = 4096 if max_seq_length is None else max_seq_length + device = utils.get_device(device) if isinstance(device, str) else device + + model_eval_wrapper = VLMEvalWrapper( + model, + transform=tokenizer, # tranform is the tokenizer for multimodal models + max_seq_length=max_seq_length, + device=device, + ) + + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + task_dict = get_task_dict(tasks) + + eval_results = evaluate( + model_eval_wrapper, + task_dict, + limit=limit, + ) + eval_results["times"] = model_eval_wrapper.times + return eval_results + + def main(args) -> None: """Evaluates model on a task from the `lm-evaluation-harness` library. @@ -244,6 +554,13 @@ def main(args) -> None: compile = args.compile max_seq_length = args.max_seq_length + modality = builder_args.modality + print(f"Modality of model={modality}") + assert modality in [ + "text", + "text-image", + ], "Only text and text-image modality is supported for evaluation" + print(f"Using device={device}") set_precision(builder_args.precision) @@ -260,24 +577,41 @@ def main(args) -> None: if compile: assert not ( - builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path + builder_args.dso_path + or builder_args.pte_path + or builder_args.aoti_package_path ), "cannot compile exported model" model_forward = torch.compile( model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True ) - torch._inductor.config.coordinate_descent_tuning = False if device == "cpu" else True + torch._inductor.config.coordinate_descent_tuning = ( + False if device == "cpu" else True + ) with measure_time("Time to run eval: {time:.02f}s."): - result = eval( - model.to(device), - model_forward, - tokenizer, - tasks, - limit, - max_seq_length, - device=builder_args.device, - is_pte_model=builder_args.pte_path is not None, - ) + if modality == "text": + result = eval( + model.to(device), + model_forward, + tokenizer, + tasks, + limit, + max_seq_length, + device=builder_args.device, + is_pte_model=builder_args.pte_path is not None, + ) + elif modality == "text-image": + result = multi_model_eval( + model.to(device), + model_forward, + tokenizer, + tasks, + limit, + max_seq_length, + device=builder_args.device, + ) + else: + raise ValueError(f"Unsupported modality: {modality}") times = torch.tensor(result["times"]) print(