Skip to content

Commit 49ab28d

Browse files
Enable SNIP on multiple cards using DeepSpeed ZeRO-3 (#1492)
Signed-off-by: yiliu30 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 061884d commit 49ab28d

File tree

17 files changed

+447
-59
lines changed

17 files changed

+447
-59
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/pruning/magnitude/README.md renamed to examples/pytorch/nlp/huggingface_models/language-modeling/pruning/multi_cards/README.md

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Step-by-Step
22
============
33

4-
# single GPU
4+
# Single GPU
55

66
```
77
export CUDA_VISIBLE_DEVICES=0
@@ -15,10 +15,11 @@ bash run.sh \
1515
--pruning_frequency=1000
1616
```
1717

18-
# multi GPU
18+
# Multi GPU
1919

20-
we use `accelerate` and `deepspeed ZeRO Stage-2` to conduct weight magnitude pruning
20+
We use `accelerate` and `deepspeed ZeRO` to conduct weight magnitude, snip pruning. Below are two usage examples: 1) magnitude pruning with ZeRO Stage-2, and 2) snip pruning with ZeRO Stage-3.
2121

22+
## Magnitude pruning with ZeRO Stage-2
2223
### Accelerate DeepSpeed Plugin
2324

2425
On your machine(s) just run:
@@ -105,3 +106,82 @@ bash run_ds.sh \
105106
--pruning_pattern=4x1 \
106107
--pruning_frequency=1000
107108
```
109+
110+
111+
## SNIP pruning with ZeRO Stage-3
112+
113+
To specify the accelerate use DeepSpeed ZeRO Stage-3. On your machine(s) just run:
114+
``` shell
115+
accelerate config
116+
117+
compute_environment: LOCAL_MACHINE
118+
deepspeed_config:
119+
deepspeed_config_file: config/zero_stage3_config.json
120+
zero3_init_flag: true
121+
distributed_type: DEEPSPEED
122+
fsdp_config: {}
123+
machine_rank: 0
124+
main_process_ip: null
125+
main_process_port: null
126+
main_training_function: main
127+
mixed_precision: fp16
128+
num_machines: 1
129+
num_processes: 2
130+
use_cpu: false
131+
```
132+
with the contents of `config/zero_stage3_config.json` being:
133+
134+
```
135+
{
136+
"train_batch_size": 64,
137+
"train_micro_batch_size_per_gpu": 8,
138+
"gradient_accumulation_steps": 4,
139+
"fp16": {
140+
"enabled": true,
141+
"min_loss_scale": 1,
142+
"opt_level": "O2"
143+
},
144+
"zero_optimization": {
145+
"stage": 3,
146+
"allgather_partitions": true,
147+
"allgather_bucket_size": 5e8,
148+
"contiguous_gradients": true
149+
},
150+
"optimizer": {
151+
"type": "AdamW",
152+
"params": {
153+
"lr": "auto",
154+
"torch_adam": true,
155+
"adam_w_mode": true
156+
}
157+
},
158+
"scheduler": {
159+
"type": "WarmupDecayLR",
160+
"params": {
161+
"warmup_min_lr": 0.0,
162+
"warmup_max_lr": "auto",
163+
"warmup_num_steps": "auto",
164+
"total_num_steps": "auto",
165+
"warmup_type": "cosine"
166+
}
167+
}
168+
}
169+
```
170+
171+
### Pruning
172+
> Note: As the ZeRO Stage-3 partitions all three model states(optimizer states, gradients, and parameters), please specify the `pruning_scope` as `local`. Choosing `global` requires gathering all parameters to update the mask, which compromises the benefits of ZeRO Stage-3.
173+
174+
175+
```
176+
# 2 gpu cards example
177+
export CUDA_VISIBLE_DEVICES=0,1 USE_DEEPSPEED=1
178+
bash run_ds_z3.sh \
179+
--model_name_or_path=facebook/opt-125m \
180+
--dataset_name=NeelNanda/pile-10k \
181+
--block_size=128 \
182+
--output_dir=./test-clm \
183+
--pruning_type=snip_momentum \
184+
--pruning_scope=local \
185+
--pruning_pattern=4x1 \
186+
--pruning_frequency=1000
187+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"train_batch_size": 64,
3+
"train_micro_batch_size_per_gpu": 8,
4+
"gradient_accumulation_steps": 4,
5+
"fp16": {
6+
"enabled": true,
7+
"min_loss_scale": 1,
8+
"opt_level": "O2"
9+
},
10+
"zero_optimization": {
11+
"stage": 3,
12+
"allgather_partitions": true,
13+
"allgather_bucket_size": 5e8,
14+
"contiguous_gradients": true
15+
},
16+
"optimizer": {
17+
"type": "AdamW",
18+
"params": {
19+
"lr": "auto",
20+
"torch_adam": true,
21+
"adam_w_mode": true
22+
}
23+
},
24+
"scheduler": {
25+
"type": "WarmupDecayLR",
26+
"params": {
27+
"warmup_min_lr": 0.0,
28+
"warmup_max_lr": "auto",
29+
"warmup_num_steps": "auto",
30+
"total_num_steps": "auto",
31+
"warmup_type": "cosine"
32+
}
33+
}
34+
}
Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,13 @@ def parse_args():
274274
help="pruning criteria to use.",
275275
choices=["magnitude", "snip", "snip_momentum"],
276276
)
277+
parser.add_argument(
278+
"--pruning_scope",
279+
type=str,
280+
default="global",
281+
help="determine layers' scores should be gather together to sort.",
282+
choices=["local", "global"],
283+
)
277284
parser.add_argument(
278285
"--warm_epochs",
279286
type=int,
@@ -688,7 +695,7 @@ def group_texts(examples):
688695
pruning_configs=[
689696
{
690697
"pruning_type": args.pruning_type,
691-
"pruning_scope": "global",
698+
"pruning_scope": args.pruning_scope,
692699
"sparsity_decay_type": "exp",
693700
"excluded_op_names": ["pooler"],
694701
"pruning_op_types": ["Linear"],
@@ -800,7 +807,8 @@ def group_texts(examples):
800807

801808
if args.output_dir is not None:
802809
accelerator.wait_for_everyone()
803-
unwrapped_model = accelerator.unwrap_model(model)
810+
# fetch the ds model from inc model
811+
unwrapped_model = accelerator.unwrap_model(model.model)
804812
unwrapped_model.save_pretrained(
805813
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
806814
)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#!/bin/bash
2+
set -x
3+
4+
function main {
5+
6+
init_params "$@"
7+
run_pruning
8+
9+
}
10+
11+
# init params
12+
function init_params {
13+
dataset_name="NeelNanda/pile-10k"
14+
model_name_or_path="facebook/opt-125m"
15+
output_dir="./test-clm"
16+
per_device_train_batch_size=8
17+
block_size=128
18+
gradient_accumulation_steps=4
19+
num_train_epochs=3
20+
target_sparsity=0.8
21+
pruning_type="snip_momentum"
22+
pruning_scope="local"
23+
pruning_pattern="4x1"
24+
pruning_frequency=1000
25+
for var in "$@"
26+
do
27+
case $var in
28+
--dataset_name=*)
29+
dataset_name=$(echo $var |cut -f2 -d=)
30+
;;
31+
--model_name_or_path=*)
32+
model_name_or_path=$(echo $var |cut -f2 -d=)
33+
;;
34+
--output_dir=*)
35+
output_dir=$(echo $var |cut -f2 -d=)
36+
;;
37+
--per_device_train_batch_size=*)
38+
per_device_train_batch_size=$(echo $var |cut -f2 -d=)
39+
;;
40+
--block_size=*)
41+
block_size=$(echo $var |cut -f2 -d=)
42+
;;
43+
--gradient_accumulation_steps=*)
44+
gradient_accumulation_steps=$(echo $var |cut -f2 -d=)
45+
;;
46+
--num_train_epochs=*)
47+
num_train_epochs=$(echo $var |cut -f2 -d=)
48+
;;
49+
--target_sparsity=*)
50+
target_sparsity=$(echo $var |cut -f2 -d=)
51+
;;
52+
--pruning_type=*)
53+
pruning_type=$(echo $var |cut -f2 -d=)
54+
;;
55+
--pruning_scope=*)
56+
pruning_scope=$(echo $var |cut -f2 -d=)
57+
;;
58+
--pruning_pattern=*)
59+
pruning_pattern=$(echo $var |cut -f2 -d=)
60+
;;
61+
--pruning_frequency=*)
62+
pruning_frequency=$(echo $var |cut -f2 -d=)
63+
;;
64+
*)
65+
echo "Error: No such parameter: ${var}"
66+
exit 1
67+
;;
68+
esac
69+
done
70+
71+
}
72+
73+
# run_tuning
74+
function run_pruning {
75+
accelerate launch --deepspeed_config_file config/ds_config.json --mixed_precision fp16 \
76+
run_clm_no_trainer_deepspeed.py \
77+
--dataset_name $dataset_name \
78+
--model_name_or_path $model_name_or_path \
79+
--block_size $block_size \
80+
--per_device_train_batch_size $per_device_train_batch_size \
81+
--gradient_accumulation_steps $gradient_accumulation_steps \
82+
--output_dir $output_dir \
83+
--do_prune \
84+
--num_train_epochs $num_train_epochs \
85+
--target_sparsity $target_sparsity \
86+
--pruning_type $pruning_type \
87+
--pruning_scope $pruning_scope \
88+
--pruning_pattern $pruning_pattern \
89+
--pruning_frequency $pruning_frequency
90+
91+
}
92+
93+
main "$@"
94+

