-
Notifications
You must be signed in to change notification settings - Fork 250
Added support for Multimodal eval #1499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
2aa67b4
78bdacf
bfc62dc
8900f8a
59ce657
afdb3ce
ae66baf
7721be9
e9c0d34
96ab799
1e609d8
51b0e83
842be23
51135fd
14502bf
815966c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,4 +34,4 @@ streamlit | |
flask | ||
|
||
# eval | ||
lm_eval==0.4.2 | ||
lm_eval==0.4.7 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,22 @@ | |
import torch._inductor.config | ||
import torch.distributed as dist | ||
|
||
from torchchat.distributed.utils import( | ||
from torchtune.models.convert_weights import meta_to_tune | ||
|
||
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE | ||
|
||
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune | ||
|
||
from torchtune.training import set_default_dtype | ||
|
||
from torchchat.distributed.logging_utils import SingletonLogger | ||
|
||
from torchchat.distributed.utils import ( | ||
Color as color, | ||
CUDATrackTime, | ||
init_distributed, | ||
GPUMemoryMonitor, | ||
init_distributed, | ||
) | ||
from torchchat.distributed.logging_utils import SingletonLogger | ||
|
||
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs | ||
from torchchat.model_config.model_config import resolve_model_config | ||
|
@@ -36,15 +45,6 @@ | |
from torchchat.utils.quantize import quantize_model | ||
|
||
|
||
from torchtune.models.convert_weights import meta_to_tune | ||
|
||
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE | ||
|
||
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune | ||
|
||
from torchtune.training import set_default_dtype | ||
|
||
|
||
@dataclass | ||
class BuilderArgs: | ||
checkpoint_path: Optional[Union[Path, str]] = None | ||
|
@@ -71,6 +71,7 @@ class BuilderArgs: | |
dynamic_shapes: bool = False | ||
max_seq_length: Optional[int] = None | ||
attention_backend: str = "math" | ||
modality: Optional[str] = "text" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def __post_init__(self): | ||
if self.device is None: | ||
|
@@ -146,6 +147,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": | |
aoti_package_path = getattr(args, "aoti_package_path", None) | ||
snapshot_path = getattr(args, "snapshot_path", None) | ||
|
||
modality = "text" | ||
if args.modality: | ||
modality = args.modality | ||
|
||
is_chat_model = False | ||
if args.is_chat_model: | ||
is_chat_model = True | ||
|
@@ -189,15 +194,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": | |
tp = getattr(args, "tp", 1) | ||
chpt_from = getattr(args, "chpt_from", "hf") | ||
sdp_backend_dict = { | ||
'math': torch.nn.attention.SDPBackend.MATH, | ||
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION, | ||
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, | ||
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION, | ||
"math": torch.nn.attention.SDPBackend.MATH, | ||
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION, | ||
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, | ||
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION, | ||
} | ||
attention_backend = sdp_backend_dict[args.attention_backend] | ||
if args.device == "cpu" and (args.attention_backend == "efficient_attention" | ||
or args.attention_backend == "cudnn_attention"): | ||
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.") | ||
if args.device == "cpu" and ( | ||
args.attention_backend == "efficient_attention" | ||
or args.attention_backend == "cudnn_attention" | ||
): | ||
print( | ||
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead." | ||
) | ||
attention_backend = torch.nn.attention.SDPBackend.MATH | ||
return cls( | ||
checkpoint_dir=checkpoint_dir, | ||
|
@@ -222,6 +231,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": | |
chpt_from=chpt_from, | ||
distribution_path=distribution_path, | ||
is_chat_model=is_chat_model, | ||
modality=modality, | ||
dynamic_shapes=getattr(args, "dynamic_shapes", False), | ||
max_seq_length=getattr(args, "max_seq_length", None), | ||
attention_backend=attention_backend, | ||
|
@@ -246,13 +256,29 @@ class TokenizerArgs: | |
is_sentencepiece: bool = False | ||
is_tiktoken: bool = False | ||
is_hf_tokenizer: bool = False | ||
is_llama_3_2_mm: bool = False | ||
t: Optional[Any] = None | ||
|
||
def __post_init__(self): | ||
# special handling for llama-3.2-mm | ||
if "llama-3.2-11b-vision" in str(self.tokenizer_path).lower(): | ||
try: | ||
from torchtune.models.llama3_2_vision import llama3_2_vision_transform | ||
|
||
self.t = llama3_2_vision_transform(path=str(self.tokenizer_path)) | ||
self.is_llama_3_2_mm = True | ||
self.is_tiktoken = False | ||
self.is_sentencepiece = False | ||
self.is_hf_tokenizer = False | ||
return | ||
except: | ||
pass | ||
|
||
try: | ||
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer | ||
|
||
self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path)) | ||
self.is_llama_3_2_mm = False | ||
self.is_tiktoken = True | ||
self.is_sentencepiece = False | ||
self.is_hf_tokenizer = False | ||
|
@@ -264,6 +290,7 @@ def __post_init__(self): | |
from sentencepiece import SentencePieceProcessor | ||
|
||
self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path)) | ||
self.is_llama_3_2_mm = False | ||
self.is_tiktoken = False | ||
self.is_sentencepiece = True | ||
self.is_hf_tokenizer = False | ||
|
@@ -275,13 +302,15 @@ def __post_init__(self): | |
from tokenizer.hf_tokenizer import HFTokenizer | ||
|
||
self.t = HFTokenizer(str(self.tokenizer_path)) | ||
self.is_llama_3_2_mm = False | ||
self.is_tiktoken = False | ||
self.is_sentencepiece = False | ||
self.is_hf_tokenizer = True | ||
return | ||
except: | ||
pass | ||
|
||
self.is_llama_3_2_mm = False | ||
self.is_tiktoken = False | ||
self.is_sentencepiece = False | ||
self.is_hf_tokenizer = False | ||
|
@@ -296,20 +325,32 @@ def validate_model( | |
if model is None: | ||
return | ||
|
||
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: | ||
if ( | ||
sum( | ||
[ | ||
self.is_tiktoken, | ||
self.is_hf_tokenizer, | ||
self.is_sentencepiece, | ||
self.is_llama_3_2_mm, | ||
] | ||
) | ||
!= 1 | ||
): | ||
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") | ||
|
||
is_tiktoken = self.is_tiktoken | ||
is_sentencepiece = self.is_sentencepiece | ||
is_hf_tokenizer = self.is_hf_tokenizer | ||
is_llama_3_2_mm = self.is_llama_3_2_mm | ||
|
||
use_tiktoken = model.config.use_tiktoken | ||
use_hf_tokenizer = model.config.use_hf_tokenizer | ||
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) | ||
|
||
use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer) | ||
if ( | ||
(is_tiktoken and not use_tiktoken) or | ||
(is_hf_tokenizer and not use_hf_tokenizer) or | ||
(is_sentencepiece and not use_sentencepiece) | ||
(is_tiktoken and not use_tiktoken) | ||
or (is_hf_tokenizer and not use_hf_tokenizer) | ||
or (is_sentencepiece and not use_other_tokenizer) | ||
or (is_llama_3_2_mm and not use_other_tokenizer) | ||
): | ||
raise RuntimeError( | ||
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( | ||
|
@@ -507,6 +548,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: | |
# AOTI-compoiled model will load its own weights. | ||
# Release weights here to avoid OOM | ||
import gc | ||
|
||
if hasattr(model, "model"): | ||
model.model = None | ||
gc.collect() | ||
|
@@ -564,6 +606,7 @@ def _initialize_model( | |
|
||
def do_nothing(max_batch_size, max_seq_length): | ||
pass | ||
|
||
model.setup_caches = do_nothing | ||
|
||
model.forward = torch._export.aot_load( | ||
|
@@ -601,6 +644,7 @@ def do_nothing(max_batch_size, max_seq_length): | |
|
||
def do_nothing(max_batch_size, max_seq_length): | ||
pass | ||
|
||
model.setup_caches = do_nothing | ||
|
||
model.forward = aoti_compiled_model | ||
|
@@ -675,7 +719,9 @@ def do_nothing(max_batch_size, max_seq_length): | |
logger = SingletonLogger.get_logger() | ||
|
||
gpu_memory_monitor = GPUMemoryMonitor("cuda") | ||
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") | ||
logger.info( | ||
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}" | ||
) | ||
|
||
# Model-level config | ||
if builder_args.params_table: | ||
|
@@ -686,20 +732,16 @@ def do_nothing(max_batch_size, max_seq_length): | |
config = TransformerArgs.from_params(model_config.transformer_args["text"]) | ||
logger.info(f"Transformer Config: {config}") | ||
|
||
#TODO: Move into head of file after solving circular import | ||
from torchchat.distributed.checkpoint_utils import ( | ||
load_model_weights, | ||
) | ||
# TODO: Move into head of file after solving circular import | ||
from torchchat.distributed.checkpoint_utils import load_model_weights | ||
|
||
# Validate pipeline degree | ||
assert config.n_layers % pp_degree == 0 | ||
|
||
# Create device mesh | ||
device_mesh = dist.init_device_mesh( | ||
"cuda", | ||
(pp_degree, tp_degree), | ||
mesh_dim_names=("pp", "tp") | ||
) | ||
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") | ||
) | ||
tp_mesh = device_mesh["tp"] | ||
pp_mesh = device_mesh["pp"] | ||
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") | ||
|
@@ -728,7 +770,13 @@ def do_nothing(max_batch_size, max_seq_length): | |
# Load weights | ||
logger.info(f"Loading weights for {pp_rank=} on {device=}") | ||
with CUDATrackTime() as timer: | ||
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from) | ||
load_model_weights( | ||
model, | ||
builder_args.distribution_path, | ||
device, | ||
config, | ||
builder_args.chpt_from, | ||
) | ||
|
||
logger.info( | ||
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" | ||
|
@@ -742,7 +790,7 @@ def do_nothing(max_batch_size, max_seq_length): | |
# lanes. | ||
# TODO: bump up the lane count | ||
pipeline_lanes = 1 | ||
seqlen_prefill=1024 | ||
seqlen_prefill = 1024 | ||
with device: | ||
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -137,6 +137,15 @@ def _add_model_specification_args(parser) -> None: | |||
help=argparse.SUPPRESS, | ||||
) | ||||
|
||||
model_specification_parser.add_argument( | ||||
"--modality", | ||||
type=str, | ||||
default="text", | ||||
choices=["text", "text-image"], | ||||
# help=argparse.SUPPRESS, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this arg is only used for evaluation, let's bump it into |
||||
help="Modality of the model. Options: text, text-image", | ||||
) | ||||
|
||||
|
||||
# Add CLI Args related to model configuration (compilation, quant, etc) | ||||
# Excludes compile args if subcommand is export | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beyond the scope of this PR, but the duplicated requirements in here vs requirements.txt will be collapsed when we introduce packaging