Skip to content

Commit d854471

Browse files
apoorvreddycopybara-github
authored andcommitted
Use vectorized TopK using highway VQSelect
PiperOrigin-RevId: 728159153
1 parent 0e5b59d commit d854471

File tree

4 files changed

+94
-58
lines changed

4 files changed

+94
-58
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ cc_library(
108108
"@highway//:matvec",
109109
"@highway//:profiler",
110110
"@highway//:thread_pool",
111+
"@highway//hwy/contrib/sort:vqsort",
111112
],
112113
)
113114

ops/ops-inl.h

Lines changed: 79 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@
2222
#include <stdio.h>
2323

2424
#include <cmath>
25-
#include <limits>
25+
#include <cstdint>
2626
#include <random>
2727
#include <type_traits> // std::enable_if_t
2828
#include <vector>
2929

3030
#include "compression/compress.h"
3131
#include "util/basics.h" // TokenAndProb
3232
#include "hwy/base.h"
33+
#include "hwy/contrib/sort/order.h"
34+
#include "hwy/contrib/sort/vqsort.h"
3335
#include "hwy/contrib/thread_pool/thread_pool.h"
3436
#include "hwy/detect_targets.h"
3537
#include "hwy/profiler.h"
@@ -54,6 +56,35 @@ namespace gcpp {
5456
namespace HWY_NAMESPACE {
5557
namespace hn = hwy::HWY_NAMESPACE;
5658

59+
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
60+
// casting prob from float to double just makes some changes to the
61+
// exponent bias and pads zeros in the mantissa.
62+
double packed = static_cast<double>(prob);
63+
int64_t packed_int64;
64+
hwy::CopySameSize(&packed, &packed_int64);
65+
// stuff the token into the lower 32 bits of packed_int64. (it is an int32_t
66+
// anyway)
67+
packed_int64 &= 0xFFFFFFFF00000000;
68+
packed_int64 |= token;
69+
// copy bytes back into packed.
70+
hwy::CopySameSize(&packed_int64, &packed);
71+
return packed;
72+
}
73+
74+
HWY_INLINE TokenAndProb UnpackTokenAndProb(double packed) {
75+
TokenAndProb tp;
76+
77+
int64_t packed_int64;
78+
hwy::CopySameSize(&packed, &packed_int64);
79+
tp.token = static_cast<int>(packed_int64 & 0xFFFFFFFFULL);
80+
81+
// clear the lower 32 bits of packed_int64 before copying back into packed.
82+
packed_int64 &= 0xFFFFFFFF00000000ULL;
83+
hwy::CopySameSize(&packed_int64, &packed);
84+
tp.prob = static_cast<float>(packed);
85+
return tp;
86+
}
87+
5788
template <typename To, typename From>
5889
HWY_INLINE constexpr std::enable_if_t<
5990
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
@@ -705,37 +736,44 @@ HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> create_distribution(
705736
}
706737

707738
template <typename TAcceptToken>
708-
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
709-
const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
710-
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
739+
HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
740+
const float* HWY_RESTRICT probabilities, size_t vocab_size, size_t k,
741+
TAcceptToken& accept_token) {
711742
HWY_ASSERT(k != 0);
712743
HWY_ASSERT(k <= vocab_size);
713-
// TODO: Optimize, potentially using new VQSort PartialSort.
714-
// Sorted from highest [0], to lowest [k-1]
715-
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
716-
std::vector<int> indices(k);
717-
size_t num_accepted = 0;
718-
for (size_t i = 0; i < vocab_size; ++i) {
719-
if (probabilities[i] < top_k[k - 1]) continue;
720-
bool accepted =
721-
!accept_token || accept_token(StaticCast<int>(i), probabilities[i]);
722-
if (!accepted) continue;
723-
num_accepted++;
724-
for (size_t j = 0; j < k; ++j) {
725-
if (probabilities[i] > top_k[j]) {
726-
// shift elements by 1, insert the new value, move on to next value
727-
for (size_t idx = k - 1; idx > j; --idx) {
728-
top_k[idx] = top_k[idx - 1];
729-
indices[idx] = indices[idx - 1];
730-
}
731-
top_k[j] = probabilities[i];
732-
indices[j] = StaticCast<int>(i);
733-
break;
734-
}
744+
std::vector<double> packed_token_probs;
745+
for (int32_t i = 0; i < vocab_size; ++i) {
746+
if (accept_token && !accept_token(StaticCast<int>(i), probabilities[i])) {
747+
continue;
735748
}
749+
packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i]));
750+
}
751+
752+
hwy::VQSelect(packed_token_probs.data(), packed_token_probs.size(), k,
753+
hwy::SortDescending());
754+
hwy::VQSort(packed_token_probs.data(), k, hwy::SortDescending());
755+
756+
std::vector<TokenAndProb> token_probs;
757+
token_probs.reserve(k);
758+
for (int32_t i = 0; i < k; ++i) {
759+
token_probs.push_back(UnpackTokenAndProb(packed_token_probs[i]));
736760
}
737-
HWY_ASSERT(k <= num_accepted);
738-
return indices[create_distribution(top_k, temperature)(gen)];
761+
return token_probs;
762+
}
763+
764+
template <typename TAcceptToken>
765+
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
766+
const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
767+
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
768+
std::vector<TokenAndProb> token_probs =
769+
TopK(probabilities, vocab_size, k, accept_token);
770+
std::vector<int> topk_indices(k);
771+
std::vector<float> topk_probs(k);
772+
for (int i = 0; i < k; ++i) {
773+
topk_indices[i] = token_probs[i].token;
774+
topk_probs[i] = token_probs[i].prob;
775+
}
776+
return topk_indices[create_distribution(topk_probs, temperature)(gen)];
739777
}
740778

