Skip to content

Pull change from upstream #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 153 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
153 commits
Select commit Hold shift + click to select a range
b371f8f
Adding fine-tuning support for Llama on the client side (#657)
minjiaz Aug 1, 2023
1a0b896
Add step 2 sweep script, clean up scripts (#664)
lekurile Aug 2, 2023
97a36bc
fix only optimize lora and ack-ckpting compatible (#658)
yaozhewei Aug 3, 2023
d652031
Refactor load_hf_tokenizer for use with all 3 steps. (#666)
awan-10 Aug 3, 2023
258182f
Update Stable Diffusion Example (#662)
lekurile Aug 3, 2023
835af7c
Update DS Chat transformers req (#667)
lekurile Aug 3, 2023
d2373cd
Fix output arg in step 3 llama script (#668)
lekurile Aug 3, 2023
c7bcbc8
Fix typos in README.md (#635)
Embracing Aug 7, 2023
0e0e2ef
Update CIFAR example with changes required in torch 1.13 (#676)
loadams Aug 7, 2023
04b1036
Update inference test and add ds-hf-compare (#671)
molly-smith Aug 8, 2023
042443d
add deepspeed chat arxiv report (#679)
conglongli Aug 8, 2023
3952c9a
10-20x faster load checkpoint (for critic/reward model) (#675)
awan-10 Aug 9, 2023
ef87284
Fix for cannot import name 'LlamaTokenizerFast' (#647) (#648)
jagane-infinstor Aug 9, 2023
f8e25c5
Update script location and docs for all 3 steps (#681)
awan-10 Aug 10, 2023
c984c8d
Add test mode args to DS Chat step 3 (#682)
lekurile Aug 10, 2023
78a8e2d
Add Hybrid Engine mode + move arguments to a file (#684)
awan-10 Aug 11, 2023
f98077d
Add LoRA LR for DS Chat steps 1-3 (#685)
lekurile Aug 14, 2023
c94321a
Enable non-CUDA device for CIFAR10 and HelloDeepSpeed training exampl…
delock Aug 15, 2023
80eb89f
DS Chat Step 3 Unit Test (#677)
lekurile Aug 15, 2023
a143850
Change run_training_test.py to test_training.py (#686)
lekurile Aug 16, 2023
7a21aea
Add minimalistic naming convention for tests
lekurile Aug 16, 2023
3daac59
Remove skipping of Z3 case in DS Chat step 3 training (#687)
lekurile Aug 16, 2023
30e6735
Add DS-Chat CI badge and documentation (#697)
lekurile Aug 17, 2023
927690e
Update requirements.txt
awan-10 Aug 17, 2023
9186347
Rename Step 3 DS-Chat args for clarification (#698)
lekurile Aug 17, 2023
32083e5
Mixed Precision ZeRO++ (#689)
HeyangQin Aug 18, 2023
81a8521
add timers and performance metrics (#688)
awan-10 Aug 22, 2023
bd1c82c
Fix calculations (include critic model) for performance (#706)
awan-10 Aug 25, 2023
845ac3e
add mixz script (#711)
HeyangQin Aug 31, 2023
4355784
add updates for the new release. (#712)
awan-10 Aug 31, 2023
6c83ac2
Update README.md (#714)
awan-10 Aug 31, 2023
20f0a85
Skip hierarchical partitioning ZeRO (hpZ) for single node (#717)
HeyangQin Sep 5, 2023
44ade96
Update README.md (#727)
NinoRisteski Sep 11, 2023
7526d55
Update README.md (#726)
NinoRisteski Sep 11, 2023
12f78ec
ZeRO-Inference refresh (#722)
tjruwase Sep 11, 2023
902a0f6
Update README.md to add python formatting in code examples (#729)
imJunaidAfzal Sep 12, 2023
0b30bcb
FlexGen reference (#730)
tjruwase Sep 13, 2023
bae2afb
Update Llama check to use module instead of model (#734)
lekurile Sep 14, 2023
27b60d2
Filter transformers version 4.33.2 due to bug (#735)
lekurile Sep 19, 2023
9c94044
Update README.md (#739)
NinoRisteski Sep 22, 2023
db56381
support DeepSpeedChat to run on different device besides cuda (#736)
ys950902 Sep 22, 2023
9b3d898
support bf16 for RLHF training (#733)
ys950902 Sep 22, 2023
d8f3f73
deepspeed-chat: support any model in chatbot (#744)
mosheisland Oct 2, 2023
58e4e9c
Fix padding and dtype issues (#738)
tjruwase Oct 2, 2023
2f99dcd
deepspeed-chat: handle overflow for bf16_optimizer (#745)
mosheisland Oct 3, 2023
4bf1924
deepspeed-chat: support explicit configuration of dropout (#746)
mosheisland Oct 3, 2023
ca03bd7
deepspeed-chat: fix incorrect lr when using lora only (#756)
mosheisland Oct 3, 2023
0d11c63
Add default value for tokenizer path (#699)
xu-song Oct 3, 2023
ca41e8b
support `trust_remote_code` in inference test (#709)
wangruohui Oct 3, 2023
6c05e03
Deepspeed-VisualChat (#753)
yaozhewei Oct 3, 2023
4364031
Update README.md (#757)
xiaoxiawu-microsoft Oct 3, 2023
e6f400a
deepspeed-chat: calculate loss in fp32 (#754)
mosheisland Oct 4, 2023
bfad08f
deepspeed-chat: support periodic eval in stage2 (#747)
mosheisland Oct 4, 2023
10aef97
add the path to load the local dataset (#761)
ys950902 Oct 8, 2023
0855679
Fix typo (#749)
xu-song Oct 9, 2023
1ba50ed
Resolving epochs being hard-coded (#759)
PareesaMS Oct 11, 2023
3517c6d
Resolves the issue with evaluation on step2 for single GPU (#766)
PareesaMS Oct 12, 2023
5161c0f
deepspeed-chat: train v_head when only optimizing lora (#758)
mosheisland Oct 16, 2023
8d850ba
deepspeed-chat: fix weight decay configuration (#755)
mosheisland Oct 16, 2023
185e25c
deepspeed-chat: fix bf16 stage2 accuracy for bloom-560m (#772)
mosheisland Oct 17, 2023
f7ff9dd
deepspeed-chat: fix training stage1 ppl calculation (#773)
mosheisland Oct 17, 2023
e8d879e
deepspeed-chat: add end-of-text special token (#775)
mosheisland Oct 17, 2023
737c674
Update requirement.txt (#789)
xiaoxiawu-microsoft Oct 23, 2023
bb6eb4d
deepspeed-chat: support print answers interval (#781)
mosheisland Oct 31, 2023
70a24ed
deepspeed-chat: fix compute_fp32_loss with llama (#784)
mosheisland Oct 31, 2023
86ccf61
deepspeed-chat: handle stage3 generate too short (#778)
mosheisland Oct 31, 2023
f52b725
deepspeed-chat: display reward ema in stage3 (#779)
mosheisland Nov 1, 2023
0e1bb1f
Mii examples (#181) (#797)
mrwyattii Nov 3, 2023
60e412e
Update MII Example (#798)
mrwyattii Nov 3, 2023
e86d0c6
Refactor deepspeed-chat into a python package. (#731)
microsoft-fevieira Nov 6, 2023
4dbff68
Update transformers dependency in deepspeed-chat install (#802)
lekurile Nov 6, 2023
ff0e254
fix: using DistributedSampler when evaluating the reward model (#804)
xffxff Nov 7, 2023
089baad
Add benchmark scripts for DeepSpeed-FastGen (#805)
tohtana Nov 8, 2023
fe7a76d
Fix SD example imports for latest diffusers (#806)
lekurile Nov 8, 2023
ccb2a34
Adding Imagenet Example (#680)
PareesaMS Nov 8, 2023
09af71a
Adds script as an example of a run of DS-FastGen (#810)
PareesaMS Nov 17, 2023
8c551d2
deepspeed-chat: filter stage3 too long prompts (#782)
mosheisland Nov 21, 2023
b116838
update MII benchmark to reflect changes in output type (#812)
mrwyattii Nov 21, 2023
0e10c4b
Adding LoRA-Distillation SD training example (#788)
PareesaMS Dec 4, 2023
dd0f181
Correction training script filename in README and Fix Bug for Step Ru…
UEFI-code Dec 11, 2023
8e4cdd8
Improve Comms Benchmark Timing (#833)
Quentin-Anthony Dec 20, 2023
abd7502
fix: typo in sa (#838)
A-Cepheus Jan 3, 2024
ff9a023
Update README.md (#827)
chinainfant Jan 3, 2024
05120bb
Update MII Inference Examples (#837)
mrwyattii Jan 10, 2024
6c31d8d
Modify codes so that different accelerators can be called according t…
foin6 Jan 11, 2024
57dd8fb
deepspeed-chat: Support zero3 params initialization in the last LN (#…
deepcharm Jan 17, 2024
8216f5f
Generalize MII benchmark for any model (#851)
mrwyattii Jan 19, 2024
107681e
[Example] Refactor and Polish Cifar10-DeepSpeed Code Example. (#843)
keli-wen Jan 26, 2024
6863634
Not a bug, just missing a space in README.md (#857)
stceum Feb 1, 2024
19e0efb
Fix errors of AttributeError: 'str' object has no attribute 'stdout' …
mlzoo Feb 1, 2024
b338d1e
Control the kernel injection with new argument. And compare the outpu…
foin6 Feb 5, 2024
48177db
Different accelerators can be called according to specific device con…
foin6 Feb 9, 2024
0b1ea40
Add Human Eval Example (#856)
lekurile Feb 22, 2024
0ac02da
Fix path in human-eval example README (#862)
lekurile Feb 22, 2024
6540db6
<fill-mask>Modify codes so that different accelerators can be called …
foin6 Feb 26, 2024
8182a8b
Extend FastGen benchmark to use AML endpoints (#865)
mrwyattii Feb 29, 2024
ffb8a4b
catch AML error response, add aml script (#869)
mrwyattii Mar 1, 2024
b7ec5c3
Remove AML key from args dict when saving results (#870)
lekurile Mar 6, 2024
6e9ada6
Update Inference Benchmarking Scripts - Support AML (#868)
lekurile Mar 6, 2024
f415ec8
Xiaoxia/fp v1 (#871)
xiaoxiawu-microsoft Mar 8, 2024
b0a5533
Fix AML benchmark E2E measurment (#874)
mrwyattii Mar 14, 2024
c3ffec2
Update README.md
awan-10 Mar 18, 2024
18200d5
Improve robustness of infernece AML benchmark (#875)
HeyangQin Mar 19, 2024
279a8fe
change kwargs for AML call to match vllm kwargs (#876)
mrwyattii Mar 19, 2024
02fc578
dynamic setting of requst num and formatting (#880)
mrwyattii Mar 27, 2024
df7119e
Fix response check in call_aml function (#882)
HeyangQin Mar 29, 2024
fab5d06
Update throughput-latency plot script (#881)
lekurile Apr 9, 2024
1be0fc7
updating tokens per second to include the token count of generated to…
guptha23 Apr 29, 2024
fdb8ee2
Update tokens_per_sec calculation to work w/ stream and non-stream ca…
lekurile Apr 30, 2024
cce6223
fix bug with queue.empty not being reliable (#898)
mrwyattii May 1, 2024
75df1d7
add client-only mode to mii benchmark (#900)
delock Jun 6, 2024
bbab278
Refactored LLM benchmark code (#899)
mrwyattii Jun 26, 2024
b04fedd
Enable cpu/xpu support for the benchmarking suite (#905)
louie-tsai Aug 14, 2024
8d91a5a
Update README.md (#916)
keshavkowshik Aug 16, 2024
9563904
Add openai client to deepspeedometer (#913)
delock Aug 21, 2024
0d40b31
DeepNVMe example scripts (#914)
tjruwase Aug 21, 2024
957ae31
DeepNVMe README.md add xref (#919)
stas00 Aug 24, 2024
1293d45
extend max_prompt_length and input text for 128k evaluation (#891)
HeyangQin Sep 3, 2024
c961379
Update requirements for opencv-python CVE (#925)
loadams Sep 3, 2024
a256c04
Fix labels & eos_token for SFT (#819)
li-plus Sep 10, 2024
90c2a9f
DeepNVMe ZeRO-inf Tutorial (#921)
jomayeri Sep 17, 2024
f73a6ed
Enable overlap_comm for better performance (#846)
li-plus Sep 17, 2024
130fb58
[cifar ds training]: Set cuda device during initialization of distrib…
jagadish-amd Oct 29, 2024
5a61193
Fixed mistake in readme (#933)
SCheekati Oct 29, 2024
cab3361
Replace deprecated transformers.deepspeed module (#872)
HollowMan6 Oct 29, 2024
aa4459f
Εnable reward model offloading option (#930)
kfertakis Oct 29, 2024
eefb0ef
Remove the fixed `eot_token` mechanism for SFT (#927)
Xingfu-Yi Oct 30, 2024
faa0420
DeepSpeed-Domino (#929)
zhangsmallshark Nov 7, 2024
be0a0e1
Update DeepSpeed version requirement to >=0.16.0 in requirements.txt …
shenzheyu Nov 27, 2024
fd79b31
Example and benchmark of APIs to offload states (#942)
tohtana Dec 13, 2024
476f600
remove-redundant-code (#947)
simonJJJ Dec 24, 2024
1842b4f
Add DPO support for DeepSpeed-Chat (#828)
stceum Jan 6, 2025
b965b9c
Update references to torchvision (#949)
loadams Jan 21, 2025
a85b5e6
Cleanup CODEOWNERS (#953)
loadams Jan 24, 2025
8075143
fix: the json format of the training imagenet configuration file (#954)
navanis Jan 30, 2025
b90ffab
Update references to deepspeedai GH org (#955)
loadams Feb 7, 2025
83757d9
Update weights_only due to change in default in torch 2.6+ (#957)
loadams Feb 12, 2025
420352c
Variable batch size and LR example for DeepSpeed PR #7104 (#963)
bm-synth Mar 11, 2025
b623258
Fix: Add output_folder parameter and correct print statement (#962)
majianpeng Mar 14, 2025
223665c
run domino example on amd (#958)
hwchen2017 Mar 27, 2025
7b34e07
update runner image (#968)
tohtana Apr 16, 2025
b76c7cc
Add example of DeepCompile (#967)
tohtana Apr 16, 2025
93ebac3
fix links (#970)
tohtana Apr 17, 2025
ce39bf0
Update description of versions for deepcompile (#971)
tohtana Apr 18, 2025
65bc536
Fix DeepCompile benchmark script (#973)
tohtana Apr 20, 2025
bd47e5b
Add example for Deepspeed-AutoTP (#964)
inkcherry May 23, 2025
86aeab2
fix: Fix: Correctly define choices as tuple for reward-model arg Fi…
Flink-ddd Jun 9, 2025
207c93c
DeepNVMe update (#966)
tjruwase Jun 9, 2025
b018de1
Update domino example (#976)
hwchen2017 Jun 12, 2025
28a984e
Simplify and add README (#978)
tjruwase Jun 18, 2025
b99d653
Add file extension (#980)
hwchen2017 Jun 21, 2025
4579df3
Update submodule link to reflect https style (#981)
raviguptaamd Jul 4, 2025
3d83278
fix init weights issue for critic/reward model (#983)
jouw Jul 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:

# formatting and basic install on cpu-only machine
formatting:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v2
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "training/DeepSpeed-Domino/Megatron-LM"]
path = training/DeepSpeed-Domino/Megatron-LM
url = https://github.com/NVIDIA/Megatron-LM.git
2 changes: 1 addition & 1 deletion CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @jeffra @samyam @tjruwase @ShadenSmith @conglongli @awan-10 @cli99 @eltonzheng @minjiaz @RezaYazdaniAminabadi @duli2012 @mrwyattii @yaozhewei @arashb @xiaoxiawu-microsoft
* @tjruwase @ShadenSmith @awan-10 @minjiaz
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# DeepSpeed Examples
This repository contains various examples including training, inference, compression, benchmarks, and applications that use [DeepSpeed](https://github.com/microsoft/DeepSpeed).
This repository contains various examples including training, inference, compression, benchmarks, and applications that use [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).

## 1. Applications
This folder contains end-to-end applications that use DeepSpeed to train and use cutting-edge models.
Expand All @@ -8,14 +8,19 @@ This folder contains end-to-end applications that use DeepSpeed to train and use
There are several training and finetuning examples so please see the individual folders for specific instructions.

## 3. Inference
The DeepSpeed Huggingface inference [README](./inference/huggingface/README.md) explains how to get started with running DeepSpeed Huggingface inference examples.
- The DeepSpeed-MII inference [README](./inference/mii/README.md) explains how to get started with running model inference with [DeepSpeed-MII](https://github.com/deepspeedai/DeepSpeed-MII) and [DeepSpeed-FastGen](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen).
- The DeepSpeed Huggingface inference [README](./inference/huggingface/README.md) explains how to get started with running DeepSpeed Huggingface inference examples.

## 4. Compression
Model compression examples.

## 5. Benchmarks
All benchmarks that use the DeepSpeed library are maintained in this folder.

# Build Pipeline Status
| Description | Status |
| ----------- | ------ |
| Integrations | [![nv-ds-chat](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-ds-chat.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-ds-chat.yml) |

# Contributing

Expand Down
137 changes: 137 additions & 0 deletions applications/DeepSpeed-Chat/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

results/
outputs/

.amltconfig
.test_output
*.hdf5
*.h5
127 changes: 86 additions & 41 deletions applications/DeepSpeed-Chat/README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
# DeepSpeed Team
import torch
import torch.nn.functional as F
import sys
import os
import time
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))

from utils.utils import print_rank_0
from dschat.utils.utils import print_rank_0


def print_all_ranks(tag, value, rank):
world_size = torch.distributed.get_world_size()
all_tensor = torch.zeros(world_size, dtype=torch.float32).cuda()
all_tensor = torch.zeros(world_size, dtype=torch.float32).to(
get_accelerator().current_device_name())
all_tensor[rank] = value
torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM)
print_rank_0(f'{tag} {all_tensor}', rank)
Expand Down Expand Up @@ -57,6 +55,12 @@ def __init__(self, rlhf_engine, args):
self.end_of_conversation_token_id = self.tokenizer(
args.end_of_conversation_token)['input_ids'][-1]
self.z3_enabled = args.actor_zero_stage == 3
self.compute_fp32_loss = self.args.compute_fp32_loss

# In case the generated experience is not valid (too short), we use the last valid
# generated experience. Alternatively, we can skip the step (on all workers).
# For now, use the last valid experience which is a simpler solution
self.last_generated_experience = None

# Those value can be changed
self.kl_ctl = 0.1
Expand All @@ -65,18 +69,28 @@ def __init__(self, rlhf_engine, args):
self.cliprange_value = 0.2
self.gamma = 1.0
self.lam = 0.95
self.generate_time = 0.0

def _generate_sequence(self, prompts, mask, step):

max_min_length = self.max_answer_seq_len + prompts.shape[1]

# This has been added due to a probability/nan error that happens after
# meta-llama/Llama-2-7b-hf enabled do_sample:
# https://huggingface.co/meta-llama/Llama-2-7b-hf/commit/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9
if self.actor_model.module.config.model_type == "llama":
kwargs = dict(do_sample=False)
else:
kwargs = dict()

with torch.no_grad():
seq = self.actor_model.module.generate(
prompts,
attention_mask=mask,
max_length=max_min_length,
pad_token_id=self.tokenizer.pad_token_id,
synced_gpus=self.z3_enabled)
synced_gpus=self.z3_enabled,
**kwargs)

# Filter out seq with no answers (or very short). This happens when users directly use the pre-training ckpt without supervised finetuning
# NOTE: this will causes each GPU has different number of examples
Expand All @@ -86,7 +100,8 @@ def _generate_sequence(self, prompts, mask, step):
ans = seq[:, prompt_length:]
valid_ans_len = (ans != self.tokenizer.pad_token_id).sum(dim=-1)

if self.args.print_answers:
if self.args.print_answers and (step % self.args.print_answers_interval
== 0):
print(
f"--- prompt --> step={step}, rank={torch.distributed.get_rank()}, {self.tokenizer.batch_decode(prompts, skip_special_tokens=True)}"
)
Expand All @@ -98,16 +113,38 @@ def _generate_sequence(self, prompts, mask, step):
for i in range(batch_size):
if valid_ans_len[
i] <= 1: # if the answer is shorter than 1 token, drop it
print(
f'Dropping too short generated answer: {step=}: \n'
f'prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
f'answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}'
)
continue
else:
out_seq.append(seq[i:i + 1])
out_seq = torch.cat(out_seq, dim=0) # concate output in the batch dim

if not out_seq:
print(
f'All generated results are too short for rank={self.args.local_rank} step={step}\n'
f'-> prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
f'-> answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}'
)
return None

out_seq = torch.cat(out_seq, dim=0) # concat output in the batch dim

return out_seq

def generate_experience(self, prompts, mask, step):
self.eval()
generate_start = time.time()
seq = self._generate_sequence(prompts, mask, step)
generate_end = time.time()
if seq is None:
assert self.last_generated_experience is not None, f'Invalid generated experience at {step=}'
prompts = self.last_generated_experience['prompts']
seq = self.last_generated_experience['seq']
else:
self.last_generated_experience = {'prompts': prompts, 'seq': seq}
self.train()

pad_token_id = self.tokenizer.pad_token_id
Expand All @@ -124,6 +161,11 @@ def generate_experience(self, prompts, mask, step):

logits = output.logits
logits_ref = output_ref.logits
if self.compute_fp32_loss:
logits = logits.to(torch.float)
logits_ref = logits_ref.to(torch.float)

self.generate_time = generate_end - generate_start

return {
'prompts': prompts,
Expand Down Expand Up @@ -226,6 +268,17 @@ def train_rlhf(self, inputs):

return actor_loss, critic_loss

def get_overflow(self):
# Overflow is not expected when using bf16
# Therefore, DeepSpeed's BF16_Optimizer does not maintain an overflow indication
if self.args.dtype == "bf16":
return False, False

actor_overflow = self.actor_model.optimizer.overflow
critic_overflow = self.critic_model.optimizer.overflow

return actor_overflow, critic_overflow

def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
## policy gradient loss
log_ratio = (logprobs - old_logprobs) * mask
Expand All @@ -243,6 +296,9 @@ def critic_loss_fn(self, values, old_values, returns, mask):
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
if self.compute_fp32_loss:
values = values.float()
values_clipped = values_clipped.float()
vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(
Expand Down
Loading