neural_compressor/compression/pruner/criteria.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818

19-
from .utils import torch
19+
from .utils import safe_get_data, safe_get_grad, safe_get_shape, torch
2020

2121
CRITERIA = {}
2222

@@ -96,7 +96,8 @@ def on_step_begin(self):
9696
"""Calculate and store the pruning scores based on a magnitude criterion."""
9797
with torch.no_grad():
9898
for key in self.modules.keys():
99-
p = self.modules[key].weight.data
99+
param = self.modules[key].weight
100+
p = safe_get_data(param)
100101
if hasattr(self.pattern, "reduce_score"):
101102
self.scores[key] = self.pattern.reduce_score(torch.abs(p), key)
102103
else:
@@ -161,12 +162,15 @@ def on_before_optimizer_step(self):
161162
"""Calculate and store the pruning scores based on snip criterion."""
162163
with torch.no_grad():
163164
for key in self.modules.keys():
164-
p = self.modules[key].weight
165+
# p = self.modules[key].weight
166+
param = self.modules[key].weight
167+
data = safe_get_data(param)
168+
grad = safe_get_grad(param)
165169
# self.scores[key] = torch.abs(p * p.grad)
166170
if hasattr(self.pattern, "reduce_score"):
167-
self.scores[key] = self.pattern.reduce_score(torch.abs(p * p.grad), key)
171+
self.scores[key] = self.pattern.reduce_score(torch.abs(data * grad), key)
168172
else:
169-
self.scores[key] = torch.abs(p * p.grad)
173+
self.scores[key] = torch.abs(data * grad)
170174

