Skip to content

Conversation

DianjingLiu
Copy link
Contributor

Summary:
When setting use_cached_outputs=False, the LLMAttribution failed to run on some old versions of pytorch/python.

Error message

======================================================================
ERROR: test_llm_attr_hf_compatibility_0 (tests.attr.test_llm_attr_hf_compatibility.TestLLMAttrHFCompatibility_1_cpu)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/parameterized/parameterized.py", line 620, in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
  File "/data/users/liudj/captum/tests/attr/test_llm_attr_hf_compatibility.py", line 80, in test_llm_attr_hf_compatibility
    res = llm_attr.attribute(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 461, in attribute
    cur_attr = self.attr_method.attribute(
  File "/data/users/liudj/captum/captum/log/__init__.py", line 52, in wrapper
    return func(*args, **kwargs)
  File "/data/users/liudj/captum/captum/attr/_core/feature_ablation.py", line 292, in attribute
    initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
  File "/data/users/liudj/captum/captum/_utils/common.py", line 599, in _run_forward
    output = forward_func(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 335, in _forward_func
    outputs = self.model.forward(model_inp, **model_kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
    layer_outputs = decoder_layer(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 428, in forward
    attn_weights = attn_weights + causal_mask
RuntimeError: The size of tensor a (8) must match the size of tensor b (7) at non-singleton dimension 3

Root cause

The attention_mask was not updated to adapt to the growth of input size. Error message see test plan.

Impacted versions

  • Python 3.8-3.10, PyTorch 1.10-1.12, transformers 4.44.2
  • Python 3.8-3.11, PyTorch 1.13-2.1.0, transformers 4.44.2

{F1876426564}

Differential Revision: D63016032

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63016032

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63016032

DianjingLiu added a commit to DianjingLiu/captum that referenced this pull request Sep 19, 2024
Summary:
Pull Request resolved: pytorch#1353

When setting `use_cached_outputs=False`, the `LLMAttribution` failed to run on some old versions of pytorch/python.
## Error message
```
======================================================================
ERROR: test_llm_attr_hf_compatibility_0 (tests.attr.test_llm_attr_hf_compatibility.TestLLMAttrHFCompatibility_1_cpu)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/parameterized/parameterized.py", line 620, in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
  File "/data/users/liudj/captum/tests/attr/test_llm_attr_hf_compatibility.py", line 80, in test_llm_attr_hf_compatibility
    res = llm_attr.attribute(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 461, in attribute
    cur_attr = self.attr_method.attribute(
  File "/data/users/liudj/captum/captum/log/__init__.py", line 52, in wrapper
    return func(*args, **kwargs)
  File "/data/users/liudj/captum/captum/attr/_core/feature_ablation.py", line 292, in attribute
    initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
  File "/data/users/liudj/captum/captum/_utils/common.py", line 599, in _run_forward
    output = forward_func(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 335, in _forward_func
    outputs = self.model.forward(model_inp, **model_kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
    layer_outputs = decoder_layer(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 428, in forward
    attn_weights = attn_weights + causal_mask
RuntimeError: The size of tensor a (8) must match the size of tensor b (7) at non-singleton dimension 3
```

## Root cause
The `attention_mask` was not updated to adapt to the growth of input size. Error message see test plan.

## Impacted versions
- Python 3.8-3.10, PyTorch 1.10-1.12, transformers 4.44.2
- Python 3.8-3.11, PyTorch 1.13-2.1.0, transformers 4.44.2

{F1876426564}

Differential Revision: D63016032
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63016032

DianjingLiu added a commit to DianjingLiu/captum that referenced this pull request Sep 19, 2024
Summary:
Pull Request resolved: pytorch#1353

When setting `use_cached_outputs=False`, the `LLMAttribution` failed to run on some old versions of pytorch/python.
## Error message
```
======================================================================
ERROR: test_llm_attr_hf_compatibility_0 (tests.attr.test_llm_attr_hf_compatibility.TestLLMAttrHFCompatibility_1_cpu)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/parameterized/parameterized.py", line 620, in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
  File "/data/users/liudj/captum/tests/attr/test_llm_attr_hf_compatibility.py", line 80, in test_llm_attr_hf_compatibility
    res = llm_attr.attribute(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 461, in attribute
    cur_attr = self.attr_method.attribute(
  File "/data/users/liudj/captum/captum/log/__init__.py", line 52, in wrapper
    return func(*args, **kwargs)
  File "/data/users/liudj/captum/captum/attr/_core/feature_ablation.py", line 292, in attribute
    initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
  File "/data/users/liudj/captum/captum/_utils/common.py", line 599, in _run_forward
    output = forward_func(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 335, in _forward_func
    outputs = self.model.forward(model_inp, **model_kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
    layer_outputs = decoder_layer(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 428, in forward
    attn_weights = attn_weights + causal_mask
RuntimeError: The size of tensor a (8) must match the size of tensor b (7) at non-singleton dimension 3
```

## Root cause
The `attention_mask` was not updated to adapt to the growth of input size. Error message see test plan.

## Impacted versions
- Python 3.8-3.10, PyTorch 1.10-1.12, transformers 4.44.2
- Python 3.8-3.11, PyTorch 1.13-2.1.0, transformers 4.44.2

{F1876426564}

Reviewed By: vivekmig

Differential Revision: D63016032
Summary:
Pull Request resolved: pytorch#1353

When setting `use_cached_outputs=False`, the `LLMAttribution` failed to run on some old versions of pytorch/python.
## Error message
```
======================================================================
ERROR: test_llm_attr_hf_compatibility_0 (tests.attr.test_llm_attr_hf_compatibility.TestLLMAttrHFCompatibility_1_cpu)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/parameterized/parameterized.py", line 620, in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
  File "/data/users/liudj/captum/tests/attr/test_llm_attr_hf_compatibility.py", line 80, in test_llm_attr_hf_compatibility
    res = llm_attr.attribute(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 461, in attribute
    cur_attr = self.attr_method.attribute(
  File "/data/users/liudj/captum/captum/log/__init__.py", line 52, in wrapper
    return func(*args, **kwargs)
  File "/data/users/liudj/captum/captum/attr/_core/feature_ablation.py", line 292, in attribute
    initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
  File "/data/users/liudj/captum/captum/_utils/common.py", line 599, in _run_forward
    output = forward_func(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 335, in _forward_func
    outputs = self.model.forward(model_inp, **model_kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
    layer_outputs = decoder_layer(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 428, in forward
    attn_weights = attn_weights + causal_mask
RuntimeError: The size of tensor a (8) must match the size of tensor b (7) at non-singleton dimension 3
```

## Root cause
The `attention_mask` was not updated to adapt to the growth of input size. Error message see test plan.

## Impacted versions
- Python 3.8-3.10, PyTorch 1.10-1.12, transformers 4.44.2
- Python 3.8-3.11, PyTorch 1.13-2.1.0, transformers 4.44.2

{F1876426564}

Reviewed By: vivekmig

Differential Revision: D63016032
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63016032

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in fc910e5.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants