Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,9 @@ def test_dola_decoding_sample(self):

if model.get_output_embeddings() is None:
self.skipTest("DoLa is not supported for models that don't have output embeddings")

logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do sample is random no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using the same value as in generation_kwargs = {...} a few line below.

Yes it is random but this method is test_...._sample so makes sense.


# Sets dola generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) there are at least two forward passes in the main model, to ensure the input preparation of
Expand All @@ -1280,7 +1283,7 @@ def test_dola_decoding_sample(self):
"use_cache": getattr(config, "use_cache", False), # Some models don't support the cache
"dola_layers": "low",
}
output_dola = model.generate(**generation_kwargs, **inputs_dict)
output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict)
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))

@pytest.mark.generate
Expand Down
6 changes: 3 additions & 3 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
},
is_training=True,
vision_config={
"image_size": 30,
"image_size": 8,
"patch_size": 2,
"num_channels": 3,
"is_training": True,
Expand Down Expand Up @@ -118,9 +118,9 @@ def __init__(
self.batch_size = 3
self.num_channels = 3
self.image_size = 336
self.encoder_seq_length = 232
self.num_image_tokens = 225
self.num_image_tokens = (self.vision_config["image_size"] // self.vision_config["patch_size"]) ** 2
self.seq_length = seq_length + self.num_image_tokens
self.encoder_seq_length = self.seq_length

def get_config(self):
return LlavaConfig(
Expand Down
6 changes: 3 additions & 3 deletions tests/models/vipllava/test_modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
is_training=True,
vision_config={
"batch_size": 12,
"image_size": 30,
"image_size": 8,
"patch_size": 2,
"num_channels": 3,
"is_training": True,
Expand Down Expand Up @@ -111,9 +111,9 @@ def __init__(
self.batch_size = 3
self.num_channels = 3
self.image_size = 336
self.encoder_seq_length = 232
self.num_image_tokens = 225
self.num_image_tokens = (self.vision_config["image_size"] // self.vision_config["patch_size"]) ** 2
self.seq_length = seq_length + self.num_image_tokens
self.encoder_seq_length = self.seq_length

def get_config(self):
return VipLlavaConfig(
Expand Down
88 changes: 38 additions & 50 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3982,6 +3982,13 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str):
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"

if hasattr(self.model_tester, "num_hidden_layers"):
self.model_tester.num_hidden_layers = 1
if hasattr(self.model_tester, "vision_config") and "num_hidden_layers" in self.model_tester.vision_config:
self.model_tester.vision_config["num_hidden_layers"] = 1
if hasattr(self.model_tester, "text_config") and "num_hidden_layers" in self.model_tester.text_config:
self.model_tester.text_config["num_hidden_layers"] = 1

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
Expand Down Expand Up @@ -4013,7 +4020,8 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
for batch_size in [1, 5]:
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
dummy_input = inputs_dict[model.main_input_name]

if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
Expand Down Expand Up @@ -4064,14 +4072,14 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :-1] = 1
dummy_attention_mask[-1, -4:] = 0
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, 1:] = 1
dummy_attention_mask[-1, :3] = 0
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1

for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}"
if is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size
Expand Down Expand Up @@ -4161,52 +4169,32 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)

_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]

if padding_side == "left":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)

sub_sdpa = logits_sdpa[-1, :-4]
sub_eager = logits_eager[-1, :-4]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)

# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, -4:]
# sub_eager = logits_eager[-1, -4:]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
elif padding_side == "right":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)

sub_sdpa = logits_sdpa[-1, 3:]
sub_eager = logits_eager[-1, 3:]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)

# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]

else:
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]

logits_sdpa = _logits_sdpa
logits_eager = _logits_eager

results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)

self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))

Expand Down
Loading