Skip to content

Commit 899bba9

Browse files
committed
update
1 parent 28fedb2 commit 899bba9

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,9 @@
12621262
scalar_t* key_reorder_ptr = kv_reorder_buf_data;
12631263
scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * rndHeadSize * rndkvSize;
12641264
1265+
int64_t B_blocked_xform_u8_per_thread = rndHeadSize * kvSplitSize;
1266+
{{template.codegen_allocate_buffer("B_blocked_xform_u8_data", "scalar_t", "num_thread * B_blocked_xform_u8_per_thread")}}
1267+
12651268
// sum k and v
12661269
at::parallel_for(
12671270
0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) {
@@ -1300,7 +1303,8 @@
13001303
int64_t i = 0, j = 0, l = 0, n = 0;
13011304
at::native::data_index_init(
13021305
begin, i, batchSize, j, num_head, l, kvSlice);
1303-
uint8_t* B_blocked_xform_u8 = new uint8_t[rndHeadSize * kvSplitSize];
1306+
int ompIdx = at::get_thread_num();
1307+
scalar_t* B_blocked_xform_u8 = B_blocked_xform_u8_data + ompIdx * B_blocked_xform_u8_per_thread;
13041308
for (const auto z : c10::irange(begin, end)) {
13051309
(void)z; // Suppress unused variable
13061310
n = l * kvSplitSize;

0 commit comments

Comments
 (0)