Skip to content
58 changes: 53 additions & 5 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,13 @@ def test_swizzle_mm():


def _test_scaled_embedding_bag_cpu_helper(
multi_hot, batch_size, vector_size, index_type, qtype
multi_hot,
batch_size,
vector_size,
index_type,
qtype,
out_dtype=torch.float,
):
dtype = torch.float32
include_last_offset = True
mode = "sum"

Expand All @@ -876,7 +880,7 @@ def _test_scaled_embedding_bag_cpu_helper(
1000,
vector_size,
mode=mode,
dtype=dtype,
dtype=torch.float,
include_last_offset=include_last_offset,
)
if qtype == torch.int8:
Expand All @@ -887,19 +891,63 @@ def _test_scaled_embedding_bag_cpu_helper(
qweight = m.weight.data.to(qtype)
m.weight.data = qweight.to(m.weight.dtype)

out_scale = 1.0
if out_dtype == torch.int8:
out_scale = 2.0

with torch.no_grad():
refe_out = m.forward(indices, offsets) * weight_scale
if out_dtype == torch.int8:
refe_out = torch.round(refe_out / out_scale).to(torch.int32)
refe_out = torch.clamp(refe_out, -128, 127).to(out_dtype)
test_out = torch.ops.torchao._scaled_embedding_bag(
qweight,
indices,
offsets,
weight_scale,
1.0,
out_scale,
mode_enum,
include_last_offset,
).to(dtype)
out_dtype,
)
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)


@pytest.mark.skipif(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
@pytest.mark.parametrize(
"multi_hot, batch_size, vector_size, index_type",
EMBEDINGBAG_TEST_PARAMS,
ids=str,
)
def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type):
for out_dtype in [torch.float, torch.int8]:
_test_scaled_embedding_bag_cpu_helper(
multi_hot,
batch_size,
vector_size,
index_type,
torch.int8,
out_dtype,
)


@pytest.mark.skipif(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
@pytest.mark.parametrize(
"multi_hot, batch_size, vector_size, index_type",
EMBEDINGBAG_TEST_PARAMS,
ids=str,
)
def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type):
_test_scaled_embedding_bag_cpu_helper(
multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn
)


if __name__ == "__main__":
pytest.main(sys.argv)
173 changes: 126 additions & 47 deletions torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,38 @@
#include <c10/util/Unroll.h>
#include <torch/all.h>

#define QTYPE_DISPATCH(TYPE, ...) \
[&]() { \
switch (TYPE) { \
case c10::ScalarType::Float8_e4m3fn: { \
using data_t = at::Float8_e4m3fn; \
return __VA_ARGS__(); \
} \
case c10::ScalarType::Char: { \
using data_t = int8_t; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "scaled_embeding_bag: unsupport qtype"); \
} \
}()

#define OUTTYPE_DISPATCH(TYPE, ...) \
[&]() { \
switch (TYPE) { \
case c10::ScalarType::Float: { \
using output_t = float; \
return __VA_ARGS__(); \
} \
case c10::ScalarType::Char: { \
using output_t = int8_t; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "scaled_embeding_bag: unsupport output type"); \
} \
}()

namespace torchao {

namespace {
Expand Down Expand Up @@ -53,14 +85,71 @@ static inline CHUNK load_chunk(const int8_t *x) {
x7 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 3));
return {x0, x1, x2, x3, x4, x5, x6, x7};
}

static inline void store_chunk(float *output, CHUNK chunk) {
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
std::tie(x0, x1, x2, x3, x4, x5, x6, x7) = chunk;
_mm512_store_ps(output, x0);
_mm512_store_ps(output + 16, x1);
_mm512_store_ps(output + 32, x2);
_mm512_store_ps(output + 48, x3);
_mm512_store_ps(output + 64, x4);
_mm512_store_ps(output + 80, x5);
_mm512_store_ps(output + 96, x6);
_mm512_store_ps(output + 112, x7);
}

