@@ -146,7 +146,7 @@ def addmm_bwd(grad_out: Tensor, bias: Tensor, mat1: Tensor, mat2: Tensor, alpha:
146146 _BLOCK_SIZE_5 = 16
147147 _BLOCK_SIZE_6 = 16
148148 _BLOCK_SIZE_7 = 16
149- _launcher(_helion_addmm_bwd, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1) + triton.cdiv(m, _BLOCK_SIZE_2) * triton.cdiv(k, _BLOCK_SIZE_3) + triton.cdiv(k, _BLOCK_SIZE_5) * triton.cdiv(n, _BLOCK_SIZE_6),), grad_out, grad_input, mat2, grad_mat1, mat1, grad_mat2, grad_input.stride(0), grad_input.stride(1), grad_mat1.stride(0), grad_mat1.stride(1), grad_mat2.stride(0), grad_mat2.stride(1), grad_out.stride(0), grad_out.stride(1), mat1.stride(0), mat1.stride(1), mat2.stride(0), mat2.stride(1), m, n, beta, k, alpha, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, _BLOCK_SIZE_6, _BLOCK_SIZE_7, num_warps=4, num_stages=3 )
149+ _launcher(_helion_addmm_bwd, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1) + triton.cdiv(m, _BLOCK_SIZE_2) * triton.cdiv(k, _BLOCK_SIZE_3) + triton.cdiv(k, _BLOCK_SIZE_5) * triton.cdiv(n, _BLOCK_SIZE_6),), grad_out, grad_input, mat2, grad_mat1, mat1, grad_mat2, grad_input.stride(0), grad_input.stride(1), grad_mat1.stride(0), grad_mat1.stride(1), grad_mat2.stride(0), grad_mat2.stride(1), grad_out.stride(0), grad_out.stride(1), mat1.stride(0), mat1.stride(1), mat2.stride(0), mat2.stride(1), m, n, beta, k, alpha, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, _BLOCK_SIZE_6, _BLOCK_SIZE_7, num_warps=4, num_stages=2 )
150150 return (grad_input, grad_mat1, grad_mat2)
151151
152152--- assertExpectedJournal(TestExamples.test_attention_block_pointer)
@@ -3050,7 +3050,7 @@ def matmul_bwd(grad_out: Tensor, mat1: Tensor, mat2: Tensor, *, _launcher=_defau
30503050 _BLOCK_SIZE_3 = 16
30513051 _BLOCK_SIZE_4 = 16
30523052 _BLOCK_SIZE_5 = 16
3053- _launcher(_helion_matmul_bwd, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(k, _BLOCK_SIZE_1) + triton.cdiv(k, _BLOCK_SIZE_3) * triton.cdiv(n, _BLOCK_SIZE_4),), grad_out, mat2, grad_mat1, mat1, grad_mat2, grad_mat1.stride(0), grad_mat1.stride(1), grad_mat2.stride(0), grad_mat2.stride(1), grad_out.stride(0), grad_out.stride(1), mat1.stride(0), mat1.stride(1), mat2.stride(0), mat2.stride(1), m, k, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=3 )
3053+ _launcher(_helion_matmul_bwd, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(k, _BLOCK_SIZE_1) + triton.cdiv(k, _BLOCK_SIZE_3) * triton.cdiv(n, _BLOCK_SIZE_4),), grad_out, mat2, grad_mat1, mat1, grad_mat2, grad_mat1.stride(0), grad_mat1.stride(1), grad_mat2.stride(0), grad_mat2.stride(1), grad_out.stride(0), grad_out.stride(1), mat1.stride(0), mat1.stride(1), mat2.stride(0), mat2.stride(1), m, k, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=2 )
30543054 return (grad_mat1, grad_mat2)
30553055
30563056--- assertExpectedJournal(TestExamples.test_matmul_layernorm_dynamic_shapes)
0 commit comments