Skip to content

Added support for Multimodal eval #1498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion install/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ streamlit
flask

# eval
lm_eval==0.4.2
lm_eval==0.4.7
120 changes: 84 additions & 36 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,22 @@
import torch._inductor.config
import torch.distributed as dist

from torchchat.distributed.utils import(
from torchtune.models.convert_weights import meta_to_tune

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune

from torchtune.training import set_default_dtype

from torchchat.distributed.logging_utils import SingletonLogger

from torchchat.distributed.utils import (
Color as color,
CUDATrackTime,
init_distributed,
GPUMemoryMonitor,
init_distributed,
)
from torchchat.distributed.logging_utils import SingletonLogger

from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
from torchchat.model_config.model_config import resolve_model_config
Expand All @@ -36,15 +45,6 @@
from torchchat.utils.quantize import quantize_model


from torchtune.models.convert_weights import meta_to_tune

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune

from torchtune.training import set_default_dtype


@dataclass
class BuilderArgs:
checkpoint_path: Optional[Union[Path, str]] = None
Expand All @@ -70,6 +70,7 @@ class BuilderArgs:
dynamic_shapes: bool = False
max_seq_length: Optional[int] = None
attention_backend: str = "math"
modality: Optional[str] = "text"

def __post_init__(self):
if self.device is None:
Expand Down Expand Up @@ -143,6 +144,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
pte_path = getattr(args, "pte_path", None)
aoti_package_path = getattr(args, "aoti_package_path", None)

modality = "text"
if args.modality:
modality = args.modality

is_chat_model = False
if args.is_chat_model:
is_chat_model = True
Expand Down Expand Up @@ -185,15 +190,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
tp = getattr(args, "tp", 1)
chpt_from = getattr(args, "chpt_from", "hf")
sdp_backend_dict = {
'math': torch.nn.attention.SDPBackend.MATH,
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
"math": torch.nn.attention.SDPBackend.MATH,
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION,
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
}
attention_backend = sdp_backend_dict[args.attention_backend]
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
or args.attention_backend == "cudnn_attention"):
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
if args.device == "cpu" and (
args.attention_backend == "efficient_attention"
or args.attention_backend == "cudnn_attention"
):
print(
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead."
)
attention_backend = torch.nn.attention.SDPBackend.MATH
return cls(
checkpoint_dir=checkpoint_dir,
Expand All @@ -217,6 +226,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
chpt_from=chpt_from,
distribution_path=distribution_path,
is_chat_model=is_chat_model,
modality=modality,
dynamic_shapes=getattr(args, "dynamic_shapes", False),
max_seq_length=getattr(args, "max_seq_length", None),
attention_backend=attention_backend,
Expand All @@ -241,13 +251,29 @@ class TokenizerArgs:
is_sentencepiece: bool = False
is_tiktoken: bool = False
is_hf_tokenizer: bool = False
is_llama_3_2_mm: bool = False
t: Optional[Any] = None

def __post_init__(self):
# special handling for llama-3.2-mm
if "llama-3.2-11b-vision" in str(self.tokenizer_path).lower():
try:
from torchtune.models.llama3_2_vision import llama3_2_vision_transform

self.t = llama3_2_vision_transform(path=str(self.tokenizer_path))
self.is_llama_3_2_mm = True
self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = False
return
except:
pass

try:
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer

self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
self.is_llama_3_2_mm = False
self.is_tiktoken = True
self.is_sentencepiece = False
self.is_hf_tokenizer = False
Expand All @@ -259,6 +285,7 @@ def __post_init__(self):
from sentencepiece import SentencePieceProcessor

self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
self.is_llama_3_2_mm = False
self.is_tiktoken = False
self.is_sentencepiece = True
self.is_hf_tokenizer = False
Expand All @@ -270,13 +297,15 @@ def __post_init__(self):
from tokenizer.hf_tokenizer import HFTokenizer

self.t = HFTokenizer(str(self.tokenizer_path))
self.is_llama_3_2_mm = False
self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = True
return
except:
pass

self.is_llama_3_2_mm = False
self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = False
Expand All @@ -291,20 +320,32 @@ def validate_model(
if model is None:
return

if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
if (
sum(
[
self.is_tiktoken,
self.is_hf_tokenizer,
self.is_sentencepiece,
self.is_llama_3_2_mm,
]
)
!= 1
):
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
is_hf_tokenizer = self.is_hf_tokenizer
is_llama_3_2_mm = self.is_llama_3_2_mm

use_tiktoken = model.config.use_tiktoken
use_hf_tokenizer = model.config.use_hf_tokenizer
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)

use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer)
if (
(is_tiktoken and not use_tiktoken) or
(is_hf_tokenizer and not use_hf_tokenizer) or
(is_sentencepiece and not use_sentencepiece)
(is_tiktoken and not use_tiktoken)
or (is_hf_tokenizer and not use_hf_tokenizer)
or (is_sentencepiece and not use_other_tokenizer)
or (is_llama_3_2_mm and not use_other_tokenizer)
):
raise RuntimeError(
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
Expand Down Expand Up @@ -502,6 +543,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
# AOTI-compoiled model will load its own weights.
# Release weights here to avoid OOM
import gc

if hasattr(model, "model"):
model.model = None
gc.collect()
Expand Down Expand Up @@ -559,6 +601,7 @@ def _initialize_model(

def do_nothing(max_batch_size, max_seq_length):
pass

model.setup_caches = do_nothing

model.forward = torch._export.aot_load(
Expand Down Expand Up @@ -596,6 +639,7 @@ def do_nothing(max_batch_size, max_seq_length):

def do_nothing(max_batch_size, max_seq_length):
pass

model.setup_caches = do_nothing

model.forward = aoti_compiled_model
Expand Down Expand Up @@ -642,7 +686,9 @@ def do_nothing(max_batch_size, max_seq_length):
logger = SingletonLogger.get_logger()

gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
logger.info(
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}"
)

# Model-level config
if builder_args.params_table:
Expand All @@ -653,20 +699,16 @@ def do_nothing(max_batch_size, max_seq_length):
config = TransformerArgs.from_params(model_config.transformer_args["text"])
logger.info(f"Transformer Config: {config}")

#TODO: Move into head of file after solving circular import
from torchchat.distributed.checkpoint_utils import (
load_model_weights,
)
# TODO: Move into head of file after solving circular import
from torchchat.distributed.checkpoint_utils import load_model_weights

# Validate pipeline degree
assert config.n_layers % pp_degree == 0

# Create device mesh
device_mesh = dist.init_device_mesh(
"cuda",
(pp_degree, tp_degree),
mesh_dim_names=("pp", "tp")
)
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")
)
tp_mesh = device_mesh["tp"]
pp_mesh = device_mesh["pp"]
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
Expand Down Expand Up @@ -695,7 +737,13 @@ def do_nothing(max_batch_size, max_seq_length):
# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with CUDATrackTime() as timer:
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
load_model_weights(
model,
builder_args.distribution_path,
device,
config,
builder_args.chpt_from,
)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
Expand All @@ -709,7 +757,7 @@ def do_nothing(max_batch_size, max_seq_length):
# lanes.
# TODO: bump up the lane count
pipeline_lanes = 1
seqlen_prefill=1024
seqlen_prefill = 1024
with device:
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)

Expand Down
9 changes: 9 additions & 0 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ def _add_model_specification_args(parser) -> None:
help=argparse.SUPPRESS,
)

model_specification_parser.add_argument(
"--modality",
type=str,
default="text",
choices=["text", "text-image"],
# help=argparse.SUPPRESS,
help="Modality of the model. Options: text, text-image",
)


# Add CLI Args related to model configuration (compilation, quant, etc)
# Excludes compile args if subcommand is export
Expand Down
6 changes: 6 additions & 0 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,12 @@ def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_l
decoder_max_seq_len=decoder_max_seq_len,
)

def caches_are_setup(self) -> bool:
return self.model.caches_are_setup()

def caches_are_enabled(self) -> bool:
return self.model.caches_are_enabled()

def reset_caches(self):
self.model.reset_caches()

Expand Down
2 changes: 1 addition & 1 deletion torchchat/model_params/Llama-3.2-11B-Vision.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"model_type": "flamingo",
"use_tiktoken": true,
"use_tiktoken": false,
"encoder": {
"patch_size": 14,
"num_heads": 16,
Expand Down
Loading