From 0a775c4a0de3f6e021a24d102c031e8d43d55b59 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Thu, 13 Mar 2025 11:23:47 +0000 Subject: [PATCH 1/5] [ROCm] Experimental flag for flex attn exhaustive tuning --- torch/_inductor/kernel/flex_attention.py | 232 ++++++++++++++++------- 1 file changed, 166 insertions(+), 66 deletions(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 2318be5c423e57..67a71867c950fd 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,5 +1,7 @@ # mypy: allow-untyped-defs """ Triton Implementation of the flex_attention Kernel""" +import os +import itertools import logging import math @@ -1206,6 +1208,16 @@ def flex_attention( if torch.version.hip: configs = [(c[0], c[1], c[2], 1) for c in configs] + # Check if the environment variable is set + if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1": + param1 = [16, 32, 64, 128, 256] + param2 = [16, 32, 64, 128, 256] + param3 = [0, 1, 2, 4, 8] + param4 = [1, 2, 3] + + # Generate full search space + configs = list(itertools.product(param1, param2, param3, param4)) + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) @@ -1223,8 +1235,8 @@ def flex_attention( ) continue # Work around https://github.com/pytorch/pytorch/issues/129625 - if num_stages == 2: - continue + #if num_stages == 2: + # continue cur_kernel_options = original_kernel_options.copy() # Performance tuning @@ -1234,33 +1246,67 @@ def flex_attention( cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) - error = flex_attention_template.maybe_append_choice( - choices=choices, - input_nodes=[ - query, - key, - value, - logsumexp, - kv_num_blocks, - kv_indices, - full_kv_num_blocks, - full_kv_indices, - ], - layout=layout, - subgraphs=[ - subgraph_buffer, - mask_graph_buffer, - ], - mutated_inputs=[ - logsumexp, - ], - num_stages=num_stages, - num_warps=num_warps, - call_sizes=query.get_size(), - **cur_kernel_options, - ) - if error is not None and len(configs) == 1: - raise error + if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1": + for mfma in [0, 16]: + for wpeu in [0, 1, 2, 4, 8]: + cur_kernel_options["waves_per_eu"] = wpeu + cur_kernel_options["matrix_instr_non_kdim"] = mfma + error = flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout, + subgraphs=[ + subgraph_buffer, + mask_graph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + num_stages=num_stages, + num_warps=num_warps, + call_sizes=query.get_size(), + **cur_kernel_options, + ) + if error is not None and len(configs) == 1: + raise error + else: + error = flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout, + subgraphs=[ + subgraph_buffer, + mask_graph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + num_stages=num_stages, + num_warps=num_warps, + call_sizes=query.get_size(), + **cur_kernel_options, + ) + if error is not None and len(configs) == 1: + raise error + inputs_for_autotuning = ( [ query, @@ -2264,6 +2310,17 @@ def flex_attention_backward(*args, **kwargs): if BLOCK2 % BLOCK1 == 0 ] ) + + # Check if the environment variable is set + if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1": + param1 = [16, 32, 64, 128, 256] + param2 = [16, 32, 64, 128, 256] + param3 = [0, 1, 2, 4, 8] + param4 = [1, 2, 3] + + # Generate full search space + configs = list(itertools.product(param1, param2, param3, param4)) + original_kernel_options = kernel_options.copy() for BLOCK1, BLOCK2, num_warps, num_stages in configs: if ( @@ -2287,43 +2344,86 @@ def flex_attention_backward(*args, **kwargs): cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) - flex_attention_backward_template.maybe_append_choice( - choices=choices, - input_nodes=[ - query, - key, - value, - logsumexp, - delta, - grad_out, - grad_query, - broadcasted_grad_value, - kv_num_blocks, - kv_indices, - q_num_blocks, - q_indices, - full_kv_num_blocks, - full_kv_indices, - full_q_num_blocks, - full_q_indices, - ], - layout=layout_broadcasted_k, # We use store_output only for grad_key - subgraphs=[ - fw_subgraph_buffer, - joint_outputs.grad_input, - mask_graph_buffer, - joint_outputs.captured_grads_compute, - ], - mutated_inputs=[ - grad_query, - broadcasted_grad_value, - *joint_outputs.mutated_grads, - ], - call_sizes=query.get_size() + key.get_size()[1:3], - num_stages=num_stages, - num_warps=num_warps, - **cur_kernel_options, - ) + if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1": + for mfma in [0, 16]: + for wpeu in [0, 1, 2, 4, 8]: + cur_kernel_options["waves_per_eu"] = wpeu + cur_kernel_options["matrix_instr_non_kdim"] = mfma + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ], + layout=layout_broadcasted_k, # We use store_output only for grad_key + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], + call_sizes=query.get_size() + key.get_size()[1:3], + num_stages=num_stages, + num_warps=num_warps, + **cur_kernel_options, + ) + else: + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ], + layout=layout_broadcasted_k, # We use store_output only for grad_key + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], + call_sizes=query.get_size() + key.get_size()[1:3], + num_stages=num_stages, + num_warps=num_warps, + **cur_kernel_options, + ) inputs_for_autotuning = ( [ query, From a1c6a723ea7284382f1aa4a0a00f8a773f08ffb6 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Fri, 14 Mar 2025 10:11:44 +0000 Subject: [PATCH 2/5] 26 tuning updates --- torch/_inductor/kernel/flex_attention.py | 140 ++++++++--------------- 1 file changed, 46 insertions(+), 94 deletions(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 67a71867c950fd..5fb9681ed42787 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1212,8 +1212,8 @@ def flex_attention( if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1": param1 = [16, 32, 64, 128, 256] param2 = [16, 32, 64, 128, 256] - param3 = [0, 1, 2, 4, 8] - param4 = [1, 2, 3] + param3 = [2, 4, 8, 16] + param4 = [1] # Generate full search space configs = list(itertools.product(param1, param2, param3, param4)) @@ -1235,8 +1235,8 @@ def flex_attention( ) continue # Work around https://github.com/pytorch/pytorch/issues/129625 - #if num_stages == 2: - # continue + if num_stages == 2: + continue cur_kernel_options = original_kernel_options.copy() # Performance tuning @@ -2303,23 +2303,14 @@ def flex_attention_backward(*args, **kwargs): configs.extend( [ (BLOCK1, BLOCK2, w, s) - for BLOCK1 in [32, 64] - for BLOCK2 in [32, 64, 128] + for BLOCK1 in [16, 32, 64, 128, 256] + for BLOCK2 in [16, 32, 64, 128, 256] for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) for s in num_stages_list if BLOCK2 % BLOCK1 == 0 ] ) - # Check if the environment variable is set - if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1": - param1 = [16, 32, 64, 128, 256] - param2 = [16, 32, 64, 128, 256] - param3 = [0, 1, 2, 4, 8] - param4 = [1, 2, 3] - - # Generate full search space - configs = list(itertools.product(param1, param2, param3, param4)) original_kernel_options = kernel_options.copy() for BLOCK1, BLOCK2, num_warps, num_stages in configs: @@ -2344,86 +2335,47 @@ def flex_attention_backward(*args, **kwargs): cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) - if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1": + for wpeu in [0, 4, 8]: for mfma in [0, 16]: - for wpeu in [0, 1, 2, 4, 8]: - cur_kernel_options["waves_per_eu"] = wpeu - cur_kernel_options["matrix_instr_non_kdim"] = mfma - flex_attention_backward_template.maybe_append_choice( - choices=choices, - input_nodes=[ - query, - key, - value, - logsumexp, - delta, - grad_out, - grad_query, - broadcasted_grad_value, - kv_num_blocks, - kv_indices, - q_num_blocks, - q_indices, - full_kv_num_blocks, - full_kv_indices, - full_q_num_blocks, - full_q_indices, - ], - layout=layout_broadcasted_k, # We use store_output only for grad_key - subgraphs=[ - fw_subgraph_buffer, - joint_outputs.grad_input, - mask_graph_buffer, - joint_outputs.captured_grads_compute, - ], - mutated_inputs=[ - grad_query, - broadcasted_grad_value, - *joint_outputs.mutated_grads, - ], - call_sizes=query.get_size() + key.get_size()[1:3], - num_stages=num_stages, - num_warps=num_warps, - **cur_kernel_options, - ) - else: - flex_attention_backward_template.maybe_append_choice( - choices=choices, - input_nodes=[ - query, - key, - value, - logsumexp, - delta, - grad_out, - grad_query, - broadcasted_grad_value, - kv_num_blocks, - kv_indices, - q_num_blocks, - q_indices, - full_kv_num_blocks, - full_kv_indices, - full_q_num_blocks, - full_q_indices, - ], - layout=layout_broadcasted_k, # We use store_output only for grad_key - subgraphs=[ - fw_subgraph_buffer, - joint_outputs.grad_input, - mask_graph_buffer, - joint_outputs.captured_grads_compute, - ], - mutated_inputs=[ - grad_query, - broadcasted_grad_value, - *joint_outputs.mutated_grads, - ], - call_sizes=query.get_size() + key.get_size()[1:3], - num_stages=num_stages, - num_warps=num_warps, - **cur_kernel_options, - ) + cur_kernel_options["waves_per_eu"] = wpeu + cur_kernel_options["matrix_instr_non_kdim"] = mfma + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ], + layout=layout_broadcasted_k, # We use store_output only for grad_key + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], + call_sizes=query.get_size() + key.get_size()[1:3], + num_stages=num_stages, + num_warps=num_warps, + **cur_kernel_options, + ) inputs_for_autotuning = ( [ query, From 0df99be88428684bc53d69dcdd4d4ee69a9d8c7e Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Fri, 7 Mar 2025 18:09:42 +0000 Subject: [PATCH 3/5] [ROCm] Incorporate ROCm triton specific tuning parameters (#148437) Splitting https://github.com/pytorch/pytorch/pull/147315 into two PRs. This PR adds general support for kpack and waves_per_eu triton kernel args for AMD backend. More detail in the PR above. A follow up PR will update the configs used by ROCm but this requires https://github.com/pytorch/pytorch/pull/147452 to land first Pull Request resolved: https://github.com/pytorch/pytorch/pull/148437 Approved by: https://github.com/eellison, https://github.com/jansel (cherry picked from commit 8059ead823740acb07afd5c7d61131bc46c790ce) --- torch/_inductor/autotune_process.py | 4 ++++ torch/_inductor/kernel/flex_attention.py | 4 ++++ torch/_inductor/kernel/mm_common.py | 7 ++++++- .../_inductor/runtime/coordinate_descent_tuner.py | 6 ++++++ torch/_inductor/select_algorithm.py | 14 ++++++++++++-- 5 files changed, 32 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 2828f48b79c233..fe1620cc90ca30 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -630,6 +630,8 @@ def __init__( num_stages: int, num_warps: int, matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction. + waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit + kpack: int = 0, # ROCm specific gemm paramete workspace_arg: Optional[WorkspaceArg] = None, ) -> None: super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) @@ -639,6 +641,8 @@ def __init__( self.num_stages = num_stages self.num_warps = num_warps self.matrix_instr_nonkdim = matrix_instr_nonkdim + self.waves_per_eu = waves_per_eu + self.kpack = kpack self.workspace_arg = workspace_arg def make_run_fn( diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 5fb9681ed42787..8cf7906f111996 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1222,6 +1222,10 @@ def flex_attention( SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + # ROCm specific considerations + if torch.version.hip: + kernel_options["kpack"] = 2 + # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function # We do need to explicitly pass it in for autotuning though. diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index cb3b2d7836c1ae..56da09bc76adff 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -75,7 +75,7 @@ def filtered_configs( ), min_block_size_k, ) - used = set() + used = OrderedSet[tuple[int, ...]]() for block_m, block_n, block_k, num_stages, num_warps in configs: # shrink configs for small sizes block_m = max(min(int(block_m * scale), m), min_block_size) @@ -88,6 +88,7 @@ def filtered_configs( # each warp computes 16x16 tile = 256 num_warps = min(num_warps, block_m * block_n // 256) if torch.version.hip: + kpack = 2 for matrix_instr_nonkdim in [0, 16]: if matrix_instr_nonkdim != 0 and ( block_m % matrix_instr_nonkdim != 0 @@ -95,6 +96,7 @@ def filtered_configs( ): # block_m and block_n must be a multiple of matrix_instr_nonkdim continue + if ( block_m, block_n, @@ -102,6 +104,7 @@ def filtered_configs( num_stages, num_warps, matrix_instr_nonkdim, + kpack, ) not in used: used.add( ( @@ -111,6 +114,7 @@ def filtered_configs( num_stages, num_warps, matrix_instr_nonkdim, + kpack, ) ) yield triton_config( @@ -120,6 +124,7 @@ def filtered_configs( num_stages=num_stages, num_warps=num_warps, matrix_instr_nonkdim=matrix_instr_nonkdim, + kpack=kpack, ) else: if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used: diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 62a2abcea8d2d7..5fe978276637a8 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -21,6 +21,8 @@ def get_field(config, name): return config.num_warps elif name == "num_stages": return config.num_stages + elif name == "waves_per_eu": + return config.kwargs.get(name, int(8 // config.num_warps)) else: return config.kwargs.get(name, None) @@ -97,6 +99,8 @@ def tunable_fields(self): ] if self.is_mm: out.append("num_stages") + if self.inductor_meta.get("is_hip") is True: + out.append("waves_per_eu") return out @@ -105,6 +109,8 @@ def value_too_large(self, name: str, val: int) -> bool: return val > self.get_config_max(name[0].lower()) if name == "num_warps": return val > self.get_warpsmax() + if name == "waves_per_eu": + return val > 8 return False diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index b3fde21699dba6..da0b06bc0fe826 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -354,9 +354,17 @@ def jit_lines(self): triton_meta["configs"] = [config_of(signature)] for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] - matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0) - if matrix_instr_nonkdim != 0: + for arg_num in equal_1_arg_indices(signature): # type: ignore[index] + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] + matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", None) + waves_per_eu = self.meta.get("waves_per_eu", None) + kpack = self.meta.get("kpack", None) + if matrix_instr_nonkdim: triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim + if waves_per_eu: + triton_meta["waves_per_eu"] = waves_per_eu + if kpack: + triton_meta["kpack"] = kpack self.triton_meta = triton_meta @@ -920,6 +928,8 @@ def make_kernel_render(out_node): num_stages=num_stages, num_warps=num_warps, matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0), + waves_per_eu=kwargs.get("waves_per_eu", 0), + kpack=kwargs.get("kpack", 2), input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type] output_tensor_meta=TensorMeta.from_irnodes(layout), workspace_arg=workspace_arg, From 13625d1c1444e9cdfe5236e1d387ba86638d3cf2 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Thu, 27 Mar 2025 12:59:48 +0000 Subject: [PATCH 4/5] Fixes --- torch/_inductor/kernel/flex_attention.py | 12 ++++++------ torch/_inductor/kernel/mm_common.py | 2 ++ torch/_inductor/select_algorithm.py | 3 +-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 8cf7906f111996..b8220ef564b46f 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1210,8 +1210,8 @@ def flex_attention( # Check if the environment variable is set if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1": - param1 = [16, 32, 64, 128, 256] - param2 = [16, 32, 64, 128, 256] + param1 = [16, 32, 64, 128, 256, 512] + param2 = [16, 32, 64, 128, 256, 512] param3 = [2, 4, 8, 16] param4 = [1] @@ -2307,9 +2307,9 @@ def flex_attention_backward(*args, **kwargs): configs.extend( [ (BLOCK1, BLOCK2, w, s) - for BLOCK1 in [16, 32, 64, 128, 256] - for BLOCK2 in [16, 32, 64, 128, 256] - for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + for BLOCK1 in [16, 32, 64, 128, 256, 512] + for BLOCK2 in [16, 32, 64, 128, 256, 512] + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4, 8]) for s in num_stages_list if BLOCK2 % BLOCK1 == 0 ] @@ -2339,7 +2339,7 @@ def flex_attention_backward(*args, **kwargs): cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) - for wpeu in [0, 4, 8]: + for wpeu in [0, 1, 2, 4, 8]: for mfma in [0, 16]: cur_kernel_options["waves_per_eu"] = wpeu cur_kernel_options["matrix_instr_non_kdim"] = mfma diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 56da09bc76adff..6e6e7faf4e3e29 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -4,6 +4,8 @@ import logging from typing import Any, cast, Dict, Sequence, Tuple +from torch.utils._ordered_set import OrderedSet + import sympy import torch diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index da0b06bc0fe826..82ed98921d133e 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -354,8 +354,6 @@ def jit_lines(self): triton_meta["configs"] = [config_of(signature)] for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] - for arg_num in equal_1_arg_indices(signature): # type: ignore[index] - triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", None) waves_per_eu = self.meta.get("waves_per_eu", None) kpack = self.meta.get("kpack", None) @@ -366,6 +364,7 @@ def jit_lines(self): if kpack: triton_meta["kpack"] = kpack + self.triton_meta = triton_meta inductor_meta = { From 8bf46b07ecdd31aea1641a30998e0ac00699ef4a Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Thu, 27 Mar 2025 13:03:18 +0000 Subject: [PATCH 5/5] Updates --- .ci/docker/ci_commit_pins/triton.txt | 2 +- torch/_inductor/kernel/flex_attention.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 396be2dd54aeee..2a9c139109996f 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -6da9e66008b58a7b8553f96c69021cca0d0028f0 +a34a79dbd711ea9f8fb5090bcaf24a7717574206 diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index b8220ef564b46f..2b1884cd00504f 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -2325,9 +2325,6 @@ def flex_attention_backward(*args, **kwargs): or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0 ): continue - if num_warps == 8: - # Working around https://github.com/pytorch/pytorch/issues/141603 - continue # Performance tuning cur_kernel_options = original_kernel_options.copy()