22
22
#include < stdio.h>
23
23
24
24
#include < cmath>
25
- #include < limits >
25
+ #include < cstdint >
26
26
#include < random>
27
27
#include < type_traits> // std::enable_if_t
28
28
#include < vector>
29
29
30
30
#include " compression/compress.h"
31
31
#include " util/basics.h" // TokenAndProb
32
32
#include " hwy/base.h"
33
+ #include " hwy/contrib/sort/order.h"
34
+ #include " hwy/contrib/sort/vqsort.h"
33
35
#include " hwy/contrib/thread_pool/thread_pool.h"
34
36
#include " hwy/detect_targets.h"
35
37
#include " hwy/profiler.h"
@@ -54,6 +56,35 @@ namespace gcpp {
54
56
namespace HWY_NAMESPACE {
55
57
namespace hn = hwy::HWY_NAMESPACE;
56
58
59
+ HWY_INLINE double PackTokenAndProb (int32_t token, float prob) {
60
+ // casting prob from float to double just makes some changes to the
61
+ // exponent bias and pads zeros in the mantissa.
62
+ double packed = static_cast <double >(prob);
63
+ int64_t packed_int64;
64
+ hwy::CopySameSize (&packed, &packed_int64);
65
+ // stuff the token into the lower 32 bits of packed_int64. (it is an int32_t
66
+ // anyway)
67
+ packed_int64 &= 0xFFFFFFFF00000000 ;
68
+ packed_int64 |= token;
69
+ // copy bytes back into packed.
70
+ hwy::CopySameSize (&packed_int64, &packed);
71
+ return packed;
72
+ }
73
+
74
+ HWY_INLINE TokenAndProb UnpackTokenAndProb (double packed) {
75
+ TokenAndProb tp;
76
+
77
+ int64_t packed_int64;
78
+ hwy::CopySameSize (&packed, &packed_int64);
79
+ tp.token = static_cast <int >(packed_int64 & 0xFFFFFFFFULL );
80
+
81
+ // clear the lower 32 bits of packed_int64 before copying back into packed.
82
+ packed_int64 &= 0xFFFFFFFF00000000ULL ;
83
+ hwy::CopySameSize (&packed_int64, &packed);
84
+ tp.prob = static_cast <float >(packed);
85
+ return tp;
86
+ }
87
+
57
88
template <typename To, typename From>
58
89
HWY_INLINE constexpr std::enable_if_t <
59
90
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
@@ -705,37 +736,44 @@ HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> create_distribution(
705
736
}
706
737
707
738
template <typename TAcceptToken>
708
- HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK (
709
- const float * HWY_RESTRICT probabilities, size_t k , size_t vocab_size ,
710
- std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
739
+ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK (
740
+ const float * HWY_RESTRICT probabilities, size_t vocab_size , size_t k ,
741
+ TAcceptToken& accept_token) {
711
742
HWY_ASSERT (k != 0 );
712
743
HWY_ASSERT (k <= vocab_size);
713
- // TODO: Optimize, potentially using new VQSort PartialSort.
714
- // Sorted from highest [0], to lowest [k-1]
715
- std::vector<float > top_k (k, -std::numeric_limits<float >::infinity ());
716
- std::vector<int > indices (k);
717
- size_t num_accepted = 0 ;
718
- for (size_t i = 0 ; i < vocab_size; ++i) {
719
- if (probabilities[i] < top_k[k - 1 ]) continue ;
720
- bool accepted =
721
- !accept_token || accept_token (StaticCast<int >(i), probabilities[i]);
722
- if (!accepted) continue ;
723
- num_accepted++;
724
- for (size_t j = 0 ; j < k; ++j) {
725
- if (probabilities[i] > top_k[j]) {
726
- // shift elements by 1, insert the new value, move on to next value
727
- for (size_t idx = k - 1 ; idx > j; --idx) {
728
- top_k[idx] = top_k[idx - 1 ];
729
- indices[idx] = indices[idx - 1 ];
730
- }
731
- top_k[j] = probabilities[i];
732
- indices[j] = StaticCast<int >(i);
733
- break ;
734
- }
744
+ std::vector<double > packed_token_probs;
745
+ for (int32_t i = 0 ; i < vocab_size; ++i) {
746
+ if (accept_token && !accept_token (StaticCast<int >(i), probabilities[i])) {
747
+ continue ;
735
748
}
749
+ packed_token_probs.push_back (PackTokenAndProb (i, probabilities[i]));
750
+ }
751
+
752
+ hwy::VQSelect (packed_token_probs.data (), packed_token_probs.size (), k,
753
+ hwy::SortDescending ());
754
+ hwy::VQSort (packed_token_probs.data (), k, hwy::SortDescending ());
755
+
756
+ std::vector<TokenAndProb> token_probs;
757
+ token_probs.reserve (k);
758
+ for (int32_t i = 0 ; i < k; ++i) {
759
+ token_probs.push_back (UnpackTokenAndProb (packed_token_probs[i]));
736
760
}
737
- HWY_ASSERT (k <= num_accepted);
738
- return indices[create_distribution (top_k, temperature)(gen)];
761
+ return token_probs;
762
+ }
763
+
764
+ template <typename TAcceptToken>
765
+ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK (
766
+ const float * HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
767
+ std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
768
+ std::vector<TokenAndProb> token_probs =
769
+ TopK (probabilities, vocab_size, k, accept_token);
770
+ std::vector<int > topk_indices (k);
771
+ std::vector<float > topk_probs (k);
772
+ for (int i = 0 ; i < k; ++i) {
773
+ topk_indices[i] = token_probs[i].token ;
774
+ topk_probs[i] = token_probs[i].prob ;
775
+ }
776
+ return topk_indices[create_distribution (topk_probs, temperature)(gen)];
739
777
}
740
778
741
779
template <typename TAcceptToken>
@@ -745,40 +783,23 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
745
783
// Softmax and sample top-K is equivalent to taking the top-K logits and
746
784
// sampling from the softmax of the top-K logits. The latter is faster as it
747
785
// avoids computing the softmax of all logits.
748
- HWY_ASSERT (k != 0 );
749
- HWY_ASSERT (k <= vocab_size);
750
-
751
- std::vector<float > top_k (k, -std::numeric_limits<float >::infinity ());
752
- std::vector<int > indices (k);
753
- size_t num_accepted = 0 ;
754
- for (size_t i = 0 ; i < vocab_size; ++i) {
755
- if (logits[i] < top_k[k - 1 ]) continue ;
756
- bool accepted =
757
- !accept_token || accept_token (StaticCast<int >(i), logits[i]);
758
- if (!accepted) continue ;
759
- num_accepted++;
760
- for (size_t j = 0 ; j < k; ++j) {
761
- if (logits[i] > top_k[j]) {
762
- // shift elements by 1, insert the new value, move on to next value
763
- for (size_t idx = k - 1 ; idx > j; --idx) {
764
- top_k[idx] = top_k[idx - 1 ];
765
- indices[idx] = indices[idx - 1 ];
766
- }
767
- top_k[j] = logits[i];
768
- indices[j] = StaticCast<int >(i);
769
- break ;
770
- }
771
- }
786
+ std::vector<TokenAndProb> token_logits =
787
+ TopK (logits, vocab_size, k, accept_token);
788
+ std::vector<int > topk_indices (k);
789
+ std::vector<float > topk_logits (k);
790
+ for (int i = 0 ; i < token_logits.size (); ++i) {
791
+ topk_indices[i] = token_logits[i].token ;
792
+ topk_logits[i] = token_logits[i].prob ;
772
793
}
773
794
774
- size_t mask = k <= num_accepted ? k : num_accepted ;
775
- Softmax (top_k .data (), mask, temperature);
776
- auto distribution = std::discrete_distribution<int >(std::begin (top_k),
777
- std::begin (top_k ) + mask);
795
+ size_t mask = token_logits. size () ;
796
+ Softmax (topk_logits .data (), mask, temperature);
797
+ auto distribution = std::discrete_distribution<int >(
798
+ std::begin (topk_logits), std::begin (topk_logits ) + mask);
778
799
int topk_sampled_index = distribution (gen);
779
- int sampled_index = indices [topk_sampled_index];
800
+ int sampled_index = topk_indices [topk_sampled_index];
780
801
return TokenAndProb{.token = sampled_index,
781
- .prob = top_k [topk_sampled_index]};
802
+ .prob = topk_logits [topk_sampled_index]};
782
803
}
783
804
784
805
// NOLINTNEXTLINE(google-readability-namespace-comments)
0 commit comments