|
15 | 15 |
|
16 | 16 | #include <stdio.h> |
17 | 17 |
|
| 18 | +#include <algorithm> |
18 | 19 | #include <string> |
19 | 20 | #include <vector> |
20 | 21 |
|
@@ -48,48 +49,78 @@ class GemmaBatchBench : public ::testing::Test { |
48 | 49 | }; |
49 | 50 |
|
50 | 51 | TEST_F(GemmaBatchBench, RandomQuestionsBatched) { |
51 | | - const std::vector<std::string> questions = { |
52 | | - {"Write me a poem about Australia?"}, |
53 | | - {"What's the history of Denmark?"}, |
54 | | - {"Write me a comedy story about the USA."}, |
55 | | - {"Teach me about GPU programming."}, |
56 | | - {"Write me a story about the moon."}, |
57 | | - {"Write me a story about the universe."}, |
58 | | - {"Write a poem about planet earth."}, |
59 | | - {"Tell me more about olympic sports."}, |
60 | | - {"How would you describe Washington State?"}, |
61 | | - {"Write me a story about Silicon Valley."}, |
62 | | - {"Write me about your best friend."}, |
| 52 | + std::vector<std::string> prompts = { |
| 53 | + {"Describe dynamic programming."}, |
| 54 | + {"Explain how electric cars work."}, |
| 55 | + {"Explain to me how to use Google Maps."}, |
| 56 | + {"How does AI work?"}, |
63 | 57 | {"How would you describe a unicorn?"}, |
64 | | - {"Tell me about world war history."}, |
| 58 | + {"Please share some good cooking tips."}, |
| 59 | + {"Teach me about GPU programming."}, |
| 60 | + {"Tell me a fact about World War 2."}, |
65 | 61 | {"Tell me about Google."}, |
66 | | - {"Explain to me how to use Google Maps."}, |
67 | | - {"Explain to me how AI works."}, |
68 | | - {"Write me a poem about France."}, |
69 | | - {"What's the history of Great Britain?"}, |
| 62 | + {"Tell me more about olympic sports."}, |
| 63 | + {"Tell me something about space travel."}, |
| 64 | + {"What is a horse?"}, |
| 65 | + {"What is Michigan State?"}, |
| 66 | + {"What's the history of Denmark?"}, |
| 67 | + {"Write a poem about planet earth."}, |
| 68 | + {"Write a story about Jupiter."}, |
| 69 | + {"Write about the moon."}, |
70 | 70 | {"Write me a comedy story about Florida."}, |
71 | | - {"Teach me about dynamic programming."}, |
72 | | - {"Write me a story about Jupiter."}, |
73 | | - {"Write me a story about space ships."}, |
74 | | - {"Write a poem about some random planet."}, |
75 | | - {"Tell me more about team sports."}, |
76 | | - {"How would you describe Michigan State?"}, |
77 | | - {"Write me a story about Europe."}, |
78 | | - {"Write me about your best colleague."}, |
79 | | - {"How would you describe a horse?"}, |
80 | | - {"Tell me about World War 2."}, |
81 | | - {"Please share some good cooking tips."}, |
82 | | - {"Tell me about space travel."}, |
83 | | - {"Explain to me how electric cars work."}, |
| 71 | + {"Write me a poem about France."}, |
84 | 72 | }; |
| 73 | + const std::vector<std::string> start = { |
| 74 | + {"What is"}, {"When did"}, {"Where did"}, {"How did"}, {"Why did"}}; |
| 75 | + const std::vector<std::string> concepts = {"Socrates", |
| 76 | + "Einstein", |
| 77 | + "Leonardo", |
| 78 | + "Cleopatra", |
| 79 | + "Adele", |
| 80 | + "Mars", |
| 81 | + "Turing", |
| 82 | + "Mozart", |
| 83 | + "democracy", |
| 84 | + "gravity", |
| 85 | + "AI", |
| 86 | + "evolution", |
| 87 | + "physics", |
| 88 | + "the internet", |
| 89 | + "steam engine", |
| 90 | + "inflation", |
| 91 | + "electricity", |
| 92 | + "the Sahara", |
| 93 | + "NASA", |
| 94 | + "Rome", |
| 95 | + "the UN", |
| 96 | + "Google", |
| 97 | + "the Renaissance", |
| 98 | + "Hamlet", |
| 99 | + "poetry", |
| 100 | + "Stoicism", |
| 101 | + "geometry", |
| 102 | + "DNA", |
| 103 | + "Star Wars", |
| 104 | + "1984"}; |
| 105 | + const std::vector<std::string> end = {"exist?", "work?", "happen?", |
| 106 | + "lead to?", "believe?", "result in?"}; |
| 107 | + for (const std::string& s : start) { |
| 108 | + for (const std::string& c : concepts) { |
| 109 | + for (const std::string& e : end) { |
| 110 | + prompts.push_back(s + " " + c + " " + e); |
| 111 | + } |
| 112 | + } |
| 113 | + } |
| 114 | + AesCtrEngine engine(true); |
| 115 | + std::shuffle(prompts.begin(), prompts.end(), RngStream(engine, 123)); |
85 | 116 |
|
86 | | - // Fills prompts round robin from `questions` until the desired batch size. |
| 117 | + // Fills `inputs` by repeating from `prompts` until the desired batch size. |
87 | 118 | std::vector<std::string> inputs; |
88 | 119 | inputs.reserve(s_env->MutableConfig().decode_qbatch_size); |
89 | 120 | size_t qpos = 0; |
90 | 121 | for (size_t i = 0; i < inputs.capacity(); ++i) { |
91 | | - inputs.push_back(questions[qpos++]); |
92 | | - if (qpos == questions.size()) qpos = 0; |
| 122 | + inputs.push_back(prompts[qpos++]); |
| 123 | + if (qpos == prompts.size()) qpos = 0; |
93 | 124 | } |
94 | 125 | s_env->SetMaxGeneratedTokens(24); |
95 | 126 | std::vector<std::string> responses = BatchGemmaReply(inputs); |
|
0 commit comments