Skip to content

Commit 9b6ed1a

Browse files
jan-wassenbergcopybara-github
authored andcommitted
gemma_batch_bench: generate more unique prompts
PiperOrigin-RevId: 819944137
1 parent 503aadd commit 9b6ed1a

File tree

1 file changed

+64
-33
lines changed

1 file changed

+64
-33
lines changed

evals/gemma_batch_bench.cc

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <stdio.h>
1717

18+
#include <algorithm>
1819
#include <string>
1920
#include <vector>
2021

@@ -48,48 +49,78 @@ class GemmaBatchBench : public ::testing::Test {
4849
};
4950

5051
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
51-
const std::vector<std::string> questions = {
52-
{"Write me a poem about Australia?"},
53-
{"What's the history of Denmark?"},
54-
{"Write me a comedy story about the USA."},
55-
{"Teach me about GPU programming."},
56-
{"Write me a story about the moon."},
57-
{"Write me a story about the universe."},
58-
{"Write a poem about planet earth."},
59-
{"Tell me more about olympic sports."},
60-
{"How would you describe Washington State?"},
61-
{"Write me a story about Silicon Valley."},
62-
{"Write me about your best friend."},
52+
std::vector<std::string> prompts = {
53+
{"Describe dynamic programming."},
54+
{"Explain how electric cars work."},
55+
{"Explain to me how to use Google Maps."},
56+
{"How does AI work?"},
6357
{"How would you describe a unicorn?"},
64-
{"Tell me about world war history."},
58+
{"Please share some good cooking tips."},
59+
{"Teach me about GPU programming."},
60+
{"Tell me a fact about World War 2."},
6561
{"Tell me about Google."},
66-
{"Explain to me how to use Google Maps."},
67-
{"Explain to me how AI works."},
68-
{"Write me a poem about France."},
69-
{"What's the history of Great Britain?"},
62+
{"Tell me more about olympic sports."},
63+
{"Tell me something about space travel."},
64+
{"What is a horse?"},
65+
{"What is Michigan State?"},
66+
{"What's the history of Denmark?"},
67+
{"Write a poem about planet earth."},
68+
{"Write a story about Jupiter."},
69+
{"Write about the moon."},
7070
{"Write me a comedy story about Florida."},
71-
{"Teach me about dynamic programming."},
72-
{"Write me a story about Jupiter."},
73-
{"Write me a story about space ships."},
74-
{"Write a poem about some random planet."},
75-
{"Tell me more about team sports."},
76-
{"How would you describe Michigan State?"},
77-
{"Write me a story about Europe."},
78-
{"Write me about your best colleague."},
79-
{"How would you describe a horse?"},
80-
{"Tell me about World War 2."},
81-
{"Please share some good cooking tips."},
82-
{"Tell me about space travel."},
83-
{"Explain to me how electric cars work."},
71+
{"Write me a poem about France."},
8472
};
73+
const std::vector<std::string> start = {
74+
{"What is"}, {"When did"}, {"Where did"}, {"How did"}, {"Why did"}};
75+
const std::vector<std::string> concepts = {"Socrates",
76+
"Einstein",
77+
"Leonardo",
78+
"Cleopatra",
79+
"Adele",
80+
"Mars",
81+
"Turing",
82+
"Mozart",
83+
"democracy",
84+
"gravity",
85+
"AI",
86+
"evolution",
87+
"physics",
88+
"the internet",
89+
"steam engine",
90+
"inflation",
91+
"electricity",
92+
"the Sahara",
93+
"NASA",
94+
"Rome",
95+
"the UN",
96+
"Google",
97+
"the Renaissance",
98+
"Hamlet",
99+
"poetry",
100+
"Stoicism",
101+
"geometry",
102+
"DNA",
103+
"Star Wars",
104+
"1984"};
105+
const std::vector<std::string> end = {"exist?", "work?", "happen?",
106+
"lead to?", "believe?", "result in?"};
107+
for (const std::string& s : start) {
108+
for (const std::string& c : concepts) {
109+
for (const std::string& e : end) {
110+
prompts.push_back(s + " " + c + " " + e);
111+
}
112+
}
113+
}
114+
AesCtrEngine engine(true);
115+
std::shuffle(prompts.begin(), prompts.end(), RngStream(engine, 123));
85116

86-
// Fills prompts round robin from `questions` until the desired batch size.
117+
// Fills `inputs` by repeating from `prompts` until the desired batch size.
87118
std::vector<std::string> inputs;
88119
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
89120
size_t qpos = 0;
90121
for (size_t i = 0; i < inputs.capacity(); ++i) {
91-
inputs.push_back(questions[qpos++]);
92-
if (qpos == questions.size()) qpos = 0;
122+
inputs.push_back(prompts[qpos++]);
123+
if (qpos == prompts.size()) qpos = 0;
93124
}
94125
s_env->SetMaxGeneratedTokens(24);
95126
std::vector<std::string> responses = BatchGemmaReply(inputs);

0 commit comments

Comments
 (0)