@@ -43,19 +43,16 @@ def evaluate_gfx_arch_within(arch_list):
43
43
effective_arch = os .environ .get ('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE' , gcn_arch_name )
44
44
# gcnArchName can be complicated strings like gfx90a:sramecc+:xnack-
45
45
# Hence the matching should be done reversely
46
- result = any (arch in effective_arch for arch in arch_list )
47
-
48
- if result and gcn_arch_name == "gfx1201" :
49
- os .environ ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL' ] = '1'
50
-
51
- return result
46
+ return any (arch in effective_arch for arch in arch_list )
52
47
53
48
def CDNA2OrLater ():
54
49
return evaluate_gfx_arch_within (["gfx90a" , "gfx942" ])
55
50
56
51
def evaluate_platform_supports_flash_attention ():
57
52
if TEST_WITH_ROCM :
58
- arch_list = ["gfx90a" , "gfx942" , "gfx1100" , "gfx1201" ]
53
+ arch_list = ["gfx90a" , "gfx942" , "gfx1100" ]
54
+ if os .environ .get ("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL" , "0" ) != "0" :
55
+ arch_list += ["gfx1201" , "gfx950" ]
59
56
return evaluate_gfx_arch_within (arch_list )
60
57
if TEST_CUDA :
61
58
return not IS_WINDOWS and SM80OrLater
@@ -64,6 +61,8 @@ def evaluate_platform_supports_flash_attention():
64
61
def evaluate_platform_supports_efficient_attention ():
65
62
if TEST_WITH_ROCM :
66
63
arch_list = ["gfx90a" , "gfx942" , "gfx1100" ]
64
+ if os .environ .get ("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL" , "0" ) != "0" :
65
+ arch_list += ["gfx1201" , "gfx950" ]
67
66
return evaluate_gfx_arch_within (arch_list )
68
67
if TEST_CUDA :
69
68
return True
0 commit comments