Skip to content

Commit 0d9d697

Browse files
kimishpatelkirklandsign
authored andcommitted
Add a path to use quantized gemm from torchao in sdpa
Differential Revision: D71370593 Pull Request resolved: #9933
1 parent e53c017 commit 0d9d697

File tree

8 files changed

+860
-45
lines changed

8 files changed

+860
-45
lines changed

extension/llm/custom_ops/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,17 @@ runtime.python_test(
4747
"//caffe2:torch",
4848
],
4949
)
50+
51+
runtime.python_test(
52+
name = "test_quantized_sdpa",
53+
srcs = [
54+
"test_quantized_sdpa.py",
55+
],
56+
preload_deps = [
57+
":custom_ops_aot_lib_mkl_noomp",
58+
":custom_ops_aot_py",
59+
],
60+
deps = [
61+
"//caffe2:torch",
62+
],
63+
)

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ bool validate_flash_attention_args(
4444
"scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
4545

4646
ET_CHECK_OR_RETURN_FALSE(
47-
(query.scalar_type() == ScalarType::Float), "Query must be Float type");
47+
(query.scalar_type() == ScalarType::Float) ||
48+
(query.scalar_type() == ScalarType::Char),
49+
"Query must be Float type");
4850

4951
ET_CHECK_OR_RETURN_FALSE(
5052
(query.scalar_type() == key.scalar_type()) &&
@@ -354,9 +356,14 @@ Tensor& custom_sdpa_out_impl(
354356
output,
355357
"Invalid arguments");
356358

359+
int64_t seq_len = q.size(1);
360+
auto q_seq_len = q.size(1);
361+
357362
bool is_seq_at_dim_1{true};
358363
if (q.scalar_type() == ScalarType::Char) {
359364
is_seq_at_dim_1 = false;
365+
seq_len = q.size(2);
366+
q_seq_len = q.size(2);
360367
ET_KERNEL_CHECK_MSG(
361368
ctx,
362369
q_scales.has_value() && q_zero_points.has_value() &&
@@ -390,9 +397,6 @@ Tensor& custom_sdpa_out_impl(
390397

391398
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
392399

393-
const int64_t seq_len = q.size(1);
394-
auto q_seq_len = q.size(1);
395-
396400
const int64_t num_keys_for_causal_attention = start_pos + seq_len;
397401

398402
ET_KERNEL_CHECK(
@@ -418,12 +422,12 @@ Tensor& custom_sdpa_out_impl(
418422
is_causal,
419423
attn_mask,
420424
scale,
421-
nullopt, // q_zero_points
422-
nullopt, // q_scales
423-
nullopt, // k_zero_points
424-
nullopt, // k_scales
425-
nullopt, // v_zero_points
426-
nullopt, // v_scales
425+
q_zero_points, // q_zero_points
426+
q_scales, // q_scales
427+
k_zero_points, // k_zero_points
428+
k_scales, // k_scales
429+
v_zero_points, // v_zero_points
430+
v_scales, // v_scales
427431
is_seq_at_dim_1, /* is_seq_at_dim_1 */
428432
start_pos,
429433
num_keys_for_causal_attention);
@@ -437,12 +441,12 @@ Tensor& custom_sdpa_out_impl(
437441
is_causal,
438442
attn_mask,
439443
scale,
440-
nullopt, // q_zero_points
441-
nullopt, // q_scales
442-
nullopt, // k_zero_points
443-
nullopt, // k_scales
444-
nullopt, // v_zero_points
445-
nullopt, // v_scales
444+
q_zero_points, // q_zero_points
445+
q_scales, // q_scales
446+
k_zero_points, // k_zero_points
447+
k_scales, // k_scales
448+
v_zero_points, // v_zero_points
449+
v_scales, // v_scales
446450
is_seq_at_dim_1, /* is_seq_at_dim_1 */
447451
start_pos,
448452
num_keys_for_causal_attention);
@@ -456,12 +460,12 @@ Tensor& custom_sdpa_out_impl(
456460
is_causal,
457461
attn_mask,
458462
scale,
459-
nullopt, // q_zero_points
460-
nullopt, // q_scales
461-
nullopt, // k_zero_points
462-
nullopt, // k_scales
463-
nullopt, // v_zero_points
464-
nullopt, // v_scales
463+
q_zero_points, // q_zero_points
464+
q_scales, // q_scales
465+
k_zero_points, // k_zero_points
466+
k_scales, // k_scales
467+
v_zero_points, // v_zero_points
468+
v_scales, // v_scales
465469
is_seq_at_dim_1, /* is_seq_at_dim_1 */
466470
start_pos,
467471
num_keys_for_causal_attention);
@@ -470,6 +474,45 @@ Tensor& custom_sdpa_out_impl(
470474
return output;
471475
}
472476

477+
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
478+
Tensor& custom_quantized_sdpa_out(
479+
RuntimeContext& ctx,
480+
const Tensor& q,
481+
const Tensor& k,
482+
const Tensor& v,
483+
const int64_t start_pos,
484+
const optional<Tensor>& attn_mask,
485+
const double dropout_p,
486+
const bool is_causal,
487+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
488+
const optional<double> scale,
489+
const optional<Tensor>& q_zero_points,
490+
const optional<Tensor>& q_scales,
491+
const optional<Tensor>& k_zero_points,
492+
const optional<Tensor>& k_scales,
493+
const optional<Tensor>& v_zero_points,
494+
const optional<Tensor>& v_scales,
495+
Tensor& output) {
496+
return custom_sdpa_out_impl(
497+
ctx,
498+
q,
499+
k,
500+
v,
501+
start_pos,
502+
attn_mask,
503+
dropout_p,
504+
is_causal,
505+
scale,
506+
output,
507+
q_zero_points,
508+
q_scales,
509+
k_zero_points,
510+
k_scales,
511+
v_zero_points,
512+
v_scales);
513+
}
514+
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
515+
473516
/*
474517
Input params
475518
@param[in] q_projected Projected query with query weights.
@@ -570,3 +613,10 @@ EXECUTORCH_LIBRARY(
570613
llama,
571614
"custom_sdpa.out",
572615
torch::executor::native::custom_sdpa_out);
616+
617+
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
618+
EXECUTORCH_LIBRARY(
619+
llama,
620+
"custom_quantized_sdpa.out",
621+
torch::executor::native::custom_quantized_sdpa_out);
622+
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA

extension/llm/custom_ops/op_sdpa.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,26 @@ Tensor& flash_attention_kernel_out(
5656
const optional<double> scale,
5757
Tensor& output);
5858

59+
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
60+
Tensor& custom_quantized_sdpa_out(
61+
RuntimeContext& ctx,
62+
const Tensor& q,
63+
const Tensor& k,
64+
const Tensor& v,
65+
const int64_t start_pos,
66+
const optional<Tensor>& attn_mask,
67+
const double dropout_p,
68+
const bool is_causal,
69+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
70+
const optional<double> scale,
71+
const optional<Tensor>& q_zero_points,
72+
const optional<Tensor>& q_scales,
73+
const optional<Tensor>& k_zero_points,
74+
const optional<Tensor>& k_scales,
75+
const optional<Tensor>& v_zero_points,
76+
const optional<Tensor>& v_scales,
77+
Tensor& output);
78+
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
5979
} // namespace native
6080
} // namespace executor
6181
} // namespace torch

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,47 @@ at::Tensor custom_sdpa_aten(
7777
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
7878
const std::optional<double> scale);
7979

80+
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
81+
Tensor& custom_quantized_sdpa_out_no_context(
82+
const Tensor& q,
83+
const Tensor& k,
84+
const Tensor& v,
85+
const int64_t start_pos,
86+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
87+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
88+
const optional<Tensor> attn_mask,
89+
const double dropout_p,
90+
const bool is_causal,
91+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
92+
const optional<double> scale,
93+
const optional<Tensor> q_zero_points,
94+
const optional<Tensor> q_scales,
95+
const optional<Tensor> k_zero_points,
96+
const optional<Tensor> k_scales,
97+
const optional<Tensor> v_zero_points,
98+
const optional<Tensor> v_scales,
99+
Tensor& output);
100+
101+
at::Tensor custom_quantized_sdpa_aten(
102+
const at::Tensor& q,
103+
const at::Tensor& k,
104+
const at::Tensor& v,
105+
const int64_t start_pos,
106+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
107+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
108+
const std::optional<at::Tensor> attn_mask,
109+
const double dropout_p,
110+
const bool is_causal,
111+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
112+
const std::optional<double> scale,
113+
const std::optional<at::Tensor>& q_zero_points,
114+
const std::optional<at::Tensor>& q_scales,
115+
const std::optional<at::Tensor>& k_zero_points,
116+
const std::optional<at::Tensor>& k_scales,
117+
const std::optional<at::Tensor>& v_zero_points,
118+
const std::optional<at::Tensor>& v_scales);
119+
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
120+
80121
Tensor& update_cache_out_no_context(
81122
const Tensor& value,
82123
Tensor& cache,
@@ -198,6 +239,85 @@ at::Tensor custom_sdpa_aten(
198239
return output;
199240
}
200241

242+
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
243+
Tensor& custom_quantized_sdpa_out_no_context(
244+
const Tensor& q,
245+
const Tensor& k,
246+
const Tensor& v,
247+
const int64_t start_pos,
248+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
249+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
250+
const optional<Tensor> attn_mask,
251+
const double dropout_p,
252+
const bool is_causal,
253+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
254+
const optional<double> scale,
255+
const optional<Tensor> q_zero_points,
256+
const optional<Tensor> q_scales,
257+
const optional<Tensor> k_zero_points,
258+
const optional<Tensor> k_scales,
259+
const optional<Tensor> v_zero_points,
260+
const optional<Tensor> v_scales,
261+
Tensor& output) {
262+
executorch::aten::RuntimeContext context{};
263+
return torch::executor::native::custom_quantized_sdpa_out(
264+
context,
265+
q,
266+
k,
267+
v,
268+
start_pos,
269+
attn_mask,
270+
dropout_p,
271+
is_causal,
272+
scale,
273+
q_zero_points,
274+
q_scales,
275+
k_zero_points,
276+
k_scales,
277+
v_zero_points,
278+
v_scales,
279+
output);
280+
}
281+
282+
at::Tensor custom_quantized_sdpa_aten(
283+
const at::Tensor& q,
284+
const at::Tensor& k,
285+
const at::Tensor& v,
286+
const int64_t start_pos,
287+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
288+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
289+
const std::optional<at::Tensor> attn_mask,
290+
const double dropout_p,
291+
const bool is_causal,
292+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
293+
const std::optional<double> scale,
294+
const std::optional<at::Tensor>& q_zero_points,
295+
const std::optional<at::Tensor>& q_scales,
296+
const std::optional<at::Tensor>& k_zero_points,
297+
const std::optional<at::Tensor>& k_scales,
298+
const std::optional<at::Tensor>& v_zero_points,
299+
const std::optional<at::Tensor>& v_scales) {
300+
auto output = at::empty(q.sizes());
301+
WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 14)
302+
(q,
303+
k,
304+
v,
305+
start_pos,
306+
attn_mask,
307+
dropout_p,
308+
is_causal,
309+
scale,
310+
q_zero_points,
311+
q_scales,
312+
k_zero_points,
313+
k_scales,
314+
v_zero_points,
315+
v_scales,
316+
output);
317+
return output;
318+
}
319+
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
320+
201321
Tensor& update_cache_out_no_context(
202322
const Tensor& value,
203323
Tensor& cache,
@@ -245,6 +365,20 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
245365
m.def(
246366
"update_cache.out(Tensor value, Tensor(a!) cache, "
247367
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
368+
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
369+
m.def(
370+
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
371+
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
372+
"float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
373+
"Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
374+
"Tensor? v_scales=None) -> Tensor");
375+
m.def(
376+
"custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
377+
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
378+
"float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
379+
"Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
380+
"Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)");
381+
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
248382
}
249383

250384
// TODO: Rename this file to op_custom_ops_aot.cpp
@@ -263,4 +397,13 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
263397
m.impl(
264398
"update_cache.out",
265399
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
400+
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
401+
m.impl(
402+
"custom_quantized_sdpa",
403+
torch::executor::native::custom_quantized_sdpa_aten);
404+
m.impl(
405+
"custom_quantized_sdpa.out",
406+
WRAP_TO_ATEN(
407+
torch::executor::native::custom_quantized_sdpa_out_no_context, 14));
408+
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
266409
}

0 commit comments

Comments
 (0)