Skip to content

Commit 3e93569

Browse files
committed
Add sparse attention integration to llm_eval
Signed-off-by: Kai Xu <[email protected]>
1 parent b7cb433 commit 3e93569

File tree

24 files changed

+1851
-255
lines changed

24 files changed

+1851
-255
lines changed

examples/llm_eval/lm_eval_hf.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@
4343
from lm_eval.api.model import T
4444
from lm_eval.models.huggingface import HFLM
4545
from quantization_utils import quantize_model
46+
from sparse_attention_utils import sparsify_model
4647

4748
import modelopt.torch.opt as mto
4849
from modelopt.torch.quantization.utils import is_quantized
50+
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified
4951

5052

5153
def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T:
@@ -60,9 +62,20 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
6062
calib_size = arg_dict.pop("calib_size", 512)
6163
compress = arg_dict.pop("compress", False)
6264

65+
# Sparse attention arguments
66+
sparse_cfg = arg_dict.pop("sparse_cfg", None)
67+
6368
additional_config = {} if additional_config is None else additional_config
6469
additional_config = {k: v for k, v in additional_config.items() if v is not None}
6570

71+
# Force eager attention if sparse attention is requested
72+
if sparse_cfg:
73+
additional_config["attn_implementation"] = "eager"
74+
warnings.warn(
75+
"Sparse attention requires attn_implementation='eager'. "
76+
"Forcing eager attention implementation."
77+
)
78+
6679
# Enable automatic save/load of modelopt state huggingface checkpointing
6780
mto.enable_huggingface_checkpointing()
6881

@@ -91,6 +104,15 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
91104
auto_quantize_checkpoint=auto_quantize_checkpoint,
92105
)
93106

107+
if sparse_cfg:
108+
if is_attn_sparsified(model_obj.model):
109+
warnings.warn("Skipping sparse attention: model already has sparse attention applied.")
110+
else:
111+
sparsify_model(
112+
model=model_obj,
113+
sparse_cfg=sparse_cfg,
114+
)
115+
94116
return model_obj
95117

96118

@@ -152,6 +174,11 @@ def setup_parser_with_modelopt_args():
152174
action="store_true",
153175
help="Compress the model after quantization",
154176
)
177+
parser.add_argument(
178+
"--sparse_cfg",
179+
type=str,
180+
help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)",
181+
)
155182
return parser
156183

157184

@@ -177,6 +204,7 @@ def setup_parser_with_modelopt_args():
177204
"calib_batch_size": args.calib_batch_size,
178205
"calib_size": args.calib_size,
179206
"compress": args.compress,
207+
"sparse_cfg": args.sparse_cfg,
180208
}
181209
)
182210

examples/llm_eval/mmlu.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from fire import Fire
4949
from modeling import EvalModel, select_model
5050
from quantization_utils import MAX_SEQ_LEN, get_tokenizer, quantize_model
51+
from sparse_attention_utils import sparsify_model
5152
from tqdm import tqdm
5253

5354
try:
@@ -56,6 +57,7 @@
5657
LLM = None # type: ignore[misc]
5758
import modelopt.torch.opt as mto
5859
from modelopt.torch.quantization.utils import is_quantized
60+
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified
5961

6062
os.environ["TOKENIZERS_PARALLELISM"] = "false"
6163

@@ -230,6 +232,7 @@ def main(
230232
auto_quantize_method: str = "gradient",
231233
auto_quantize_score_size: int = 128,
232234
auto_quantize_checkpoint: str | None = None,
235+
sparse_cfg: str | None = None,
233236
**kwargs,
234237
):
235238
random.seed(RAND_SEED)
@@ -266,6 +269,14 @@ def main(
266269
max_batch_size=1,
267270
)
268271
else:
272+
# Force eager attention if sparse attention is requested
273+
if sparse_cfg:
274+
kwargs["attn_implementation"] = "eager"
275+
warnings.warn(
276+
"Sparse attention requires attn_implementation='eager'. "
277+
"Forcing eager attention implementation."
278+
)
279+
269280
model = select_model(
270281
max_input_length=MAX_SEQ_LEN, max_output_length=2, dtype=dtype, **kwargs
271282
)
@@ -289,6 +300,34 @@ def main(
289300
auto_quantize_checkpoint=auto_quantize_checkpoint,
290301
)
291302

