@@ -77,6 +77,47 @@ at::Tensor custom_sdpa_aten(
77
77
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
78
78
const std::optional<double > scale);
79
79
80
+ #ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
81
+ Tensor& custom_quantized_sdpa_out_no_context (
82
+ const Tensor& q,
83
+ const Tensor& k,
84
+ const Tensor& v,
85
+ const int64_t start_pos,
86
+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
87
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
88
+ const optional<Tensor> attn_mask,
89
+ const double dropout_p,
90
+ const bool is_causal,
91
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
92
+ const optional<double > scale,
93
+ const optional<Tensor> q_zero_points,
94
+ const optional<Tensor> q_scales,
95
+ const optional<Tensor> k_zero_points,
96
+ const optional<Tensor> k_scales,
97
+ const optional<Tensor> v_zero_points,
98
+ const optional<Tensor> v_scales,
99
+ Tensor& output);
100
+
101
+ at::Tensor custom_quantized_sdpa_aten (
102
+ const at::Tensor& q,
103
+ const at::Tensor& k,
104
+ const at::Tensor& v,
105
+ const int64_t start_pos,
106
+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
107
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
108
+ const std::optional<at::Tensor> attn_mask,
109
+ const double dropout_p,
110
+ const bool is_causal,
111
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
112
+ const std::optional<double > scale,
113
+ const std::optional<at::Tensor>& q_zero_points,
114
+ const std::optional<at::Tensor>& q_scales,
115
+ const std::optional<at::Tensor>& k_zero_points,
116
+ const std::optional<at::Tensor>& k_scales,
117
+ const std::optional<at::Tensor>& v_zero_points,
118
+ const std::optional<at::Tensor>& v_scales);
119
+ #endif // ENABLE_CUSTOM_QUANTIZED_SDPA
120
+
80
121
Tensor& update_cache_out_no_context (
81
122
const Tensor& value,
82
123
Tensor& cache,
@@ -198,6 +239,85 @@ at::Tensor custom_sdpa_aten(
198
239
return output;
199
240
}
200
241
242
+ #ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
243
+ Tensor& custom_quantized_sdpa_out_no_context (
244
+ const Tensor& q,
245
+ const Tensor& k,
246
+ const Tensor& v,
247
+ const int64_t start_pos,
248
+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
249
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
250
+ const optional<Tensor> attn_mask,
251
+ const double dropout_p,
252
+ const bool is_causal,
253
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
254
+ const optional<double > scale,
255
+ const optional<Tensor> q_zero_points,
256
+ const optional<Tensor> q_scales,
257
+ const optional<Tensor> k_zero_points,
258
+ const optional<Tensor> k_scales,
259
+ const optional<Tensor> v_zero_points,
260
+ const optional<Tensor> v_scales,
261
+ Tensor& output) {
262
+ executorch::aten::RuntimeContext context{};
263
+ return torch::executor::native::custom_quantized_sdpa_out (
264
+ context,
265
+ q,
266
+ k,
267
+ v,
268
+ start_pos,
269
+ attn_mask,
270
+ dropout_p,
271
+ is_causal,
272
+ scale,
273
+ q_zero_points,
274
+ q_scales,
275
+ k_zero_points,
276
+ k_scales,
277
+ v_zero_points,
278
+ v_scales,
279
+ output);
280
+ }
281
+
282
+ at::Tensor custom_quantized_sdpa_aten (
283
+ const at::Tensor& q,
284
+ const at::Tensor& k,
285
+ const at::Tensor& v,
286
+ const int64_t start_pos,
287
+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
288
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
289
+ const std::optional<at::Tensor> attn_mask,
290
+ const double dropout_p,
291
+ const bool is_causal,
292
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
293
+ const std::optional<double > scale,
294
+ const std::optional<at::Tensor>& q_zero_points,
295
+ const std::optional<at::Tensor>& q_scales,
296
+ const std::optional<at::Tensor>& k_zero_points,
297
+ const std::optional<at::Tensor>& k_scales,
298
+ const std::optional<at::Tensor>& v_zero_points,
299
+ const std::optional<at::Tensor>& v_scales) {
300
+ auto output = at::empty (q.sizes ());
301
+ WRAP_TO_ATEN (custom_quantized_sdpa_out_no_context, 14 )
302
+ (q,
303
+ k,
304
+ v,
305
+ start_pos,
306
+ attn_mask,
307
+ dropout_p,
308
+ is_causal,
309
+ scale,
310
+ q_zero_points,
311
+ q_scales,
312
+ k_zero_points,
313
+ k_scales,
314
+ v_zero_points,
315
+ v_scales,
316
+ output);
317
+ return output;
318
+ }
319
+ #endif // ENABLE_CUSTOM_QUANTIZED_SDPA
320
+
201
321
Tensor& update_cache_out_no_context (
202
322
const Tensor& value,
203
323
Tensor& cache,
@@ -245,6 +365,20 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
245
365
m.def (
246
366
" update_cache.out(Tensor value, Tensor(a!) cache, "
247
367
" SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)" );
368
+ #ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
369
+ m.def (
370
+ " custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
371
+ " Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
372
+ " float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
373
+ " Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
374
+ " Tensor? v_scales=None) -> Tensor" );
375
+ m.def (
376
+ " custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
377
+ " Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
378
+ " float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
379
+ " Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
380
+ " Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)" );
381
+ #endif // ENABLE_CUSTOM_QUANTIZED_SDPA
248
382
}
249
383
250
384
// TODO: Rename this file to op_custom_ops_aot.cpp
@@ -263,4 +397,13 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
263
397
m.impl (
264
398
" update_cache.out" ,
265
399
WRAP_TO_ATEN (torch::executor::native::update_cache_out_no_context, 3 ));
400
+ #ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
401
+ m.impl (
402
+ " custom_quantized_sdpa" ,
403
+ torch::executor::native::custom_quantized_sdpa_aten);
404
+ m.impl (
405
+ " custom_quantized_sdpa.out" ,
406
+ WRAP_TO_ATEN (
407
+ torch::executor::native::custom_quantized_sdpa_out_no_context, 14 ));
408
+ #endif // ENABLE_CUSTOM_QUANTIZED_SDPA
266
409
}
0 commit comments