@@ -236,7 +236,7 @@ def _apply_ac_to_transformer_block(
236236 * ,
237237 base_fqn : str | None = None ,
238238 model_compile_enabled : bool = False ,
239- attn_type : str = "sdpa" ,
239+ use_flex_attn : bool = False ,
240240 op_sac_save_list : set [torch ._ops .OpOverload ] | None = None ,
241241) -> nn .Module :
242242 valid_ac_modes = ("full" , "selective" )
@@ -259,7 +259,7 @@ def _apply_ac_to_transformer_block(
259259
260260 if use_op_sac :
261261 op_sac_save_list = op_sac_save_list or set ()
262- if attn_type == "flex" :
262+ if use_flex_attn :
263263 """
264264 For Flex Attention, we need to apply SAC carefully to avoid invalidating
265265 torch.compile. Any torch.compile inside the SAC region will be ignored,
@@ -288,7 +288,7 @@ def apply_ac(
288288 ac_config : ACConfig ,
289289 * ,
290290 model_compile_enabled : bool = False ,
291- attn_type : str = "sdpa" ,
291+ use_flex_attn : bool = False ,
292292 op_sac_save_list : set [torch ._ops .OpOverload ] | None = None ,
293293 base_folder : str = "" ,
294294) -> None :
@@ -302,7 +302,7 @@ def apply_ac(
302302 model (nn.Module): The model to apply activation checkpointing to.
303303 ac_config (ACConfig): The activation checkpointing config.
304304 model_compile_enabled (bool): Whether torch.compile is enabled for the model.
305- attn_type (str ): Attention type (one of [sdpa, varlen, flex])
305+ use_flex_attn (bool ): Whether flex attention is enabled for the model.
306306 op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
307307 of recomputing.
308308 Returns:
@@ -326,7 +326,7 @@ def apply_ac(
326326 ac_config ,
327327 base_fqn = f"layers.{ layer_id } " ,
328328 model_compile_enabled = model_compile_enabled ,
329- attn_type = attn_type ,
329+ use_flex_attn = use_flex_attn ,
330330 op_sac_save_list = op_sac_save_list ,
331331 )
332332 model .layers .register_module (layer_id , transformer_block )
0 commit comments