Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import argparse
import logging
import os
from pathlib import Path

import fairseq
Expand All @@ -30,10 +31,11 @@
BartModel,
BartTokenizer,
)
from transformers.modeling_bart import _make_linear_from_emb


FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]

FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")

Expand All @@ -57,62 +59,79 @@ def rename_key(dct, old, new):
dct[new] = val


def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
def load_xsum_checkpoint(checkpoint_path):
"""Checkpoint path should end in model.pt"""
sd = torch.load(checkpoint_path, map_location="cpu")
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
hub_interface.model.load_state_dict(sd["model"])
return hub_interface


@torch.no_grad()
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
bart = torch.hub.load("pytorch/fairseq", checkpoint_path)
bart.eval() # disable dropout
if not os.path.exists(checkpoint_path):
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
else:
bart = load_xsum_checkpoint(checkpoint_path)

bart.model.upgrade_state_dict(bart.model.state_dict())
hf_model_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_model_name)
if hf_checkpoint_name is None:
hf_checkpoint_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_checkpoint_name)
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
assert torch.eq(tokens, tokens2).all()

if checkpoint_path in ["bart.large", "bart.large.cnn"]:
state_dict = bart.model.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = BartModel(config)
their_output = bart.extract_features(tokens)
else: # MNLI Case
if checkpoint_path == "bart.large.mnli":
state_dict = bart.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
remove_ignore_keys_(state_dict)
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
model = BartForSequenceClassification(config)
their_output = bart.predict("mnli", tokens, return_logits=True)
model = BartForSequenceClassification(config).eval()
model.load_state_dict(state_dict)
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
new_model_outputs = model(tokens)[0] # logits
else: # no classification heads to worry about
state_dict = bart.model.state_dict()
remove_ignore_keys_(state_dict)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
fairseq_output = bart.extract_features(tokens)
if hf_checkpoint_name == "bart-large":
model = BartModel(config).eval()
model.load_state_dict(state_dict)
new_model_outputs = model(tokens).model[0]
else:
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
model.model.load_state_dict(state_dict)
if hasattr(model, "lm_head"):
model.lm_head = _make_linear_from_emb(model.model.shared)
new_model_outputs = model.model(tokens)[0]

# Load state dict
model.load_state_dict(state_dict)
model.eval()
# Check results

if checkpoint_path == "bart.large.cnn":
model = BartForConditionalGeneration(config, base_model=model)
assert "lm_head.weight" in model.state_dict()
assert model.lm_head.out_features == config.max_position_embeddings
model.eval()
our_outputs = model.model(tokens)[0]
else:
our_outputs = model(tokens)[0]
assert their_output.shape == our_outputs.shape
assert (their_output == our_outputs).all().item()
assert fairseq_output.shape == new_model_outputs.shape
assert (fairseq_output == new_model_outputs).all().item()
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)


def remove_ignore_keys_(state_dict):
for k in IGNORE_KEYS:
state_dict.pop(k, None)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("fairseq_path", choices=FAIRSEQ_MODELS, type=str, help="")

parser.add_argument(
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
)
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_bart_checkpoint(
args.fairseq_path, args.pytorch_dump_folder_path,
parser.add_argument(
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
)
args = parser.parse_args()
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)
1 change: 1 addition & 0 deletions src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin",
}

BART_START_DOCSTRING = r"""
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/tokenization_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn"]
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn", "bart-large-xsum"]


class BartTokenizer(RobertaTokenizer):
Expand Down
32 changes: 32 additions & 0 deletions tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,38 @@ def test_model_from_pretrained(self):
model = BartModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)

@slow
def test_xsum_summarization_same_as_fairseq(self):
model = BartForConditionalGeneration.from_pretrained("bart-large-xsum").to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")

PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
dct = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",)

hypotheses_batch = model.generate(
input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=2,
max_length=62,
min_length=11,
length_penalty=1.0,
no_repeat_ngram_size=3,
early_stopping=True,
decoder_start_token_id=model.config.eos_token_ids[0],
)

decoded = [
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
]
self.assertEqual(EXPECTED_SUMMARY, decoded[0])

def test_xsum_config_generation_params(self):
config = BartConfig.from_pretrained("bart-large-xsum")
expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0)
config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
self.assertDictEqual(expected_params, config_params)

@slow
def test_cnn_summarization_same_as_fairseq(self):
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
Expand Down