Skip to content

Commit eceaa81

Browse files
committed
[Executorch] Use temp allocator for allocating scratch memory
This allows us to leverage temp memory allocator and if that allocator is caching allocator it reduces the allocaiton overhead. Differential Revision: [D85532076](https://our.internmc.facebook.com/intern/diff/D85532076/) ghstack-source-id: 321483013 Pull Request resolved: #15654
1 parent 593e996 commit eceaa81

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ Tensor& flash_attention_kernel_out(
273273
// we might consider another appraoch
274274
if (seq_len >= 768) {
275275
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
276+
ctx,
276277
output,
277278
query,
278279
key,
@@ -289,6 +290,7 @@ Tensor& flash_attention_kernel_out(
289290
nullopt);
290291
} else if (seq_len >= 192) {
291292
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
293+
ctx,
292294
output,
293295
query,
294296
key,
@@ -305,6 +307,7 @@ Tensor& flash_attention_kernel_out(
305307
nullopt);
306308
} else {
307309
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
310+
ctx,
308311
output,
309312
query,
310313
key,
@@ -418,6 +421,7 @@ Tensor& custom_sdpa_out_impl(
418421
// we might consider another appraoch
419422
if (seq_len >= 768) {
420423
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
424+
ctx,
421425
output,
422426
q,
423427
k,
@@ -437,6 +441,7 @@ Tensor& custom_sdpa_out_impl(
437441
num_keys_for_causal_attention);
438442
} else if (seq_len >= 192) {
439443
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
444+
ctx,
440445
output,
441446
q,
442447
k,
@@ -456,6 +461,7 @@ Tensor& custom_sdpa_out_impl(
456461
num_keys_for_causal_attention);
457462
} else {
458463
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
464+
ctx,
459465
output,
460466
q,
461467
k,

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ enum class SeqDim { ONE = 1, TWO };
3535

3636
namespace sdpa::impl {
3737

38+
static std::vector<char> scratch_for_quant_dequant_vec;
3839
struct MaybeQuantizedMatrixData {
3940
const void* data{nullptr};
4041
const int8_t* zero_points{nullptr};
@@ -543,6 +544,7 @@ TODO: Just handle conversion of bool mask to float
543544
*/
544545
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
545546
void cpu_flash_attention(
547+
RuntimeContext& ctx,
546548
Tensor& output,
547549
const Tensor& query,
548550
const Tensor& key,
@@ -766,26 +768,37 @@ void cpu_flash_attention(
766768
int64_t size_of_intermediate_precision = sizeof(accum_t);
767769
int64_t size_bytes = size_per_thread * num_thread * query.element_size() *
768770
size_of_intermediate_precision;
769-
std::vector<char> buf_vec(size_bytes);
770-
void* buf = reinterpret_cast<void*>(buf_vec.data());
771-
// Need to double check the following
772-
size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size();
773-
std::vector<char> buf_reduced_vec(size_bytes);
774-
void* buf_reduced = reinterpret_cast<void*>(buf_reduced_vec.data());
775-
// at::Tensor buf_reduced = at::empty(
776-
// {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
777-
// query.options());
771+
Result<void*> buff_res = ctx.allocate_temp(size_bytes);
772+
std::unique_ptr<char[]> allocated_buf;
773+
void* buf;
774+
if (!buff_res.ok()) {
775+
allocated_buf = std::make_unique<char[]>(size_bytes);
776+
buf = reinterpret_cast<void*>(allocated_buf.get());
777+
} else {
778+
buf = buff_res.get();
779+
}
780+
void* buf_reduced = nullptr;
778781
int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize;
779782
// Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780783
// by padding with right number of per thread elements
781784
constexpr int64_t kAlignment = 32;
782785
size_per_thread_qdq_vec =
783786
(size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
784-
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t);
787+
int64_t size_per_thread_qdq_bytes =
788+
size_per_thread_qdq_vec * size_of_intermediate_precision;
785789
int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
786-
std::vector<char> scratch_for_quant_dequant_vec(size_qdq_bytes);
787-
accum_t* scratch_for_quant_dequant =
788-
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
790+
std::unique_ptr<char[]> allocated_buf_for_qdq;
791+
Result<void*> scratch_for_quant_dequant_res =
792+
ctx.allocate_temp(size_qdq_bytes);
793+
accum_t* scratch_for_quant_dequant;
794+
if (!scratch_for_quant_dequant_res.ok()) {
795+
allocated_buf_for_qdq = std::make_unique<char[]>(size_qdq_bytes);
796+
scratch_for_quant_dequant =
797+
reinterpret_cast<accum_t*>(allocated_buf_for_qdq.get());
798+
} else {
799+
scratch_for_quant_dequant =
800+
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_res.get());
801+
}
789802

790803
// Data ptrs
791804
const scalar_t* q_data = query.const_data_ptr<scalar_t>();

0 commit comments

Comments
 (0)