@@ -35,6 +35,7 @@ enum class SeqDim { ONE = 1, TWO };
3535
3636namespace sdpa ::impl {
3737
38+ static std::vector<char > scratch_for_quant_dequant_vec;
3839struct MaybeQuantizedMatrixData {
3940 const void * data{nullptr };
4041 const int8_t * zero_points{nullptr };
@@ -543,6 +544,7 @@ TODO: Just handle conversion of bool mask to float
543544 */
544545template <typename scalar_t , int64_t q_split_size, int64_t kv_split_size>
545546void cpu_flash_attention (
547+ RuntimeContext& ctx,
546548 Tensor& output,
547549 const Tensor& query,
548550 const Tensor& key,
@@ -766,26 +768,37 @@ void cpu_flash_attention(
766768 int64_t size_of_intermediate_precision = sizeof (accum_t );
767769 int64_t size_bytes = size_per_thread * num_thread * query.element_size () *
768770 size_of_intermediate_precision;
769- std::vector<char > buf_vec (size_bytes);
770- void * buf = reinterpret_cast <void *>(buf_vec.data ());
771- // Need to double check the following
772- size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size ();
773- std::vector<char > buf_reduced_vec (size_bytes);
774- void * buf_reduced = reinterpret_cast <void *>(buf_reduced_vec.data ());
775- // at::Tensor buf_reduced = at::empty(
776- // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
777- // query.options());
771+ Result<void *> buff_res = ctx.allocate_temp (size_bytes);
772+ std::unique_ptr<char []> allocated_buf;
773+ void * buf;
774+ if (!buff_res.ok ()) {
775+ allocated_buf = std::make_unique<char []>(size_bytes);
776+ buf = reinterpret_cast <void *>(allocated_buf.get ());
777+ } else {
778+ buf = buff_res.get ();
779+ }
780+ void * buf_reduced = nullptr ;
778781 int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize;
779782 // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780783 // by padding with right number of per thread elements
781784 constexpr int64_t kAlignment = 32 ;
782785 size_per_thread_qdq_vec =
783786 (size_per_thread_qdq_vec + kAlignment - 1 ) & (-(kAlignment - 1 ));
784- int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof (accum_t );
787+ int64_t size_per_thread_qdq_bytes =
788+ size_per_thread_qdq_vec * size_of_intermediate_precision;
785789 int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
786- std::vector<char > scratch_for_quant_dequant_vec (size_qdq_bytes);
787- accum_t * scratch_for_quant_dequant =
788- reinterpret_cast <accum_t *>(scratch_for_quant_dequant_vec.data ());
790+ std::unique_ptr<char []> allocated_buf_for_qdq;
791+ Result<void *> scratch_for_quant_dequant_res =
792+ ctx.allocate_temp (size_qdq_bytes);
793+ accum_t * scratch_for_quant_dequant;
794+ if (!scratch_for_quant_dequant_res.ok ()) {
795+ allocated_buf_for_qdq = std::make_unique<char []>(size_qdq_bytes);
796+ scratch_for_quant_dequant =
797+ reinterpret_cast <accum_t *>(allocated_buf_for_qdq.get ());
798+ } else {
799+ scratch_for_quant_dequant =
800+ reinterpret_cast <accum_t *>(scratch_for_quant_dequant_res.get ());
801+ }
789802
790803 // Data ptrs
791804 const scalar_t * q_data = query.const_data_ptr <scalar_t >();
0 commit comments