static inline void store_chunk(int8_t *output, CHUNK chunk) {
__m512i x00, x64;
__m512i y0, y1, y2, y3, y4, y5, y6, y7;
__m512 f0, f1, f2, f3, f4, f5, f6, f7;
std::tie(f0, f1, f2, f3, f4, f5, f6, f7) = chunk;
y0 = _mm512_cvt_roundps_epi32(
f0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
y1 = _mm512_cvt_roundps_epi32(
f1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
y2 = _mm512_cvt_roundps_epi32(
f2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
y3 = _mm512_cvt_roundps_epi32(
f3, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
y4 = _mm512_cvt_roundps_epi32(
f4, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
y5 = _mm512_cvt_roundps_epi32(
f5, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
y6 = _mm512_cvt_roundps_epi32(
f6, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
y7 = _mm512_cvt_roundps_epi32(
f7, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y0), 0);
x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y1), 1);
x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y2), 2);
x00 = _mm512_inserti32x4(x00, _mm512_cvtsepi32_epi8(y3), 3);
x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y4), 0);
x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y5), 1);
x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y6), 2);
x64 = _mm512_inserti32x4(x64, _mm512_cvtsepi32_epi8(y7), 3);
_mm512_store_si512(output, x00);
_mm512_store_si512(output + 64, x64);
}
#endif

template <typename index_t, typename data_t>
static inline void store_elem(float &out, float input) {
out = input;
}

static inline void store_elem(int8_t &out, float input) {
float rounded = std::round(input);
float clamped = std::max(-128.0f, std::min(127.0f, rounded));
int32_t int32_value = static_cast<int32_t>(clamped);
out = static_cast<int8_t>(int32_value);
}

template <typename index_t, typename data_t, typename output_t>
inline void _scaled_embedding_bag_krnl(
const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
const int64_t emb_dim, const index_t last_offset, const index_t *indices,
const index_t *offsets, const data_t *weight, const double scale,
float *result, const int64_t num_batch) {
output_t *result, const int64_t num_batch) {
#if defined(CPU_CAPABILITY_AVX512)
if (emb_dim % 128 == 0) {
constexpr int64_t block_dim = 128;
Expand All @@ -76,7 +165,7 @@ inline void _scaled_embedding_bag_krnl(
for (int64_t block_id = 0; block_id < num_blocks; block_id++) {
// load first indices
int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
float *block_result = result + block_dim * block_id;
output_t *block_result = result + block_dim * block_id;
std::tie(x0, x1, x2, x3, x4, x5, x6, x7) = load_chunk(weight + idx);
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
// add following idx
Expand All @@ -100,14 +189,7 @@ inline void _scaled_embedding_bag_krnl(
x6 = _mm512_mul_ps(x6, scale_v);
x7 = _mm512_mul_ps(x7, scale_v);
// store
_mm512_store_ps(block_result, x0);
_mm512_store_ps(block_result + 16, x1);
_mm512_store_ps(block_result + 32, x2);
_mm512_store_ps(block_result + 48, x3);
_mm512_store_ps(block_result + 64, x4);
_mm512_store_ps(block_result + 80, x5);
_mm512_store_ps(block_result + 96, x6);
_mm512_store_ps(block_result + 112, x7);
store_chunk(block_result, {x0, x1, x2, x3, x4, x5, x6, x7});
}
result += num_emb * emb_dim;
}
Expand All @@ -127,14 +209,14 @@ inline void _scaled_embedding_bag_krnl(
value += float(weight[idx + d]);
}
value = value * scale;
result[d] = value;
store_elem(result[d], value);
}
result += num_emb * emb_dim;
}
}