303+
# Apply sparse attention if requested
304+
if sparse_cfg:
305+
model.load()
306+
307+
if is_attn_sparsified(model.model):
308+
warnings.warn(
309+
"Skipping sparse attention: model already has sparse attention applied."
310+
)
311+
else:
312+
sparsify_model(
313+
model=model,
314+
sparse_cfg=sparse_cfg,
315+
)
316+
317+
# Apply sparse attention if requested
318+
if sparse_cfg:
319+
model.load()
320+
321+
if is_attn_sparsified(model.model):
322+
warnings.warn(
323+
"Skipping sparse attention: model already has sparse attention applied."
324+
)
325+
else:
326+
sparsify_model(
327+
model=model,
328+
sparse_cfg=sparse_cfg,
329+
)
330+
292331
for subject in tqdm(subjects):
293332
dev_df = pd.read_csv(os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None)[
294333
:ntrain

examples/llm_eval/modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class SeqToSeqModel(EvalModel):
179179
lora_path: str = ""
180180
device: str = "cuda"
181181
load_8bit: bool = False
182+
attn_implementation: str | None = None
182183

183184
def load(self):
184185
if self.model is None:
@@ -188,6 +189,8 @@ def load(self):
188189
if self.load_8bit:
189190
args.update(device_map="auto", load_in_8bit=True)
190191
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
192+
if self.attn_implementation:
193+
args["attn_implementation"] = self.attn_implementation
191194
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path, **args)
192195
print_gpu_utilization()
193196
if self.lora_path:
@@ -241,6 +244,8 @@ def load(self):
241244
if self.load_8bit:
242245
args.update(device_map="auto", load_in_8bit=True)
243246
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
247+
if self.attn_implementation:
248+
args["attn_implementation"] = self.attn_implementation
244249
self.model = AutoModelForCausalLM.from_pretrained(
245250
self.model_path, trust_remote_code=True, **args
246251
)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utilities for sparse attention integration with llm_eval."""
17+
18+
import modelopt.torch.sparsity.attention_sparsity as mtsa
19+
20+
# Custom sparse attention configurations
21+
CUSTOM_SPARSE_CONFIG = {
22+
"SPARSE_CONSERVATIVE": {
23+
"sparse_cfg": {
24+
"*attn*": {
25+
"method": "flash_skip_softmax",
26+
"threshold": {"prefill": 5e-4, "decode": 1e-5},
27+
"br": 128,
28+
"bc": 128,
29+
"backend": "pytorch",
30+
"enable": True,
31+
},
32+
"default": {"enable": False},
33+
},
34+
},
35+
"SPARSE_AGGRESSIVE": {
36+
"sparse_cfg": {
37+
"*attn*": {
38+
"method": "flash_skip_softmax",
39+
"threshold": {"prefill": 5e-3, "decode": 5e-4},
40+
"br": 128,
41+
"bc": 128,
42+
"backend": "pytorch",
43+
"enable": True,
44+
},
45+
"default": {"enable": False},
46+
},
47+
},
48+
}
49+
50+
51+
def _extract_model(model_obj):
52+
"""Extract actual model from wrapper (HFLM or EvalModel)."""
53+
if hasattr(model_obj, "gpt2"):
54+
return model_obj.gpt2
55+
elif hasattr(model_obj, "model"):
56+
return model_obj.model
57+
else:
58+
return model_obj
59+
60+
61+
def sparsify_model(
62+
model,
63+
sparse_cfg: str,
64+
backend=None,
65+
):
66+
"""Apply sparse attention to model with optional RULER calibration.
67+
68+
Args:
69+
model: Model wrapper (HFLM or EvalModel) or raw model
70+
sparse_cfg: Sparse attention config name or dict
71+
backend: Backend to use (optional, overrides config backend)
72+
73+
Returns:
74+
The model with sparse attention applied
75+
76+
Note:
77+
Calibration is automatically triggered if the config contains a 'calibration' field.
78+
The calibration will auto-generate RULER dataset from the model's tokenizer.
79+
"""
80+
# Extract actual model
81+
net = _extract_model(model)
82+
83+
# Resolve config
84+
if isinstance(sparse_cfg, str):
85+
# Try custom configs first
86+
mtsa_cfg = CUSTOM_SPARSE_CONFIG.get(sparse_cfg)
87+
if mtsa_cfg is None:
88+
# Try predefined configs
89+
mtsa_cfg = getattr(mtsa, sparse_cfg, None)
90+
if mtsa_cfg is None:
91+
raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}")
92+
else:
93+
mtsa_cfg = sparse_cfg
94+
95+
# Override backend if specified
96+
if backend:
97+
if isinstance(mtsa_cfg, dict) and "sparse_cfg" in mtsa_cfg:
98+
modified_sparse_cfg = {}
99+
for pattern, cfg in mtsa_cfg["sparse_cfg"].items():
100+
modified_cfg = cfg.copy() if isinstance(cfg, dict) else cfg
101+
if isinstance(modified_cfg, dict):
102+
modified_cfg["backend"] = backend
103+
modified_sparse_cfg[pattern] = modified_cfg
104+
mtsa_cfg = {"sparse_cfg": modified_sparse_cfg}
105+
106+
# Apply sparsification
107+
print(f"\nApplying sparse attention with config: {sparse_cfg}")
108+
mtsa.sparsify(net, mtsa_cfg)
109+
print("Sparse attention applied successfully!")
110+
111+
return model

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import modelopt.torch.opt as mto
2929
import modelopt.torch.sparsity.attention_sparsity as mtsa
3030
from modelopt.torch.export import export_hf_checkpoint
31-
from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig
3231
from modelopt.torch.sparsity.attention_sparsity.config import (
3332
SKIP_SOFTMAX_CALIB,
3433
SKIP_SOFTMAX_DEFAULT,
@@ -196,29 +195,6 @@ def generate_text(model, inputs, args, tokenizer):
196195
print("\nOutputs differ")
197196

198197

199-
def sparsify_model(model, args):
200-
"""Apply sparse attention to the model with optional calibration."""
201-
print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}")
202-
base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
203-
204-
# Create modified config with selected backend
205-
modified_sparse_cfg = {}
206-
for pattern, cfg in base_config["sparse_cfg"].items():
207-
modified_cfg = cfg.copy()
208-
modified_cfg["backend"] = args.backend
209-
modified_sparse_cfg[pattern] = modified_cfg
210-
211-
# Create new config with modified settings
212-
sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg)
213-
214-
# Sparsify the model
215-
model = mtsa.sparsify(model, config=sparse_config)
216-
217-
print("Sparse attention applied successfully!")
218-
219-
return model
220-
221-
222198
def main(args):
223199
"""Main function to run the selected mode."""
224200
if not torch.cuda.is_available():
@@ -249,8 +225,12 @@ def main(args):
249225
model = model.cuda()
250226
print("Model moved to CUDA")
251227

252-
# Apply sparse attention to the model (with calibration if configured)
253-
model = sparsify_model(model, args)
228+
# Apply sparse attention with optional calibration
229+
print(f"\nApplying sparse attention: {args.sparse_attn}")
230+
sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
231+
model = mtsa.sparsify(model, config=sparse_config)
232+
233+
print("Sparse attention applied successfully!")
254234

255235
# Verify outputs if requested (compares baseline vs calibrated sparse model)
256236
if args.verify_output:
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
nltk
2+
wonderwords

0 commit comments

Comments
 (0)