Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions ane_transformers/huggingface/distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
import re

from ane_transformers.reference.layer_norm import LayerNormANE

Expand Down Expand Up @@ -520,14 +521,26 @@ def forward(
return ((loss, ) + output) if loss is not None else output


_LINEAR_TO_CONV2D_LAYERS = [
"q_lin.weight",
"k_lin.weight",
"v_lin.weight",
"out_lin.weight",
"lin1.weight",
"lin2.weight",
"classifier.weight",
"pre_classifier.weight",
"vocab_transform.weight",
"vocab_projector.weight",
"qa_outputs.weight",
]


def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
""" Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
"""
for k in state_dict:
is_internal_proj = all(substr in k for substr in ['lin', '.weight'])
is_output_proj = all(substr in k
for substr in ['classifier', '.weight'])
if is_internal_proj or is_output_proj:
if any(k.endswith(layer) for layer in _LINEAR_TO_CONV2D_LAYERS):
if len(state_dict[k].shape) == 2:
state_dict[k] = state_dict[k][:, :, None, None]
56 changes: 55 additions & 1 deletion ane_transformers/huggingface/test_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import logging
import numpy as np
import unittest
import time

import torch

Expand All @@ -32,6 +31,10 @@
("This is not what I expected!", "NEGATIVE"),
])

MASKED_LM_MODEL = 'distilbert-base-uncased'
QUESTION_ANSWERING_MODEL = 'distilbert-base-uncased-distilled-squad'
TOKEN_CLASSIFICATION_MODEL = 'elastic/distilbert-base-uncased-finetuned-conll03-english'
MULTIPLE_CHOICE_MODEL = 'Gladiator/distilbert-base-uncased_swag_mqa'

class TestDistilBertForSequenceClassification(unittest.TestCase):
"""
Expand Down Expand Up @@ -191,5 +194,56 @@ def test_coreml_conversion_and_speedup(self):
)


class TestDistilBertLoadState(unittest.TestCase):
"""
Test load_state_dict compatibility.
"""

test_params = (
(
MASKED_LM_MODEL,
transformers.AutoModelForMaskedLM,
ane_transformers.DistilBertForMaskedLM,
),
(
QUESTION_ANSWERING_MODEL,
transformers.AutoModelForQuestionAnswering,
ane_transformers.DistilBertForQuestionAnswering,
),
(
TOKEN_CLASSIFICATION_MODEL,
transformers.AutoModelForTokenClassification,
ane_transformers.DistilBertForTokenClassification,
),
(
MULTIPLE_CHOICE_MODEL,
transformers.AutoModelForMultipleChoice,
ane_transformers.DistilBertForMultipleChoice,
),
)

def test_load_state(self):
for model_name, auto_model_cls, ane_model_cls in self.test_params:
with self.subTest(ane_model_cls=ane_model_cls):
try:
# Instantiate the reference model from an exemplar pre-trained
# model hosted on huggingface.co/models
reference_model = auto_model_cls.from_pretrained(
model_name,
return_dict=False,
torchscript=True,
).eval()
except Exception as e:
raise RuntimeError(
"Failed to download reference model from huggingface.co/models!"
) from e
logger.info("Downloaded reference model from huggingface.co/models")

# Initialize an ANE equivalent model and restore the checkpoint
test_model = ane_model_cls(reference_model.config).eval()
test_model.load_state_dict(reference_model.state_dict())
logger.info("Initialized and restored test model")


if __name__ == "__main__":
unittest.main()