|
2 | 2 | from typing import List, Optional, Tuple, Union
|
3 | 3 |
|
4 | 4 | import torch
|
| 5 | +import torch.nn.functional as F |
| 6 | +import torch.distributed as dist |
5 | 7 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
6 | 8 | from transformers.modeling_outputs import (
|
7 | 9 | BaseModelOutputWithPast,
|
|
12 | 14 | from transformers.utils import logging
|
13 | 15 |
|
14 | 16 | from colossalai.pipeline.stage_manager import PipelineStageManager
|
| 17 | +from colossalai.shardformer.shard import ShardConfig |
| 18 | +from ..layer import cross_entropy_1d |
15 | 19 |
|
16 | 20 | try:
|
17 | 21 | from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
@@ -40,6 +44,7 @@ def llama_model_forward(
|
40 | 44 | stage_manager: Optional[PipelineStageManager] = None,
|
41 | 45 | hidden_states: Optional[torch.FloatTensor] = None,
|
42 | 46 | stage_index: Optional[List[int]] = None,
|
| 47 | + shard_config: ShardConfig = None, |
43 | 48 | ):
|
44 | 49 | logger = logging.get_logger(__name__)
|
45 | 50 |
|
@@ -198,6 +203,7 @@ def llama_for_causal_lm_forward(
|
198 | 203 | stage_manager: Optional[PipelineStageManager] = None,
|
199 | 204 | hidden_states: Optional[torch.FloatTensor] = None,
|
200 | 205 | stage_index: Optional[List[int]] = None,
|
| 206 | + shard_config: ShardConfig = None |
201 | 207 | ):
|
202 | 208 | r"""
|
203 | 209 | Args:
|
@@ -267,11 +273,17 @@ def llama_for_causal_lm_forward(
|
267 | 273 | shift_labels = labels[..., 1:].contiguous()
|
268 | 274 | # Flatten the tokens
|
269 | 275 | loss_fct = CrossEntropyLoss()
|
270 |
| - shift_logits = shift_logits.view(-1, self.config.vocab_size) |
271 | 276 | shift_labels = shift_labels.view(-1)
|
272 | 277 | # Enable model parallelism
|
273 | 278 | shift_labels = shift_labels.to(shift_logits.device)
|
274 |
| - loss = loss_fct(shift_logits, shift_labels) |
| 279 | + if shard_config.enable_tensor_parallelism: |
| 280 | + new_vocab_size = logits.shape[-1] |
| 281 | + shift_logits = shift_logits.view(-1, new_vocab_size) |
| 282 | + loss = cross_entropy_1d(shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group) |
| 283 | + else: |
| 284 | + shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| 285 | + loss = loss_fct(shift_logits, shift_labels) |
| 286 | + |
275 | 287 |
|
276 | 288 | if not return_dict:
|
277 | 289 | output = (logits,) + outputs[1:]
|
@@ -304,6 +316,7 @@ def llama_for_sequence_classification_forward(
|
304 | 316 | stage_manager: Optional[PipelineStageManager] = None,
|
305 | 317 | hidden_states: Optional[torch.FloatTensor] = None,
|
306 | 318 | stage_index: Optional[List[int]] = None,
|
| 319 | + shard_config: ShardConfig = None, |
307 | 320 | ):
|
308 | 321 | r"""
|
309 | 322 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
@@ -476,3 +489,106 @@ def forward(
|
476 | 489 | return attn_output, None, past_key_value
|
477 | 490 |
|
478 | 491 | return forward
|
| 492 | + |
| 493 | + |
| 494 | +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): |
| 495 | + from transformers import LlamaForCausalLM |
| 496 | + |
| 497 | + def forward( |
| 498 | + self: LlamaForCausalLM, |
| 499 | + input_ids: torch.LongTensor = None, |
| 500 | + attention_mask: Optional[torch.Tensor] = None, |
| 501 | + position_ids: Optional[torch.LongTensor] = None, |
| 502 | + past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 503 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 504 | + labels: Optional[torch.LongTensor] = None, |
| 505 | + use_cache: Optional[bool] = None, |
| 506 | + output_attentions: Optional[bool] = None, |
| 507 | + output_hidden_states: Optional[bool] = None, |
| 508 | + return_dict: Optional[bool] = None, |
| 509 | + ) -> Union[Tuple, CausalLMOutputWithPast]: |
| 510 | + r""" |
| 511 | + Args: |
| 512 | + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 513 | + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| 514 | + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| 515 | + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| 516 | +
|
| 517 | + Returns: |
| 518 | +
|
| 519 | + Example: |
| 520 | +
|
| 521 | + ```python |
| 522 | + >>> from transformers import AutoTokenizer, LlamaForCausalLM |
| 523 | +
|
| 524 | + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
| 525 | + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
| 526 | +
|
| 527 | + >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| 528 | + >>> inputs = tokenizer(prompt, return_tensors="pt") |
| 529 | +
|
| 530 | + >>> # Generate |
| 531 | + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| 532 | + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| 533 | + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| 534 | + ```""" |
| 535 | + |
| 536 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 537 | + output_hidden_states = ( |
| 538 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 539 | + ) |
| 540 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 541 | + |
| 542 | + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
| 543 | + outputs = self.model( |
| 544 | + input_ids=input_ids, |
| 545 | + attention_mask=attention_mask, |
| 546 | + position_ids=position_ids, |
| 547 | + past_key_values=past_key_values, |
| 548 | + inputs_embeds=inputs_embeds, |
| 549 | + use_cache=use_cache, |
| 550 | + output_attentions=output_attentions, |
| 551 | + output_hidden_states=output_hidden_states, |
| 552 | + return_dict=return_dict, |
| 553 | + ) |
| 554 | + |
| 555 | + hidden_states = outputs[0] |
| 556 | + if self.config.pretraining_tp > 1: |
| 557 | + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) |
| 558 | + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] |
| 559 | + logits = torch.cat(logits, dim=-1) |
| 560 | + else: |
| 561 | + logits = self.lm_head(hidden_states) |
| 562 | + logits = logits.float() |
| 563 | + |
| 564 | + loss = None |
| 565 | + if labels is not None: |
| 566 | + # Shift so that tokens < n predict n |
| 567 | + shift_logits = logits[..., :-1, :].contiguous() |
| 568 | + shift_labels = labels[..., 1:].contiguous() |
| 569 | + # Flatten the tokens |
| 570 | + loss_fct = CrossEntropyLoss() |
| 571 | + shift_labels = shift_labels.view(-1) |
| 572 | + # Enable model parallelism |
| 573 | + shift_labels = shift_labels.to(shift_logits.device) |
| 574 | + if shard_config.enable_tensor_parallelism: |
| 575 | + new_vocab_size = logits.shape[-1] |
| 576 | + shift_logits = shift_logits.view(-1, new_vocab_size) |
| 577 | + loss = cross_entropy_1d(shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group) |
| 578 | + else: |
| 579 | + shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| 580 | + loss = loss_fct(shift_logits, shift_labels) |
| 581 | + |
| 582 | + |
| 583 | + if not return_dict: |
| 584 | + output = (logits,) + outputs[1:] |
| 585 | + return (loss,) + output if loss is not None else output |
| 586 | + |
| 587 | + return CausalLMOutputWithPast( |
| 588 | + loss=loss, |
| 589 | + logits=logits, |
| 590 | + past_key_values=outputs.past_key_values, |
| 591 | + hidden_states=outputs.hidden_states, |
| 592 | + attentions=outputs.attentions, |
| 593 | + ) |
| 594 | + return forward |
0 commit comments