template <typename index_t, typename data_t>
void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
template <typename index_t, typename data_t, typename output_t>
void _scaled_embedding_bag(output_t *o_ptr, data_t *w_ptr, index_t *indices_ptr,
index_t *offsets_ptr, int64_t num_batch,
int64_t emb_dim, index_t last_offset, double w_scale,
double o_scale) {
Expand All @@ -147,7 +229,7 @@ void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
for (int64_t n = 0; n < num_emb; ++n) {
const int64_t bs_begin = b * b_block;
const int64_t bs_end = std::min(num_batch, (b + 1) * b_block);
float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
output_t *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
// avoid offsets not include last batch
_scaled_embedding_bag_krnl(bs_begin, bs_end, num_emb, emb_dim,
last_offset, indices_ptr, offsets_ptr, w_ptr,
Expand All @@ -156,12 +238,24 @@ void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
}
}

at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
const at::Tensor &indices,
const at::Tensor &offsets,
const at::Tensor &w_scales,
double o_scale, const int64_t mode,
bool include_last_offset) {
template <typename index_t, typename data_t, typename output_t>
void _scaled_embedding_bag_dispatch_dtype(
const at::Tensor &qweight, const at::Tensor &indices,
const at::Tensor &offsets, const at::Tensor &output, int64_t batch_size,
int64_t emb_dim, index_t last_offset, double w_scale, double o_scale) {
data_t *qweight_ptr = qweight.data_ptr<data_t>();
index_t *indices_ptr = indices.data_ptr<index_t>();
index_t *offsets_ptr = offsets.data_ptr<index_t>();
output_t *output_ptr = output.data_ptr<output_t>();
_scaled_embedding_bag<index_t, data_t, output_t>(
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
last_offset, w_scale, o_scale);
}

at::Tensor _scaled_embedding_bag_impl(
const at::Tensor &qweight, const at::Tensor &indices,
const at::Tensor &offsets, const at::Tensor &w_scales, double o_scale,
const int64_t mode, bool include_last_offset, at::ScalarType output_dtype) {
// Only support include_last_offset == True and mode ==
// at::native::EmbeddingBagMode::SUM
// TODO: Support more case
Expand Down Expand Up @@ -193,32 +287,17 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
int64_t last_offset = indices.numel();

at::Tensor output =
at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat));
if (qtype == c10::ScalarType::Float8_e4m3fn) {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "_scaled_embedding_bag", [&] {
at::Float8_e4m3fn *qweight_ptr =
qweight.data_ptr<at::Float8_e4m3fn>();
index_t *indices_ptr = indices.data_ptr<index_t>();
index_t *offsets_ptr = offsets.data_ptr<index_t>();
float *output_ptr = output.data_ptr<float>();
_scaled_embedding_bag<index_t, at::Float8_e4m3fn>(
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
emb_dim, last_offset, w_scale, o_scale);
});
} else {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "_scaled_embedding_bag", [&] {
int8_t *qweight_ptr = qweight.data_ptr<int8_t>();
index_t *indices_ptr = indices.data_ptr<index_t>();
index_t *offsets_ptr = offsets.data_ptr<index_t>();
float *output_ptr = output.data_ptr<float>();
_scaled_embedding_bag<index_t, int8_t>(
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
emb_dim, last_offset, w_scale, o_scale);
});
}

at::empty({batch_size, emb_dim}, qweight.options().dtype(output_dtype));
OUTTYPE_DISPATCH(output_dtype, [&] {
QTYPE_DISPATCH(qtype, [&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "_scaled_embedding_bag", [&] {
_scaled_embedding_bag_dispatch_dtype<index_t, data_t, output_t>(
qweight, indices, offsets, output, batch_size, emb_dim,
last_offset, w_scale, o_scale);
});
});
});
return output;
}

Expand Down
5 changes: 3 additions & 2 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor"
)
lib.define(
"_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor"
"_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset, ScalarType output_dtype) -> Tensor"
)


Expand Down Expand Up @@ -1112,8 +1112,9 @@ def _(
o_scale: float,
mode: int,
include_last_offset: bool,
out_dtype: torch.dtype,
) -> Tensor:
# Only support include_last_offset == True
assert include_last_offset == True
batch_size = offsets.shape[0] - 1
return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype)
return qweight.new_empty(batch_size, qweight.shape[1], dtype=out_dtype)
Loading