@@ -58,6 +58,40 @@ class VarInfo(NamedTuple):
5858 fx_node : torch .fx .Node
5959
6060
61+ def find_block_size_symbols (
62+ expr : sympy .Expr ,
63+ ) -> tuple [dict [sympy .Symbol , int ], set [sympy .Symbol ]]:
64+ """
65+ Find block size symbols in a sympy expression.
66+
67+ Returns:
68+ tuple of (block_size_mapping, non_block_size_symbols) where:
69+ - block_size_mapping: dict mapping block size symbols to their block_id
70+ - non_block_size_symbols: set of symbols that are NOT block sizes
71+ """
72+ if not isinstance (expr , sympy .Expr ):
73+ return {}, set ()
74+
75+ hf = HostFunction .current ()
76+ block_sizes = {}
77+ non_block_size_symbols = set ()
78+
79+ for symbol in expr .free_symbols :
80+ origin_info = hf .expr_to_origin .get (symbol ) # pyright: ignore[reportArgumentType]
81+ if origin_info is None or not isinstance (origin_info .origin , BlockSizeOrigin ):
82+ non_block_size_symbols .add (symbol )
83+ else :
84+ block_sizes [symbol ] = origin_info .origin .block_id
85+
86+ return block_sizes , non_block_size_symbols
87+
88+
89+ def contains_only_block_size_symbols (expr : sympy .Expr ) -> bool :
90+ """Check if expression contains only block size symbols (no other variables)."""
91+ _ , non_block = find_block_size_symbols (expr )
92+ return len (non_block ) == 0
93+
94+
6195@dataclasses .dataclass
6296class Argument :
6397 name : str # in the device function
@@ -209,6 +243,35 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
209243 def block_size_var (self , block_id : int ) -> str | None :
210244 return self .block_size_var_cache .get ((block_id ,))
211245
246+ def try_map_block_symbols_to_vars (self , expr : sympy .Expr ) -> sympy .Expr | None :
247+ """Try to map all block size symbols in expression to their variable names.
248+
249+ Returns:
250+ - The expression with symbols replaced if ALL symbols are block sizes and have variables
251+ - None if the expression contains non-block symbols or unmapped block symbols
252+ """
253+ block_mapping , non_block_symbols = find_block_size_symbols (expr )
254+
255+ # Can't map if there are non-block symbols
256+ if non_block_symbols :
257+ return None
258+
259+ # No symbols to map - return as-is
260+ if not block_mapping :
261+ return expr
262+
263+ # Try to map all block symbols to their variables
264+ var_map = {}
265+ for symbol , block_id in block_mapping .items ():
266+ block_var = self .block_size_var (block_id )
267+ if not block_var :
268+ # Can't map this block symbol - fail
269+ return None
270+ var_map [symbol ] = sympy .Symbol (block_var , integer = True )
271+
272+ # Successfully mapped all symbols
273+ return expr .xreplace (var_map )
274+
212275 def merge_variable_names (self , a : str , b : str ) -> None :
213276 name_group = [
214277 * self ._variable_renames .get (a , [a ]),
0 commit comments