Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class ExampleRunner
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
{ "Batched Executor: LLava", BatchedExecutorLLava.Run },
{ "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run },
{ "Speech Chat: Integration with Whisper.net", SpeechChat.Run },
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
};
Expand Down
166 changes: 166 additions & 0 deletions LLama.Examples/Examples/BatchedExecutorBeamSearch.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using Spectre.Console;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates beam search using the batched executor.
///
/// Beam search is a technique for finding the most likely multi-token completion from a prompt. The search keeps track of a
/// set of "beams", each beam is a possible completion and keeps track of it's cumulative probability. At each step all
/// of the current beams are split into multiple beams by extending the beam with different possible tokens (greedy sampling the
/// top N tokens), the set of _all_ beams is then trimmed down to just the most likely beams. This allows multiple possibilities to
/// be considered, and can find a higher probability result than simply greedy sampling the most likely token at every stage.
/// </summary>
public class BatchedExecutorBeamSearch
{
public static async Task Run()
{
// Load model weights
var parameters = new ModelParams(UserSettings.GetModelPath());
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "The cat sat on");
var tokensGenerate = AnsiConsole.Ask("How many tokens to generate?", 8);
var beamsCount = AnsiConsole.Ask("How many parallel beams to keep track of?", 8);

// Create an executor that can evaluate a batch of conversations together
using var executor = new BatchedExecutor(model, parameters);

// Print some info
var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// Evaluate the initial prompt to create one conversation
var conversation = executor.Create();
var startTokens = executor.Context.Tokenize(prompt);
conversation.Prompt(startTokens);

// Create one beam, containing that conversation
var beams = new List<Beam> { new Beam(conversation, 1.0, startTokens, [conversation.ConversationId]) };

// Generate loop
for (var i = 0; i < tokensGenerate; i++)
{
await executor.Infer();

// Create new beams, forked from all original beams
beams = (from oldBeam in beams
from beam in oldBeam.Sample(beamsCount)
select beam).OrderBy(a => a.CumulativeProbability).ToList();

// Trim down list by removing low probability beams
while (beams.Count > beamsCount)
{
var beam = beams[0];
AnsiConsole.MarkupLineInterpolated($"[red]Culling Beam {beam.Conversation.ConversationId} (prob:{beam.CumulativeProbability:P10})[/]: {beam}");

beam.Dispose();
beams.RemoveAt(0);
}

// Normalize all remaining beam probabilities.
NormalizeBeams(beams);
}

// Print out all remaining beams
AnsiConsole.MarkupLineInterpolated($"Final Beams:");
beams.Reverse();
foreach (var beam in beams)
{
AnsiConsole.MarkupLineInterpolated($"[yellow]Probability: {beam.CumulativeProbability:P10}[/]");
AnsiConsole.MarkupLineInterpolated($"[yellow]Sequence: {string.Join(",", beam.Sequence)}[/]");
AnsiConsole.MarkupLineInterpolated($"[green]{beam}[/]");
Console.WriteLine();
}

Console.WriteLine("Press any key to exit demo");
Console.ReadKey(true);
}

/// <summary>
/// As the beam grows the cumulative probability gets very small. Normalizing all the beams prevents the value collapsing to zero.
/// </summary>
/// <param name="beams"></param>
private static void NormalizeBeams(List<Beam> beams)
{
// Find max probability
var max = beams.MaxBy(a => a.CumulativeProbability)!.CumulativeProbability;

// Divide all beams by max, this makes the max prob = 1.0
foreach (var beam in beams)
beam.CumulativeProbability /= max;
}

private class Beam
: IDisposable
{
public readonly Conversation Conversation;
public readonly IReadOnlyList<LLamaToken> Tokens;
public readonly IReadOnlyList<LLamaSeqId> Sequence;

public double CumulativeProbability;

public Beam(Conversation conversation, double prob, IReadOnlyList<LLamaToken> tokens, IReadOnlyList<LLamaSeqId> sequence)
{
Conversation = conversation;
Tokens = tokens;
Sequence = sequence;

CumulativeProbability = prob;
}

public void Dispose()
{
Conversation.Dispose();
}

public List<Beam> Sample(int nbeams)
{
// Apply softmax, this calculates probabilities and sorts tokens into descending order
var logitsArr = LLamaTokenDataArray.Create(Conversation.Sample());
logitsArr.Softmax(Conversation.Executor.Context.NativeHandle);

// Create new forked conversations, one for each beam
var results = new List<Beam>();
for (var i = 0; i < nbeams; i++)
{
// After softmax the logits array is in descending order of probability. Take the first `nbeams` items to make new beams.
var item = logitsArr.Data.Span[i];

// Fork the parent conversation. This shares all of the KV cache with the parent (and other forks)
// so does not cost any extra memory.
var c = Conversation.Fork();

// Extend the conversation with the selected token.
c.Prompt(item.id);

// Keep track of the cumulative probability of this entire sequence.
var p = CumulativeProbability * item.p;

// Keep track of all tokens in this sequence, for decoding later
var t = Tokens.ToList();
t.Add(item.id);

// Keep track of which beam this beam was derived from.
var s = Sequence.ToList();
s.Add(c.ConversationId);

results.Add(new Beam(c, p, t, s));
}

// Dispose self now that child beams have spawned
Conversation.Dispose();
return results;
}

public override string ToString()
{
var decoder = new StreamingTokenDecoder(Conversation.Executor.Context);
decoder.AddRange(Tokens);
return decoder.Read();
}
}
}
8 changes: 7 additions & 1 deletion LLama/Native/LLamaSeqId.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Runtime.InteropServices;
using System.Runtime.InteropServices;

namespace LLama.Native;

Expand Down Expand Up @@ -39,4 +39,10 @@ private LLamaSeqId(int value)
/// <param name="value"></param>
/// <returns></returns>
public static explicit operator LLamaSeqId(int value) => new(value);

/// <inheritdoc />
public readonly override string ToString()
{
return Value.ToString();
}
}