Skip to content

Commit 8b56090

Browse files
committed
[release/2.6][SWDEV-523736] Fix review observations
1 parent d1855a8 commit 8b56090

File tree

3 files changed

+7
-11
lines changed

3 files changed

+7
-11
lines changed

test/dynamo/test_repros.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@
4747
parametrize,
4848
skipIfWindows,
4949
TEST_WITH_ROCM,
50-
skipIfRocmArch,
51-
NAVI44_ARCH,
5250
)
5351
from torch.testing._internal.two_tensor import TwoTensor
5452

@@ -6410,7 +6408,7 @@ def fn(x):
64106408
self.assertEqual(fn(inp), opt_fn(inp))
64116409

64126410
@requires_cuda
6413-
@skipIfRocmArch(NAVI44_ARCH)
6411+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
64146412
def test_sdpa_dynamic_shapes(self):
64156413
def f(x, s0, s1, s2):
64166414
q = x.view(2, s0, s2, s0)

torch/testing/_internal/common_cuda.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,16 @@ def evaluate_gfx_arch_within(arch_list):
4343
effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
4444
# gcnArchName can be complicated strings like gfx90a:sramecc+:xnack-
4545
# Hence the matching should be done reversely
46-
result = any(arch in effective_arch for arch in arch_list)
47-
48-
if result and gcn_arch_name == "gfx1201":
49-
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
50-
51-
return result
46+
return any(arch in effective_arch for arch in arch_list)
5247

5348
def CDNA2OrLater():
5449
return evaluate_gfx_arch_within(["gfx90a", "gfx942"])
5550

5651
def evaluate_platform_supports_flash_attention():
5752
if TEST_WITH_ROCM:
58-
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201"]
53+
arch_list = ["gfx90a", "gfx942", "gfx1100"]
54+
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
55+
arch_list += ["gfx1201", "gfx950"]
5956
return evaluate_gfx_arch_within(arch_list)
6057
if TEST_CUDA:
6158
return not IS_WINDOWS and SM80OrLater
@@ -64,6 +61,8 @@ def evaluate_platform_supports_flash_attention():
6461
def evaluate_platform_supports_efficient_attention():
6562
if TEST_WITH_ROCM:
6663
arch_list = ["gfx90a", "gfx942", "gfx1100"]
64+
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
65+
arch_list += ["gfx1201", "gfx950"]
6766
return evaluate_gfx_arch_within(arch_list)
6867
if TEST_CUDA:
6968
return True

torch/testing/_internal/common_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@
111111
NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")
112112
NAVI3_ARCH = ("gfx1100", "gfx1101")
113113
NAVI4_ARCH = ("gfx1200", "gfx1201")
114-
NAVI44_ARCH = "gfx1200"
115114

116115
def is_navi3_arch():
117116
if torch.cuda.is_available():

0 commit comments

Comments
 (0)