diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 9bd46a4247..fc32f86754 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -346,14 +346,14 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea const at::Tensor &block_offset, const at::Tensor &column_count, const at::Tensor &column_index, - const c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size - const c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const double p_dropout, const double softmax_scale, bool is_causal, const double softcap, const bool return_softmax, - c10::optional gen_) { + std::optional gen_) { auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -515,11 +515,11 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_ const at::Tensor &block_offset, const at::Tensor &column_count, const at::Tensor &column_index, - const c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 - const c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - const c10::optional &alibi_slopes_, // num_heads or b x num_heads + const std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + const std::optional &alibi_slopes_, // num_heads or b x num_heads int64_t max_seqlen_q, const int64_t max_seqlen_k, const double p_dropout, @@ -528,7 +528,7 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_ bool is_causal, const double softcap, const bool return_softmax, - c10::optional gen_) { + std::optional gen_) { auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;