diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs
index c194f6f87..09819c120 100644
--- a/LLama.Examples/ExampleRunner.cs
+++ b/LLama.Examples/ExampleRunner.cs
@@ -33,6 +33,7 @@ public class ExampleRunner
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
{ "Batched Executor: LLava", BatchedExecutorLLava.Run },
+ { "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run },
{ "Speech Chat: Integration with Whisper.net", SpeechChat.Run },
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
};
diff --git a/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs b/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs
new file mode 100644
index 000000000..ce91fccc3
--- /dev/null
+++ b/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs
@@ -0,0 +1,166 @@
+using LLama.Batched;
+using LLama.Common;
+using LLama.Native;
+using Spectre.Console;
+
+namespace LLama.Examples.Examples;
+
+///
+/// This demonstrates beam search using the batched executor.
+///
+/// Beam search is a technique for finding the most likely multi-token completion from a prompt. The search keeps track of a
+/// set of "beams", each beam is a possible completion and keeps track of it's cumulative probability. At each step all
+/// of the current beams are split into multiple beams by extending the beam with different possible tokens (greedy sampling the
+/// top N tokens), the set of _all_ beams is then trimmed down to just the most likely beams. This allows multiple possibilities to
+/// be considered, and can find a higher probability result than simply greedy sampling the most likely token at every stage.
+///
+public class BatchedExecutorBeamSearch
+{
+ public static async Task Run()
+ {
+ // Load model weights
+ var parameters = new ModelParams(UserSettings.GetModelPath());
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
+
+ var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "The cat sat on");
+ var tokensGenerate = AnsiConsole.Ask("How many tokens to generate?", 8);
+ var beamsCount = AnsiConsole.Ask("How many parallel beams to keep track of?", 8);
+
+ // Create an executor that can evaluate a batch of conversations together
+ using var executor = new BatchedExecutor(model, parameters);
+
+ // Print some info
+ var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name");
+ Console.WriteLine($"Created executor with model: {name}");
+
+ // Evaluate the initial prompt to create one conversation
+ var conversation = executor.Create();
+ var startTokens = executor.Context.Tokenize(prompt);
+ conversation.Prompt(startTokens);
+
+ // Create one beam, containing that conversation
+ var beams = new List { new Beam(conversation, 1.0, startTokens, [conversation.ConversationId]) };
+
+ // Generate loop
+ for (var i = 0; i < tokensGenerate; i++)
+ {
+ await executor.Infer();
+
+ // Create new beams, forked from all original beams
+ beams = (from oldBeam in beams
+ from beam in oldBeam.Sample(beamsCount)
+ select beam).OrderBy(a => a.CumulativeProbability).ToList();
+
+ // Trim down list by removing low probability beams
+ while (beams.Count > beamsCount)
+ {
+ var beam = beams[0];
+ AnsiConsole.MarkupLineInterpolated($"[red]Culling Beam {beam.Conversation.ConversationId} (prob:{beam.CumulativeProbability:P10})[/]: {beam}");
+
+ beam.Dispose();
+ beams.RemoveAt(0);
+ }
+
+ // Normalize all remaining beam probabilities.
+ NormalizeBeams(beams);
+ }
+
+ // Print out all remaining beams
+ AnsiConsole.MarkupLineInterpolated($"Final Beams:");
+ beams.Reverse();
+ foreach (var beam in beams)
+ {
+ AnsiConsole.MarkupLineInterpolated($"[yellow]Probability: {beam.CumulativeProbability:P10}[/]");
+ AnsiConsole.MarkupLineInterpolated($"[yellow]Sequence: {string.Join(",", beam.Sequence)}[/]");
+ AnsiConsole.MarkupLineInterpolated($"[green]{beam}[/]");
+ Console.WriteLine();
+ }
+
+ Console.WriteLine("Press any key to exit demo");
+ Console.ReadKey(true);
+ }
+
+ ///
+ /// As the beam grows the cumulative probability gets very small. Normalizing all the beams prevents the value collapsing to zero.
+ ///
+ ///
+ private static void NormalizeBeams(List beams)
+ {
+ // Find max probability
+ var max = beams.MaxBy(a => a.CumulativeProbability)!.CumulativeProbability;
+
+ // Divide all beams by max, this makes the max prob = 1.0
+ foreach (var beam in beams)
+ beam.CumulativeProbability /= max;
+ }
+
+ private class Beam
+ : IDisposable
+ {
+ public readonly Conversation Conversation;
+ public readonly IReadOnlyList Tokens;
+ public readonly IReadOnlyList Sequence;
+
+ public double CumulativeProbability;
+
+ public Beam(Conversation conversation, double prob, IReadOnlyList tokens, IReadOnlyList sequence)
+ {
+ Conversation = conversation;
+ Tokens = tokens;
+ Sequence = sequence;
+
+ CumulativeProbability = prob;
+ }
+
+ public void Dispose()
+ {
+ Conversation.Dispose();
+ }
+
+ public List Sample(int nbeams)
+ {
+ // Apply softmax, this calculates probabilities and sorts tokens into descending order
+ var logitsArr = LLamaTokenDataArray.Create(Conversation.Sample());
+ logitsArr.Softmax(Conversation.Executor.Context.NativeHandle);
+
+ // Create new forked conversations, one for each beam
+ var results = new List();
+ for (var i = 0; i < nbeams; i++)
+ {
+ // After softmax the logits array is in descending order of probability. Take the first `nbeams` items to make new beams.
+ var item = logitsArr.Data.Span[i];
+
+ // Fork the parent conversation. This shares all of the KV cache with the parent (and other forks)
+ // so does not cost any extra memory.
+ var c = Conversation.Fork();
+
+ // Extend the conversation with the selected token.
+ c.Prompt(item.id);
+
+ // Keep track of the cumulative probability of this entire sequence.
+ var p = CumulativeProbability * item.p;
+
+ // Keep track of all tokens in this sequence, for decoding later
+ var t = Tokens.ToList();
+ t.Add(item.id);
+
+ // Keep track of which beam this beam was derived from.
+ var s = Sequence.ToList();
+ s.Add(c.ConversationId);
+
+ results.Add(new Beam(c, p, t, s));
+ }
+
+ // Dispose self now that child beams have spawned
+ Conversation.Dispose();
+ return results;
+ }
+
+ public override string ToString()
+ {
+ var decoder = new StreamingTokenDecoder(Conversation.Executor.Context);
+ decoder.AddRange(Tokens);
+ return decoder.Read();
+ }
+ }
+}
\ No newline at end of file
diff --git a/LLama/Native/LLamaSeqId.cs b/LLama/Native/LLamaSeqId.cs
index 8a3dae5d8..e3c6e8f43 100644
--- a/LLama/Native/LLamaSeqId.cs
+++ b/LLama/Native/LLamaSeqId.cs
@@ -1,4 +1,4 @@
-using System.Runtime.InteropServices;
+using System.Runtime.InteropServices;
namespace LLama.Native;
@@ -39,4 +39,10 @@ private LLamaSeqId(int value)
///
///
public static explicit operator LLamaSeqId(int value) => new(value);
+
+ ///
+ public readonly override string ToString()
+ {
+ return Value.ToString();
+ }
}
\ No newline at end of file