From c41119355f21e8ea8a603ec4af07c6f533e8c95d Mon Sep 17 00:00:00 2001 From: Luke Date: Wed, 11 Sep 2019 20:44:54 +0000 Subject: [PATCH] Updated encode text script --- gpt2/encode_text.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/gpt2/encode_text.py b/gpt2/encode_text.py index 0150306..1cab98b 100644 --- a/gpt2/encode_text.py +++ b/gpt2/encode_text.py @@ -1,32 +1,32 @@ import sys from pytorch_pretrained_bert import GPT2Tokenizer import regex as re +import argparse + +parser = argparse.ArgumentParser(description='Encode text') +parser.add_argument('--input_file', type=str, help='input file') +parser.add_argument('--output_file', type=str, help='full output filename (usually ends in .bpe)') +parser.add_argument('--add_tldr', action='store_true', help='adds \nTL;DR') +parser.add_argument('--replace_newline', action='store_true', help='replace with \\n') +parser.add_argument('--tok_trunc', type=int, default=1000000, help='truncate tokens') +args = parser.parse_args() pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") enc = GPT2Tokenizer.from_pretrained('gpt2') -filename = sys.argv[1] - -with_tldr = False -replace_newline = False -tok_trunc = 1000000 - -write_name = file_prefix+filename+'.bpe' -if with_tldr and 'src' in filename: - write_name += '.tldr' - -with open(file_prefix+filename, 'r') as f: - with open(write_name, 'w') as fw: +with open(args.input_file, 'r') as f: + with open(args.output_file, 'w') as fw: for line in f: txt = line.strip() - if with_tldr and 'src' in filename: + + if args.add_tldr: txt += '\nTL;DR:' - if replace_newline: + if args.replace_newline: txt = txt.replace('', '\n') bpe_tokens = [] for token in re.findall(pat, txt): # line.strip() to make sure newline is not encoded token = ''.join(enc.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(enc.bpe(token).split(' ')) - fw.write(' '.join(bpe_tokens[:tok_trunc]) + '\n') + fw.write(' '.join(bpe_tokens[:args.tok_trunc]) + '\n')