Skip to content

Commit b7cb433

Browse files
committed
Add sparsity ratio calibration for skip softmax
Signed-off-by: Kai Xu <[email protected]>
1 parent 2ea0b35 commit b7cb433

File tree

13 files changed

+2298
-10
lines changed

13 files changed

+2298
-10
lines changed

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
import modelopt.torch.sparsity.attention_sparsity as mtsa
3030
from modelopt.torch.export import export_hf_checkpoint
3131
from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig
32-
from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT
32+
from modelopt.torch.sparsity.attention_sparsity.config import (
33+
SKIP_SOFTMAX_CALIB,
34+
SKIP_SOFTMAX_DEFAULT,
35+
)
3336
from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule
3437
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
3538

@@ -38,9 +41,10 @@
3841
# Enable HuggingFace checkpointing support
3942
mto.enable_huggingface_checkpointing()
4043

41-
# You can define custom configurations or use the default
44+
# Sparse attention configuration choices
4245
SPARSE_ATTN_CFG_CHOICES = {
4346
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
47+
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
4448
}
4549

4650

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Calibration framework for sparse attention methods."""
17+
18+
from .calibrate import calibrate_sparse_attention
19+
from .calibrator import DynamicThresholdCalibrator
20+
from .dataset import RulerDatasetBuilder
21+
22+
__all__ = [
23+
"DynamicThresholdCalibrator",
24+
"RulerDatasetBuilder",
25+
"calibrate_sparse_attention",
26+
]
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Calibration functions for sparse attention."""
17+
18+
import warnings
19+
from collections.abc import Callable
20+
from typing import Any
21+
22+
import torch
23+
import torch.nn as nn
24+
from transformers import AutoTokenizer
25+
26+
from ..config import CalibrationConfig
27+
from ..sparse_attention import SparseAttentionModule
28+
from .calibrator import DynamicThresholdCalibrator
29+
from .dataset import RulerDatasetBuilder
30+
31+
32+
def _extract_tokenizer_from_model(model: nn.Module) -> str:
33+
"""Extract tokenizer name/path from model config.
34+
35+
Args:
36+
model: Model to extract tokenizer from
37+
38+
Returns:
39+
Tokenizer name or path
40+
41+
Raises:
42+
ValueError: If tokenizer path cannot be determined from model
43+
"""
44+
# Extract tokenizer path from model config
45+
tokenizer_path = getattr(getattr(model, "config", None), "_name_or_path", None)
46+
47+
if not tokenizer_path:
48+
raise ValueError("Could not load tokenizer from model.")
49+
50+
return tokenizer_path
51+
52+
53+
def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | None:
54+
"""Extract and validate calibration config from sparse_cfg patterns.
55+
56+
Args:
57+
config: Sparse attention configuration dict
58+
59+
Returns:
60+
Validated CalibrationConfig or None if not found
61+
"""
62+
# Extract sparse_cfg and search for calibration
63+
sparse_cfg = config.get("sparse_cfg", {})
64+
65+
calib_dict = next(
66+
(
67+
cfg["calibration"]
68+
for cfg in sparse_cfg.values()
69+
if isinstance(cfg, dict) and "calibration" in cfg
70+
),
71+
None,
72+
)
73+
74+
# Create and calidate the calibration config
75+
return CalibrationConfig(**calib_dict) if calib_dict else None
76+
77+
78+
def create_calibration_forward_loop(
79+
calibration_data: list[dict[str, Any]],
80+
tokenizer_name_or_path: str,
81+
batch_size: int = 1,
82+
) -> Callable:
83+
"""Create forward loop for calibration.
84+
85+
Args:
86+
calibration_data: List of samples with 'input' and 'length' fields
87+
tokenizer_name_or_path: HuggingFace tokenizer path
88+
batch_size: Batch size (currently unused, always 1)
89+
90+
Returns:
91+
Forward loop function that takes model as argument
92+
"""
93+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
94+
if not tokenizer.pad_token:
95+
tokenizer.pad_token = tokenizer.eos_token
96+
97+
def forward_loop(model: nn.Module) -> None:
98+
device = next(model.parameters()).device
99+
100+
for sample in calibration_data:
101+
inputs = tokenizer(
102+
sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"]
103+
)
104+
inputs = {k: v.to(device) for k, v in inputs.items()}
105+
106+
with torch.no_grad():
107+
model(**inputs)
108+
109+
return forward_loop
110+
111+
112+
def calibrate_sparse_attention(
113+
model: nn.Module,
114+
config: dict[str, Any],
115+
forward_loop: Callable | None = None,
116+
) -> dict[str, Any]:
117+
"""Calibrate sparse attention parameters for optimal sparsity.
118+
119+
Args:
120+
model: Model with sparse attention modules
121+
config: Sparse attention configuration dict
122+
forward_loop: Callable that forwards calibration data through model.
123+
If None, auto-generates RULER dataset.
124+
125+
Returns:
126+
Dictionary with calibration results
127+
"""
128+
# Extract and validate calibration config
129+
calib_config = _extract_calibration_config(config)
130+
if not calib_config:
131+
return {}
132+
133+
# Generate forward_loop if not provided
134+
if not forward_loop:
135+
tokenizer = _extract_tokenizer_from_model(model)
136+
builder = RulerDatasetBuilder(
137+
samples=calib_config.samples,
138+
max_seqlen=calib_config.max_seqlen,
139+
tokenizer_name_or_path=tokenizer,
140+
num_length_bins=calib_config.num_length_bins,
141+
max_length_filter=int(calib_config.max_seqlen * 1.2),
142+
)
143+
calibration_data = builder.build_calibration_dataset()
144+
print(f"Generated {len(calibration_data)} calibration samples")
145+
forward_loop = create_calibration_forward_loop(calibration_data, tokenizer)
146+
147+
# Get sparse attention modules
148+
sparse_modules = [
149+
(name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule)
150+
]
151+
152+
if not sparse_modules:
153+
print("No sparse attention modules found for calibration")
154+
return {}
155+
156+
print(f"Calibrating {len(sparse_modules)} sparse attention modules together...")
157+
158+
# Run calibration
159+
calibrator = DynamicThresholdCalibrator(
160+
target_sparse_ratio=calib_config.target_sparse_ratio,
161+
threshold_trials=calib_config.threshold_trials,
162+
)
163+
calibration_result = calibrator.calibrate(model, forward_loop)
164+
165+
if "scale_factor" not in calibration_result:
166+
warnings.warn("Calibration did not produce valid results")
167+
return {}
168+
169+
# Apply calibrated scale factor to all modules
170+
scale_factor = calibration_result["scale_factor"]
171+
print(f"\nApplying calibrated scale factor={scale_factor:.6f} to {len(sparse_modules)} modules")
172+
173+
for module_name, module in sparse_modules:
174+
module._sparse_method_instance.threshold_scale_factor = scale_factor
175+
176+
return {"calibration_results": {name: calibration_result for name, _ in sparse_modules}}

0 commit comments

Comments
 (0)