Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
70c6d40
memory benchmark rss
thomwolf Mar 8, 2020
e0c50a6
have both forward pass and line-by-line mem tracing
thomwolf Mar 10, 2020
03e14b2
cleaned up tracing
thomwolf Mar 12, 2020
f77d7d9
refactored and cleaning up API
thomwolf Mar 12, 2020
88ea59f
no f-strings yet...
thomwolf Mar 12, 2020
f46ff48
add GPU mem logging
thomwolf Mar 12, 2020
e9182b4
fix GPU memory monitoring
thomwolf Mar 12, 2020
675dfa2
style and quality
thomwolf Mar 12, 2020
8da965b
clean up and doc
thomwolf Mar 12, 2020
4cd18c1
boom boom
sshleifer Mar 13, 2020
47c25cd
Merge branch 'master' into shleifer-memprof
sshleifer Mar 15, 2020
f6d2c64
add test
sshleifer Mar 15, 2020
6849c29
viewer
sshleifer Mar 15, 2020
727e754
saver
sshleifer Mar 15, 2020
562e6c5
add to mnli
sshleifer Mar 15, 2020
c9c0e74
different fnames
sshleifer Mar 15, 2020
8b7ae1b
use LoggingMixin
sshleifer Mar 15, 2020
a229a27
more logigng
sshleifer Mar 15, 2020
bc07602
add preinit
sshleifer Mar 15, 2020
09a6894
dont log preinit
sshleifer Mar 15, 2020
75381a4
only sometimes update layer state
sshleifer Mar 15, 2020
cc0363c
Script
sshleifer Mar 15, 2020
480c8c6
default
sshleifer Mar 15, 2020
3da23d4
Do generate flag
sshleifer Mar 15, 2020
a6a592d
no output_past
sshleifer Mar 15, 2020
681e0a3
fix
sshleifer Mar 15, 2020
8b46f83
no grad
sshleifer Mar 15, 2020
7d669b8
no lm_head
sshleifer Mar 15, 2020
237202f
get it off gpu
sshleifer Mar 15, 2020
3bb13cf
del
sshleifer Mar 15, 2020
8381c9b
Fix
sshleifer Mar 17, 2020
1c5fe4e
new padding strategy
sshleifer Mar 18, 2020
c9d5c63
bugfix
sshleifer Mar 18, 2020
dffc461
undo chg
sshleifer Mar 18, 2020
9590f16
Who knows
sshleifer Mar 18, 2020
bfaae34
del trace
sshleifer Mar 18, 2020
c61f3e4
del trace
sshleifer Mar 18, 2020
0274d58
Fix mask
sshleifer Mar 18, 2020
7629d42
cant be worse
sshleifer Mar 18, 2020
e0fdb76
merge master
sshleifer Mar 18, 2020
2589099
bart mem utests
sshleifer Mar 19, 2020
ae6a7c6
callfwd
sshleifer Mar 19, 2020
a022f5c
merged master
sshleifer Mar 19, 2020
01218ac
new test file
sshleifer Mar 19, 2020
c8cca90
no attn_weights
sshleifer Mar 19, 2020
f4bc62a
undo chg
sshleifer Mar 19, 2020
255ebe1
Delay mem
sshleifer Mar 19, 2020
600d62a
boom boom
sshleifer Mar 19, 2020
6ff2eb5
boom boom
sshleifer Mar 19, 2020
4199a7b
boom boom
sshleifer Mar 19, 2020
8219f5c
boom boom
sshleifer Mar 19, 2020
74dcbb2
boom boom
sshleifer Mar 19, 2020
685d892
boom boom
sshleifer Mar 19, 2020
8fd4be3
undo thom chg
sshleifer Mar 19, 2020
12cf809
boom boom
sshleifer Mar 19, 2020
ed1c07f
boom boom
sshleifer Mar 19, 2020
96e701b
boom boom
sshleifer Mar 19, 2020
975c282
boom boom
sshleifer Mar 19, 2020
3b81c20
boom boom
sshleifer Mar 19, 2020
c572b13
boom boom
sshleifer Mar 19, 2020
32959ea
boom boom
sshleifer Mar 19, 2020
5a91a71
v similar
sshleifer Mar 19, 2020
bcfd0d4
boom boom
sshleifer Mar 19, 2020
24c56e8
boom boom
sshleifer Mar 19, 2020
d78ecc7
boom boom
sshleifer Mar 19, 2020
3537776
boom boom
sshleifer Mar 19, 2020
fb01c11
not verbose encoder
sshleifer Mar 19, 2020
9462adf
not verbose encoder
sshleifer Mar 19, 2020
4fc9477
inline the unsquoze key_padding_mask
sshleifer Mar 20, 2020
0326279
boom boom
sshleifer Mar 20, 2020
189f2f1
merged master
sshleifer Mar 20, 2020
d615cae
rearrange encoder_outputs call
sshleifer Mar 20, 2020
335682f
Merge branch 'master' into shleifer-memprof-nobm
sshleifer Mar 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/transformers/bart_mem_prof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from transformers import *
import torch
DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def runner(source_path, out_file, batch_size=8, device=DEFAULT_DEVICE, prof_generate=False):

tokenizer = BartTokenizer.from_pretrained('bart-large')
lns = [" " + x.rstrip() for x in open(source_path).readlines()][:batch_size]

dct = tokenizer.batch_encode_plus(lns, max_length=1024, return_tensors="pt", pad_to_max_length=True)
ids = dct['input_ids'].to(DEFAULT_DEVICE)
msk = dct['attention_mask'].to(DEFAULT_DEVICE)
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn', output_past=prof_generate).to(DEFAULT_DEVICE)
model.log_mem('starting')
if prof_generate:

summaries = model.generate(
input_ids=ids,
attention_mask=msk,
num_beams=4,
length_penalty=2.0,
max_length=140 + 2, # +2 from original because we start at step=1 and stop before max_length
min_length=55 + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3,
early_stopping=True,
do_sample=False,
decoder_start_token_id=model.config.eos_token_ids[0],
)
model.log_mem('done')
dec = [tokenizer.decode(s) for s in summaries]
print(dec[0])
else:
#model.decoder.generation_mode = Fals
with torch.no_grad():
model(
input_ids=ids,
attention_mask=msk,
)

log_df = model.combine_logs()
log_df.to_csv(out_file)


import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument(
"output_path", type=str, help="where to save summaries",
)
parser.add_argument(
"--source_path", type=str, default="/home/shleifer/transformers_fork/notebooks/test.source",
help="like cnn_dm/test.source", required=False
)
parser.add_argument(
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
)
parser.add_argument(
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
)
parser.add_argument(
"--do-generate", action='store_true', required=False, help="batch size: how many to summarize at a time",
)
args = parser.parse_args()
runner(args.source_path, args.output_path, batch_size=args.bs, device=args.device, prof_generate=args.do_generate)



Loading