Skip to content
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings

@torch.no_grad()
def forward(self, x, position_ids):
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings

@torch.no_grad()
def forward(self, x, position_ids):
Expand Down
422 changes: 229 additions & 193 deletions src/transformers/models/phi/modeling_phi.py

Large diffs are not rendered by default.

324 changes: 201 additions & 123 deletions src/transformers/models/phi3/modeling_phi3.py

Large diffs are not rendered by default.

91 changes: 78 additions & 13 deletions tests/models/phi/test_modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest

import pytest
from packaging import version
from parameterized import parameterized

from transformers import PhiConfig, is_torch_available, set_seed
Expand Down Expand Up @@ -397,7 +398,7 @@ def test_model_rope_scaling_from_config(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))

# Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Phi
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling with Llama->Phi
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
Expand All @@ -409,17 +410,21 @@ def test_model_rope_scaling(self):

# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)

# Sanity check original RoPE
original_rope = PhiRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length)
original_cos_long, original_sin_long = original_rope(x, long_input_length)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])

# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
Expand All @@ -429,14 +434,14 @@ def test_model_rope_scaling(self):
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])

# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
Expand All @@ -447,8 +452,8 @@ def test_model_rope_scaling(self):
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
Expand Down Expand Up @@ -498,6 +503,16 @@ def test_flash_attn_2_generate_padding_right(self):
@slow
@require_torch
class PhiIntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None

@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]

def test_model_phi_1_logits(self):
input_ids = {
"input_ids": torch.tensor(
Expand Down Expand Up @@ -564,3 +579,53 @@ def test_phi_2_generation(self):
]

self.assertListEqual(output_text, EXPECTED_OUTPUT)

@slow
@require_torch_gpu
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("This test requires torch >= 2.3 to run.")

NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
EXPECTED_TEXT_COMPLETION = {
8: [
"Simply put, the theory of relativity states that \n\n$$\nE = mc^2\n$$\n\nwhere $E$ is the energy of an object, $m$ is its mass, and $c$ is the speed of light",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my fries, I love it on my burgers, I love it on my hot dogs, I love it on my chicken nuggets, I love it",
],
7: [
"Simply put, the theory of relativity states that \n\n$$\nE = mc^2\n$$\n\nwhere $E$ is the energy of an object, $m$ is its mass, and $c$ is the speed of light",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my fries, I love it on my burgers, I love it on my hot dogs, I love it on my chicken nuggets, I love it",
],
}

prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", pad_token="<|endoftext|>", padding_side="left")
model = PhiForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype=torch.float16).to(torch_device)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output

# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)

# Static Cache + compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
65 changes: 65 additions & 0 deletions tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import Phi3Config, is_torch_available, set_seed
from transformers.testing_utils import (
require_torch,
require_torch_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -396,6 +397,16 @@ def test_model_rope_scaling_from_config(self, scaling_type):
@slow
@require_torch
class Phi3IntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None

@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]

def test_model_phi3_mini_4k_instruct_logits(self):
input_ids = {
"input_ids": torch.tensor(
Expand Down Expand Up @@ -471,3 +482,57 @@ def test_phi3_mini_128k_instruct_generation(self):
]

self.assertListEqual(output_text, EXPECTED_OUTPUT)

@slow
@require_torch_gpu
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
# if version.parse(torch.__version__) < version.parse("2.3.0"):
# self.skipTest("This test requires torch >= 2.3 to run.")

NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
EXPECTED_TEXT_COMPLETION = {
8: [
"Can you provide ways to eat combinations of bananas and dragonfruits? Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some ideas for eating combinations of bananas and dragonfruits:\n\n1",
],
7: [
"Can you provide ways to eat combinations of bananas and dragonfruits? Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some ideas for eating combinations of bananas and dragonfruits:\n\n1",
],
}
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct", torch_dtype=torch.float16).to(
torch_device
)
messages = [
{
"role": "system",
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
},
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(
torch_device
)

# Dynamic Cache
generated_ids = model.generate(inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output

# Static Cache
generated_ids = model.generate(
inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)

# Static Cache + compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
2 changes: 1 addition & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4414,7 +4414,7 @@ def test_custom_4d_attention_mask(self):
if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
if getattr(config, "sliding_window", 0) > 0:
if getattr(config, "sliding_window", None) is not None:
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
model = model_class(config).to(device=torch_device, dtype=torch.float32)

Expand Down