diff --git a/kernels/quantized/cpu/op_embedding4b.cpp b/kernels/quantized/cpu/op_embedding4b.cpp index 33be86e5cc4..f234ee224ca 100644 --- a/kernels/quantized/cpu/op_embedding4b.cpp +++ b/kernels/quantized/cpu/op_embedding4b.cpp @@ -195,7 +195,7 @@ void resize_out_tensor( for (size_t i = 0; i < indices.dim(); i++) { expected_output_size[i] = indices.size(i); } - const size_t embedding_dim = weight.size(1); + const size_t embedding_dim = weight.size(1) * 2; expected_output_size[out.dim() - 1] = embedding_dim; exec_aten::ArrayRef output_size{ diff --git a/kernels/quantized/test/op_embedding4b_test.cpp b/kernels/quantized/test/op_embedding4b_test.cpp index 56944c57857..1eb7aa11b2a 100644 --- a/kernels/quantized/test/op_embedding4b_test.cpp +++ b/kernels/quantized/test/op_embedding4b_test.cpp @@ -19,6 +19,7 @@ using namespace ::testing; using exec_aten::ArrayRef; using exec_aten::optional; +using exec_aten::RuntimeContext; using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::native::quantized_embedding_4bit_out; @@ -60,6 +61,20 @@ TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbedding) { EXPECT_TENSOR_EQ(out, expected); + out = tf.zeros({3, 4}); + auto context = RuntimeContext(); + torch::executor::native::quantized_embedding_4bit_out( + context, + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); + // Groupwise quantization. groupsize = 2 weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0}); weight_zero_points = tf.make({3, 2}, {1, -5, 0, 2, -3, -1});