diff --git a/test/test_ops.py b/test/test_ops.py index 11ad59db27..240488c637 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -864,9 +864,13 @@ def test_swizzle_mm(): def _test_scaled_embedding_bag_cpu_helper( - multi_hot, batch_size, vector_size, index_type, qtype + multi_hot, + batch_size, + vector_size, + index_type, + qtype, + out_dtype=torch.float, ): - dtype = torch.float32 include_last_offset = True mode = "sum" @@ -883,7 +887,7 @@ def _test_scaled_embedding_bag_cpu_helper( 1000, vector_size, mode=mode, - dtype=dtype, + dtype=torch.float, include_last_offset=include_last_offset, ) if qtype == torch.int8: @@ -894,17 +898,25 @@ def _test_scaled_embedding_bag_cpu_helper( qweight = m.weight.data.to(qtype) m.weight.data = qweight.to(m.weight.dtype) + out_scale = 1.0 + if out_dtype == torch.int8: + out_scale = 2.0 + with torch.no_grad(): refe_out = m.forward(indices, offsets) * weight_scale + if out_dtype == torch.int8: + refe_out = torch.round(refe_out / out_scale).to(torch.int32) + refe_out = torch.clamp(refe_out, -128, 127).to(out_dtype) test_out = torch.ops.torchao._scaled_embedding_bag( qweight, indices, offsets, weight_scale, - 1.0, + out_scale, mode_enum, include_last_offset, - ).to(dtype) + out_dtype, + ) torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5) @@ -918,9 +930,15 @@ def _test_scaled_embedding_bag_cpu_helper( ids=str, ) def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type): - _test_scaled_embedding_bag_cpu_helper( - multi_hot, batch_size, vector_size, index_type, torch.int8 - ) + for out_dtype in [torch.float, torch.int8]: + _test_scaled_embedding_bag_cpu_helper( + multi_hot, + batch_size, + vector_size, + index_type, + torch.int8, + out_dtype, + ) @pytest.mark.skipif( diff --git a/torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp b/torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp index 7ab518aa32..82b5c59c5c 100644 --- a/torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp +++ b/torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp @@ -6,6 +6,38 @@ #include #include +#define QTYPE_DISPATCH(TYPE, ...) \ + [&]() { \ + switch (TYPE) { \ + case c10::ScalarType::Float8_e4m3fn: { \ + using data_t = at::Float8_e4m3fn; \ + return __VA_ARGS__(); \ + } \ + case c10::ScalarType::Char: { \ + using data_t = int8_t; \ + return __VA_ARGS__(); \ + } \ + default: \ + TORCH_CHECK(false, "scaled_embeding_bag: unsupport qtype"); \ + } \ + }() + +#define OUTTYPE_DISPATCH(TYPE, ...) \ + [&]() { \ + switch (TYPE) { \ + case c10::ScalarType::Float: { \ + using output_t = float; \ + return __VA_ARGS__(); \ + } \ + case c10::ScalarType::Char: { \ + using output_t = int8_t; \ + return __VA_ARGS__(); \ + } \ + default: \ + TORCH_CHECK(false, "scaled_embeding_bag: unsupport output type"); \ + } \ + }() + namespace torchao { namespace { @@ -53,14 +85,71 @@ static inline CHUNK load_chunk(const int8_t *x) { x7 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 3)); return {x0, x1, x2, x3, x4, x5, x6, x7}; } + +static inline void store_chunk(float *output, CHUNK chunk) { + __m512 x0, x1, x2, x3, x4, x5, x6, x7; + std::tie(x0, x1, x2, x3, x4, x5, x6, x7) = chunk; + _mm512_store_ps(output, x0); + _mm512_store_ps(output + 16, x1); + _mm512_store_ps(output + 32, x2); + _mm512_store_ps(output + 48, x3); + _mm512_store_ps(output + 64, x4); + _mm512_store_ps(output + 80, x5); + _mm512_store_ps(output + 96, x6); + _mm512_store_ps(output + 112, x7); +} + +static inline void store_chunk(int8_t *output, CHUNK chunk) { + __m512i x00, x64; + __m512i y0, y1, y2, y3, y4, y5, y6, y7; + __m512 f0, f1, f2, f3, f4, f5, f6, f7; + std::tie(f0, f1, f2, f3, f4, f5, f6, f7) = chunk; + y0 = _mm512_cvt_roundps_epi32( + f0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + y1 = _mm512_cvt_roundps_epi32( + f1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + y2 = _mm512_cvt_roundps_epi32( + f2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + y3 = _mm512_cvt_roundps_epi32( + f3, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + y4 = _mm512_cvt_roundps_epi32( + f4, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + y5 = _mm512_cvt_roundps_epi32( + f5, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + y6 = _mm512_cvt_roundps_epi32( + f6, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + y7 = _mm512_cvt_roundps_epi32( + f7, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y0), 0); + x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y1), 1); + x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y2), 2); + x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y3), 3); + x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y4), 0); + x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y5), 1); + x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y6), 2); + x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y7), 3); + _mm512_store_si512(output, x00); + _mm512_store_si512(output + 64, x64); +} #endif -template +static inline void store_elem(float &out, float input) { + out = input; +} + +static inline void store_elem(int8_t &out, float input) { + float rounded = std::round(input); + float clamped = std::max(-128.0f, std::min(127.0f, rounded)); + int32_t int32_value = static_cast(clamped); + out = static_cast(int32_value); +} + +template inline void _scaled_embedding_bag_krnl( const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb, const int64_t emb_dim, const index_t last_offset, const index_t *indices, const index_t *offsets, const data_t *weight, const double scale, - float *result, const int64_t num_batch) { + output_t *result, const int64_t num_batch) { #if defined(CPU_CAPABILITY_AVX512) if (emb_dim % 128 == 0) { constexpr int64_t block_dim = 128; @@ -76,7 +165,7 @@ inline void _scaled_embedding_bag_krnl( for (int64_t block_id = 0; block_id < num_blocks; block_id++) { // load first indices int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id; - float *block_result = result + block_dim * block_id; + output_t *block_result = result + block_dim * block_id; std::tie(x0, x1, x2, x3, x4, x5, x6, x7) = load_chunk(weight + idx); for (int64_t j = start_idx + 1; j < end_idx; ++j) { // add following idx @@ -100,14 +189,7 @@ inline void _scaled_embedding_bag_krnl( x6 = _mm512_mul_ps(x6, scale_v); x7 = _mm512_mul_ps(x7, scale_v); // store - _mm512_store_ps(block_result, x0); - _mm512_store_ps(block_result + 16, x1); - _mm512_store_ps(block_result + 32, x2); - _mm512_store_ps(block_result + 48, x3); - _mm512_store_ps(block_result + 64, x4); - _mm512_store_ps(block_result + 80, x5); - _mm512_store_ps(block_result + 96, x6); - _mm512_store_ps(block_result + 112, x7); + store_chunk(block_result, {x0, x1, x2, x3, x4, x5, x6, x7}); } result += num_emb * emb_dim; } @@ -127,14 +209,14 @@ inline void _scaled_embedding_bag_krnl( value += float(weight[idx + d]); } value = value * scale; - result[d] = value; + store_elem(result[d], value); } result += num_emb * emb_dim; } } -template -void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr, +template +void _scaled_embedding_bag(output_t *o_ptr, data_t *w_ptr, index_t *indices_ptr, index_t *offsets_ptr, int64_t num_batch, int64_t emb_dim, index_t last_offset, double w_scale, double o_scale) { @@ -147,7 +229,7 @@ void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr, for (int64_t n = 0; n < num_emb; ++n) { const int64_t bs_begin = b * b_block; const int64_t bs_end = std::min(num_batch, (b + 1) * b_block); - float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim]; + output_t *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim]; // avoid offsets not include last batch _scaled_embedding_bag_krnl(bs_begin, bs_end, num_emb, emb_dim, last_offset, indices_ptr, offsets_ptr, w_ptr, @@ -156,12 +238,24 @@ void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr, } } -at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight, - const at::Tensor &indices, - const at::Tensor &offsets, - const at::Tensor &w_scales, - double o_scale, const int64_t mode, - bool include_last_offset) { +template +void _scaled_embedding_bag_dispatch_dtype( + const at::Tensor &qweight, const at::Tensor &indices, + const at::Tensor &offsets, const at::Tensor &output, int64_t batch_size, + int64_t emb_dim, index_t last_offset, double w_scale, double o_scale) { + data_t *qweight_ptr = qweight.data_ptr(); + index_t *indices_ptr = indices.data_ptr(); + index_t *offsets_ptr = offsets.data_ptr(); + output_t *output_ptr = output.data_ptr(); + _scaled_embedding_bag( + output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim, + last_offset, w_scale, o_scale); +} + +at::Tensor _scaled_embedding_bag_impl( + const at::Tensor &qweight, const at::Tensor &indices, + const at::Tensor &offsets, const at::Tensor &w_scales, double o_scale, + const int64_t mode, bool include_last_offset, at::ScalarType output_dtype) { // Only support include_last_offset == True and mode == // at::native::EmbeddingBagMode::SUM // TODO: Support more case @@ -193,32 +287,17 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight, int64_t last_offset = indices.numel(); at::Tensor output = - at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat)); - if (qtype == c10::ScalarType::Float8_e4m3fn) { - AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), "_scaled_embedding_bag", [&] { - at::Float8_e4m3fn *qweight_ptr = - qweight.data_ptr(); - index_t *indices_ptr = indices.data_ptr(); - index_t *offsets_ptr = offsets.data_ptr(); - float *output_ptr = output.data_ptr(); - _scaled_embedding_bag( - output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, - emb_dim, last_offset, w_scale, o_scale); - }); - } else { - AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), "_scaled_embedding_bag", [&] { - int8_t *qweight_ptr = qweight.data_ptr(); - index_t *indices_ptr = indices.data_ptr(); - index_t *offsets_ptr = offsets.data_ptr(); - float *output_ptr = output.data_ptr(); - _scaled_embedding_bag( - output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, - emb_dim, last_offset, w_scale, o_scale); - }); - } - + at::empty({batch_size, emb_dim}, qweight.options().dtype(output_dtype)); + OUTTYPE_DISPATCH(output_dtype, [&] { + QTYPE_DISPATCH(qtype, [&] { + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "_scaled_embedding_bag", [&] { + _scaled_embedding_bag_dispatch_dtype( + qweight, indices, offsets, output, batch_size, emb_dim, + last_offset, w_scale, o_scale); + }); + }); + }); return output; } diff --git a/torchao/ops.py b/torchao/ops.py index f4191d60b5..6748565fe4 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -69,7 +69,7 @@ "da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor" ) lib.define( - "_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor" + "_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset, ScalarType output_dtype) -> Tensor" ) lib.define( "float8_linear_prepack_cpu(Tensor weight, Tensor scales) -> (Tensor, Tensor)" @@ -1118,13 +1118,11 @@ def _( o_scale: float, mode: int, include_last_offset: bool, + out_dtype: torch.dtype, ) -> Tensor: # Only support include_last_offset == True assert include_last_offset == True batch_size = offsets.shape[0] - 1 - # Only support out_dtype == torch.float32 - # Next setp: support more out_dtype - out_dtype = torch.float32 return qweight.new_empty(batch_size, qweight.shape[1], dtype=out_dtype)