@@ -84,7 +84,7 @@ def get_bw_flops(model_fn):
8484 model_selective_ac ,
8585 ac_config_no_force ,
8686 model_compile_enabled = False ,
87- use_flex_attn = False ,
87+ attn_type = "sdpa" ,
8888 op_sac_save_list = _op_sac_save_list ,
8989 )
9090 flops_selective_ac = get_bw_flops (model_selective_ac )
@@ -102,7 +102,7 @@ def get_bw_flops(model_fn):
102102 model_with_force_first ,
103103 ac_config_with_force_first ,
104104 model_compile_enabled = False ,
105- use_flex_attn = False ,
105+ attn_type = "sdpa" ,
106106 op_sac_save_list = _op_sac_save_list ,
107107 )
108108 flops_with_force_first = get_bw_flops (model_with_force_first )
@@ -119,7 +119,7 @@ def get_bw_flops(model_fn):
119119 model_with_force_last ,
120120 ac_config_with_force_last ,
121121 model_compile_enabled = False ,
122- use_flex_attn = False ,
122+ attn_type = "sdpa" ,
123123 op_sac_save_list = _op_sac_save_list ,
124124 )
125125 flops_with_force_last = get_bw_flops (model_with_force_last )
@@ -134,7 +134,7 @@ def get_bw_flops(model_fn):
134134 model_with_full_ac ,
135135 ac_config_full_ac ,
136136 model_compile_enabled = False ,
137- use_flex_attn = False ,
137+ attn_type = "sdpa" ,
138138 op_sac_save_list = _op_sac_save_list ,
139139 )
140140 flops_full_ac = get_bw_flops (model_with_full_ac )
@@ -177,7 +177,7 @@ def get_act_mem(model_fn):
177177 model_selective_ac ,
178178 ac_config_no_force ,
179179 model_compile_enabled = False ,
180- use_flex_attn = False ,
180+ attn_type = "sdpa" ,
181181 op_sac_save_list = _op_sac_save_list ,
182182 )
183183 mem_selective_ac = get_act_mem (model_selective_ac )
@@ -194,7 +194,7 @@ def get_act_mem(model_fn):
194194 model_with_force_first ,
195195 ac_config_with_force_first ,
196196 model_compile_enabled = False ,
197- use_flex_attn = False ,
197+ attn_type = "sdpa" ,
198198 op_sac_save_list = _op_sac_save_list ,
199199 )
200200 mem_with_force_first = get_act_mem (model_with_force_first )
@@ -210,7 +210,7 @@ def get_act_mem(model_fn):
210210 model_with_force_last ,
211211 ac_config_with_force_last ,
212212 model_compile_enabled = False ,
213- use_flex_attn = False ,
213+ attn_type = "sdpa" ,
214214 op_sac_save_list = _op_sac_save_list ,
215215 )
216216 mem_with_force_last = get_act_mem (model_with_force_last )
@@ -224,7 +224,7 @@ def get_act_mem(model_fn):
224224 model_with_full_ac ,
225225 ac_config_full_ac ,
226226 model_compile_enabled = False ,
227- use_flex_attn = False ,
227+ attn_type = "sdpa" ,
228228 op_sac_save_list = _op_sac_save_list ,
229229 )
230230 mem_full_ac = get_act_mem (model_with_full_ac )
@@ -251,7 +251,7 @@ def test_correctness(self):
251251 per_op_sac_force_recompute_mm_shapes_by_fqns = [],
252252 ),
253253 model_compile_enabled = False ,
254- use_flex_attn = False ,
254+ attn_type = "sdpa" ,
255255 op_sac_save_list = _op_sac_save_list ,
256256 )
257257 model_force_first = ToyModule ()
@@ -264,7 +264,7 @@ def test_correctness(self):
264264 per_op_sac_force_recompute_mm_shapes_by_fqns = ["moe.router.gate" ],
265265 ),
266266 model_compile_enabled = False ,
267- use_flex_attn = False ,
267+ attn_type = "sdpa" ,
268268 op_sac_save_list = _op_sac_save_list ,
269269 )
270270
@@ -278,7 +278,7 @@ def test_correctness(self):
278278 per_op_sac_force_recompute_mm_shapes_by_fqns = ["output" ],
279279 ),
280280 model_compile_enabled = False ,
281- use_flex_attn = False ,
281+ attn_type = "sdpa" ,
282282 op_sac_save_list = _op_sac_save_list ,
283283 )
284284
0 commit comments