From ee34e9c0d4dec891d48f2b9aae29a8eef0df4f05 Mon Sep 17 00:00:00 2001 From: Anentropic Date: Sun, 23 Apr 2023 19:28:53 +0100 Subject: [PATCH 1/2] fix linear_to_conv2d_map to work with other distilbert model types --- ane_transformers/huggingface/distilbert.py | 23 ++++++-- .../huggingface/test_distilbert.py | 56 ++++++++++++++++++- 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/ane_transformers/huggingface/distilbert.py b/ane_transformers/huggingface/distilbert.py index 4845c22..c7a8096 100644 --- a/ane_transformers/huggingface/distilbert.py +++ b/ane_transformers/huggingface/distilbert.py @@ -2,6 +2,7 @@ # For licensing see accompanying LICENSE.md file. # Copyright (C) 2022 Apple Inc. All Rights Reserved. # +import re from ane_transformers.reference.layer_norm import LayerNormANE @@ -520,14 +521,28 @@ def forward( return ((loss, ) + output) if loss is not None else output +_LINEAR_TO_CONV2D_LAYERS_RE = re.compile(r".*({})\.weight".format( + "|".join([ + "q_lin", + "k_lin", + "v_lin", + "out_lin", + "lin1", + "lin2", + "classifier", + "pre_classifier", + "vocab_transform", + "vocab_projector", + "qa_outputs", + ]) +)) + + def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """ Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights """ for k in state_dict: - is_internal_proj = all(substr in k for substr in ['lin', '.weight']) - is_output_proj = all(substr in k - for substr in ['classifier', '.weight']) - if is_internal_proj or is_output_proj: + if _LINEAR_TO_CONV2D_LAYERS_RE.match(k): if len(state_dict[k].shape) == 2: state_dict[k] = state_dict[k][:, :, None, None] diff --git a/ane_transformers/huggingface/test_distilbert.py b/ane_transformers/huggingface/test_distilbert.py index b14f14d..7e43a07 100644 --- a/ane_transformers/huggingface/test_distilbert.py +++ b/ane_transformers/huggingface/test_distilbert.py @@ -9,7 +9,6 @@ import logging import numpy as np import unittest -import time import torch @@ -32,6 +31,10 @@ ("This is not what I expected!", "NEGATIVE"), ]) +MASKED_LM_MODEL = 'distilbert-base-uncased' +QUESTION_ANSWERING_MODEL = 'distilbert-base-uncased-distilled-squad' +TOKEN_CLASSIFICATION_MODEL = 'elastic/distilbert-base-uncased-finetuned-conll03-english' +MULTIPLE_CHOICE_MODEL = 'Gladiator/distilbert-base-uncased_swag_mqa' class TestDistilBertForSequenceClassification(unittest.TestCase): """ @@ -191,5 +194,56 @@ def test_coreml_conversion_and_speedup(self): ) +class TestDistilBertLoadState(unittest.TestCase): + """ + Test load_state_dict compatibility. + """ + + test_params = ( + ( + MASKED_LM_MODEL, + transformers.AutoModelForMaskedLM, + ane_transformers.DistilBertForMaskedLM, + ), + ( + QUESTION_ANSWERING_MODEL, + transformers.AutoModelForQuestionAnswering, + ane_transformers.DistilBertForQuestionAnswering, + ), + ( + TOKEN_CLASSIFICATION_MODEL, + transformers.AutoModelForTokenClassification, + ane_transformers.DistilBertForTokenClassification, + ), + ( + MULTIPLE_CHOICE_MODEL, + transformers.AutoModelForMultipleChoice, + ane_transformers.DistilBertForMultipleChoice, + ), + ) + + def test_load_state(self): + for model_name, auto_model_cls, ane_model_cls in self.test_params: + with self.subTest(ane_model_cls=ane_model_cls): + try: + # Instantiate the reference model from an exemplar pre-trained + # model hosted on huggingface.co/models + reference_model = auto_model_cls.from_pretrained( + model_name, + return_dict=False, + torchscript=True, + ).eval() + except Exception as e: + raise RuntimeError( + "Failed to download reference model from huggingface.co/models!" + ) from e + logger.info("Downloaded reference model from huggingface.co/models") + + # Initialize an ANE equivalent model and restore the checkpoint + test_model = ane_model_cls(reference_model.config).eval() + test_model.load_state_dict(reference_model.state_dict()) + logger.info("Initialized and restored test model") + + if __name__ == "__main__": unittest.main() From 4f60bae0a87ce158ddd230c2e0c234f302862a7f Mon Sep 17 00:00:00 2001 From: Anentropic Date: Tue, 25 Apr 2023 10:24:33 +0100 Subject: [PATCH 2/2] do it without regex --- ane_transformers/huggingface/distilbert.py | 30 ++++++++++------------ 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/ane_transformers/huggingface/distilbert.py b/ane_transformers/huggingface/distilbert.py index c7a8096..c6bd117 100644 --- a/ane_transformers/huggingface/distilbert.py +++ b/ane_transformers/huggingface/distilbert.py @@ -521,21 +521,19 @@ def forward( return ((loss, ) + output) if loss is not None else output -_LINEAR_TO_CONV2D_LAYERS_RE = re.compile(r".*({})\.weight".format( - "|".join([ - "q_lin", - "k_lin", - "v_lin", - "out_lin", - "lin1", - "lin2", - "classifier", - "pre_classifier", - "vocab_transform", - "vocab_projector", - "qa_outputs", - ]) -)) +_LINEAR_TO_CONV2D_LAYERS = [ + "q_lin.weight", + "k_lin.weight", + "v_lin.weight", + "out_lin.weight", + "lin1.weight", + "lin2.weight", + "classifier.weight", + "pre_classifier.weight", + "vocab_transform.weight", + "vocab_projector.weight", + "qa_outputs.weight", +] def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict, @@ -543,6 +541,6 @@ def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict, """ Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights """ for k in state_dict: - if _LINEAR_TO_CONV2D_LAYERS_RE.match(k): + if any(k.endswith(layer) for layer in _LINEAR_TO_CONV2D_LAYERS): if len(state_dict[k].shape) == 2: state_dict[k] = state_dict[k][:, :, None, None]