Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def setup(self):
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)

def step(self, step_idx: int, **kwargs) -> Optional[float]:
labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100
kwargs["labels"] = labels
assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape

need_update = (step_idx + 1) % self.num_microbatches == 0

ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer)
Expand Down
140 changes: 100 additions & 40 deletions applications/ColossalChat/coati/distributed/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import torch
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer

from colossalai.utils import get_current_device

from .utils import log_probs_from_logits, update_by_default

try:
import sglang as sgl
except ImportError:
Expand All @@ -22,37 +24,73 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
pass

def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
pass
"""Generate new tokens given input_ids and attention_mask.

Args:
input_ids (torch.Tensor): shape [B, S]
attention_mask (torch.Tensor): shape [B, S]

Returns:
Dict[str, torch.Tensor]: containing the
- input_ids (torch.Tensor): shape [B, S+N]
- attention_mask (torch.Tensor): shape [B, S+N]
- action_log_probs (torch.Tensor): shape [B, N]
- action_mask (torch.Tensor): shape [B, N]
where N is the number of generated tokens. And all tensors should be on CUDA.
"""

def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
pass


class TransformersInferenceBackend(BaseInferenceBackend):
DEFAULT_MODEL_CONFIG = dict(
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
FORCE_MODEL_CONFIG = dict(
device_map="auto",
)
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)

def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
path = model_config.pop("path")
defaut_config = dict(
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
defaut_config.update(model_config)
self.model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(path, **defaut_config)
self.generate_config = generate_config
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer

@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
input_len = input_ids.shape[-1]
labels = out.clone()
labels[..., :input_len] = -100
attention_mask = F.pad(attention_mask, (0, out.shape[-1] - input_len), value=1)
attention_mask = attention_mask.expand_as(labels)
new_token_ids = out.sequences[:, input_len:]
# get log probs
assert new_token_ids.shape[-1] == len(out.logits)
action_log_probs = []
for i, logits in enumerate(out.logits):
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
action_log_probs = torch.cat(action_log_probs, dim=1)
# get action mask
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
if self.tokenizer.eos_token_id is not None:
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
action_mask[indices[0], indices[1] + 1 :] = 0

if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0)

attention_mask = torch.cat((attention_mask, action_mask), dim=1)
data = {
"input_ids": out,
"input_ids": out.sequences,
"attention_mask": attention_mask,
"labels": labels,
"action_log_probs": action_log_probs,
"action_mask": action_mask,
}
return data

Expand All @@ -75,6 +113,7 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
self.tokenizer = tokenizer
self.config = AutoConfig.from_pretrained(path)

@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
outputs = self.llm.generate(input_ids=input_ids.tolist(), sampling_params=self.generate_config)
out_tokens = []
Expand Down Expand Up @@ -110,45 +149,66 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:


class VLLMInferenceBackend(BaseInferenceBackend):
DEFAULT_MODEL_CONFIG = dict(
trust_remote_code=True,
)
FORCE_GENERATE_CONFIG = dict(
logprobs=0,
)

def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path")
defaut_config = dict(
trust_remote_code=True,
# skip_tokenizer_init=True,
)
defaut_config.update(model_config)
self.llm = LLM(path, **defaut_config)
self.generate_config = SamplingParams(**generate_config, stop_token_ids=[tokenizer.eos_token_id])
self.llm = LLM(path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer
self.config = AutoConfig.from_pretrained(path)

@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
)
out_tokens = []
out_len = []
log_probs = []
for out in outputs:
out_tokens.append(list(out.outputs[0].token_ids))
out_len.append(len(out.outputs[0].token_ids))
for output_i in out.outputs:
out_len.append(len(output_i.token_ids))
out_tokens.append(list(output_i.token_ids))
assert len(output_i.logprobs) == len(output_i.token_ids)
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
log_probs.append(p)

# pad them
max_len = max(out_len)
input_len = input_ids.shape[-1]
attention_mask = F.pad(attention_mask, (0, max_len), value=1)
for i in range(len(out_tokens)):
out_tokens[i] = out_tokens[i] + [self.tokenizer.pad_token_id] * (max_len - out_len[i])
attention_mask[i, input_len + out_len[i] :] = 0
out = torch.tensor(out_tokens)
out = torch.cat((input_ids, out), dim=1)
labels = out.clone()
labels[..., :input_len] = -100
for i in range(len(out_len)):
labels[i, input_len + out_len[i] :] = -100
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)

for i, new_token_ids in enumerate(out_tokens):
pad_len = max_len - out_len[i]
out_tokens[i] = new_token_ids + [self.tokenizer.pad_token_id] * pad_len
log_probs[i] = log_probs[i] + [0.0] * pad_len
action_mask[i, out_len[i] :] = 0

out_tokens = torch.tensor(out_tokens)
log_probs = torch.tensor(log_probs)
if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
num_returns = action_mask.size(0) // attention_mask.size(0)
attention_mask = attention_mask.repeat_interleave(num_returns, dim=0)
input_ids = input_ids.repeat_interleave(num_returns, dim=0)

out_tokens = torch.cat((input_ids, out_tokens), dim=1)
attention_mask = torch.cat((attention_mask, action_mask), dim=1)

data = {
"input_ids": out,
"input_ids": out_tokens,
"attention_mask": attention_mask,
"labels": labels,
"action_log_probs": log_probs,
"action_mask": action_mask,
}
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
Expand All @@ -159,6 +219,6 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:

BACKEND_MAP = {
"transformers": TransformersInferenceBackend,
"sglang": SGLangInferenceBackend,
# "sglang": SGLangInferenceBackend, # sglang backend will stuck the process due to unknown reason
"vllm": VLLMInferenceBackend,
}
40 changes: 33 additions & 7 deletions applications/ColossalChat/coati/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Any, Dict, List

import torch

Expand All @@ -25,16 +25,42 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor


def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# compress attention_mask to save bandwidth
# compress mask to save bandwidth
if "attention_mask" in batch:
attention_mask = batch["attention_mask"]
batch["attention_mask"] = attention_mask.to(torch.bool)
batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
if "action_mask" in batch:
batch["action_mask"] = batch["action_mask"].to(torch.bool)
return batch


def post_recv(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# decompress attention_mask
# decompress mask
if "attention_mask" in batch:
attention_mask = batch["attention_mask"]
batch["attention_mask"] = attention_mask.to(torch.int)
batch["attention_mask"] = batch["attention_mask"].to(torch.int)
if "action_mask" in batch:
batch["action_mask"] = batch["action_mask"].to(torch.int)
return batch


def update_by_default(data: Dict[str, Any], default: Dict[str, Any]) -> Dict[str, Any]:
data = data.copy()
for k, v in default.items():
if k not in data:
data[k] = v
return data


def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Compute the log probabilities from logits for the given labels.

Args:
logits (torch.Tensor): The input logits.
labels (torch.Tensor): The target labels.

Returns:
torch.Tensor: The log probabilities corresponding to the labels.
"""
log_probs = torch.log_softmax(logits, dim=-1)
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return per_label_logps.squeeze(-1)