741779
template <typename TAcceptToken>
@@ -745,40 +783,23 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
745783
// Softmax and sample top-K is equivalent to taking the top-K logits and
746784
// sampling from the softmax of the top-K logits. The latter is faster as it
747785
// avoids computing the softmax of all logits.
748-
HWY_ASSERT(k != 0);
749-
HWY_ASSERT(k <= vocab_size);
750-
751-
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
752-
std::vector<int> indices(k);
753-
size_t num_accepted = 0;
754-
for (size_t i = 0; i < vocab_size; ++i) {
755-
if (logits[i] < top_k[k - 1]) continue;
756-
bool accepted =
757-
!accept_token || accept_token(StaticCast<int>(i), logits[i]);
758-
if (!accepted) continue;
759-
num_accepted++;
760-
for (size_t j = 0; j < k; ++j) {
761-
if (logits[i] > top_k[j]) {
762-
// shift elements by 1, insert the new value, move on to next value
763-
for (size_t idx = k - 1; idx > j; --idx) {
764-
top_k[idx] = top_k[idx - 1];
765-
indices[idx] = indices[idx - 1];
766-
}
767-
top_k[j] = logits[i];
768-
indices[j] = StaticCast<int>(i);
769-
break;
770-
}
771-
}
786+
std::vector<TokenAndProb> token_logits =
787+
TopK(logits, vocab_size, k, accept_token);
788+
std::vector<int> topk_indices(k);
789+
std::vector<float> topk_logits(k);
790+
for (int i = 0; i < token_logits.size(); ++i) {
791+
topk_indices[i] = token_logits[i].token;
792+
topk_logits[i] = token_logits[i].prob;
772793
}
773794

774-
size_t mask = k <= num_accepted ? k : num_accepted;
775-
Softmax(top_k.data(), mask, temperature);
776-
auto distribution = std::discrete_distribution<int>(std::begin(top_k),
777-
std::begin(top_k) + mask);
795+
size_t mask = token_logits.size();
796+
Softmax(topk_logits.data(), mask, temperature);
797+
auto distribution = std::discrete_distribution<int>(
798+
std::begin(topk_logits), std::begin(topk_logits) + mask);
778799
int topk_sampled_index = distribution(gen);
779-
int sampled_index = indices[topk_sampled_index];
800+
int sampled_index = topk_indices[topk_sampled_index];
780801
return TokenAndProb{.token = sampled_index,
781-
.prob = top_k[topk_sampled_index]};
802+
.prob = topk_logits[topk_sampled_index]};
782803
}
783804

784805
// NOLINTNEXTLINE(google-readability-namespace-comments)

ops/ops_test.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,17 @@ void TestSampleTopK() {
600600
}
601601
}
602602

603+
void TestPackTokenAndProb() {
604+
double packed1 = PackTokenAndProb(10, 0.96f);
605+
TokenAndProb unpacked1 = UnpackTokenAndProb(packed1);
606+
EXPECT_EQ(unpacked1.token, 10);
607+
EXPECT_NEAR(unpacked1.prob, 0.96f, 1e-6);
608+
609+
double packed2 = PackTokenAndProb(1000000000, 0.87f);
610+
611+
EXPECT_LT(packed2, packed1);
612+
}
613+
603614
// NOLINTNEXTLINE(google-readability-namespace-comments)
604615
} // namespace HWY_NAMESPACE
605616
} // namespace gcpp
@@ -621,6 +632,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
621632
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm);
622633
HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple);
623634
HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK);
635+
HWY_EXPORT_AND_TEST_P(OpsTest, TestPackTokenAndProb);
624636
HWY_AFTER_TEST();
625637

626638
} // namespace gcpp

util/basics.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,12 @@ static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
5757
}
5858

5959
// Shared between gemma.h and ops-inl.h.
60+
#pragma pack(push, 1)
6061
struct TokenAndProb {
6162
int token;
6263
float prob;
6364
};
65+
#pragma pack(pop)
6466

6567
// Entire size of a 2D array.
6668
struct Extents2D {

0 commit comments

Comments
 (0)