171175

172176
@register_criterion("snip_momentum")
@@ -191,15 +195,19 @@ def __init__(self, modules, config, pattern):
191195
super(SnipMomentumCriterion, self).__init__(modules, config, pattern)
192196
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"
193197
for key in modules.keys():
194-
p = modules[key].weight
198+
param = modules[key].weight
199+
# p = modules[key].weight
200+
param_shape = safe_get_shape(param)
195201
dtype = torch.float32
196202
if self.low_memory_usage:
197-
dtype = torch.bfloat16 if p.device.type == "cpu" else torch.float16
203+
dtype = torch.bfloat16 if param.device.type == "cpu" else torch.float16
198204
# self.scores[key] = torch.zeros(p.shape, dtype=dtype).to(p.device)
199205
if hasattr(self.pattern, "reduce_score"):
200-
self.scores[key] = self.pattern.reduce_score(torch.zeros(p.shape, dtype=dtype).to(p.device), key)
206+
self.scores[key] = self.pattern.reduce_score(
207+
torch.zeros(param_shape, dtype=dtype).to(param.device), key
208+
)
201209
else:
202-
self.scores[key] = torch.zeros(p.shape, dtype=dtype).to(p.device)
210+
self.scores[key] = torch.zeros(param_shape, dtype=dtype).to(param.device)
203211

204212
self.alpha = 0.9
205213
self.beta = 1.0
@@ -209,8 +217,11 @@ def on_before_optimizer_step(self):
209217
with torch.no_grad():
210218
for key in self.modules.keys():
211219
p = self.modules[key].weight
220+
param = self.modules[key].weight
221+
data = safe_get_data(param)
222+
grad = safe_get_grad(param)
212223
self.scores[key] *= self.alpha
213-
tmp = torch.abs(p * p.grad)
224+
tmp = torch.abs(data * grad)
214225
if hasattr(self.pattern, "reduce_score"):
215226
tmp = self.pattern.reduce_score(tmp, key, force=True)
216227
if self.low_memory_usage:

neural_compressor/compression/pruner/patterns/base.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import numpy as np
2222

23-
from ..utils import tf, torch
23+
from ..utils import safe_get_data, safe_get_grad, safe_get_shape, tf, torch
2424

2525
PATTERNS = {}
2626

@@ -75,12 +75,18 @@ def _reshape_2dims_to_orig(data, orig_shape):
7575
Returns:
7676
Reshaped data.
7777
"""
78-
if len(orig_shape) == 4:
78+
if len(orig_shape) == 2:
79+
return data
80+
elif len(orig_shape) == 4:
7981
data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1])
8082
data = data.permute(0, 3, 1, 2)
81-
if len(orig_shape) == 3:
83+
elif len(orig_shape) == 3:
8284
data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[1])
8385
data = data.permute(0, 2, 1)
86+
elif len(orig_shape) == 1:
87+
data = data.reshape(orig_shape)
88+
else:
89+
raise NotImplementedError(f"not support {data.shape}")
8490
return data
8591

8692
# some util functions which can be used.
@@ -601,12 +607,16 @@ def get_pattern_lock_masks(self, modules):
601607
"""
602608
pattern_lock_masks = {}
603609
for key in modules.keys():
604-
weight = modules[key].weight
605-
shape = weight.shape
610+
# weight = modules[key].weight
611+
# shape = weight.shape
612+
param = modules[key].weight
613+
data = safe_get_data(param)
614+
shape = safe_get_shape(param)
606615
mask = torch.ones(shape)
607-
mask[weight == 0] = 0.0
616+
# mask[weight == 0] = 0.0
617+
mask[data == 0] = 0.0
608618
mask = mask.bool()
609-
pattern_lock_masks[key] = mask.to(weight.device)
619+
pattern_lock_masks[key] = mask.to(param.device)
610620

611621
return pattern_lock_masks
612622

0 commit comments

Comments
 (0)