Skip to content

Classifier Free Guidance #536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
1 change: 1 addition & 0 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class ExampleRunner
{ "Semantic Kernel: Store", SemanticKernelMemory.Run },
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
};

Expand Down
71 changes: 40 additions & 31 deletions LLama.Examples/Examples/BatchedExecutorFork.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace LLama.Examples.Examples;
public class BatchedExecutorFork
{
private const int n_split = 16;
private const int n_len = 64;
private const int n_len = 72;

public static async Task Run()
{
Expand All @@ -24,41 +24,51 @@ public static async Task Run()
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

// Create an executor that can evaluate a batch of conversations together
var executor = new BatchedExecutor(model, parameters);
using var executor = new BatchedExecutor(model, parameters);

// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// Evaluate the initial prompt to create one conversation
var start = executor.Prompt(prompt);
using var start = executor.Prompt(prompt);
await executor.Infer();

// Create the root node of the tree
var root = new Node(start);

// Run inference loop
for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();

// Occasionally fork all the active conversations
if (i != 0 && i % n_split == 0)
root.Split();

// Sample all active conversations
root.Sample();
}

Console.WriteLine($"{prompt}...");
root.Print(1);

Console.WriteLine("Press any key to exit demo");
Console.ReadKey(true);
await AnsiConsole
.Progress()
.StartAsync(async progress =>
{
var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len);

// Run inference loop
for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();

// Occasionally fork all the active conversations
if (i != 0 && i % n_split == 0)
root.Split();

// Sample all active conversations
root.Sample();

// Update progress bar
reporter.Increment(1);
reporter.Description($"Running Inference ({root.ActiveConversationCount})");
}

// Display results
var display = new Tree(prompt);
root.Display(display);
AnsiConsole.Write(display);
});
}

class Node
private class Node
{
private readonly StreamingTokenDecoder _decoder;

Expand Down Expand Up @@ -116,19 +126,18 @@ public void Split()
}
}

public void Print(int indendation)
public void Display<T>(T tree, int depth = 0)
where T : IHasTreeNodes
{
var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White };
Console.ForegroundColor = colors[indendation % colors.Length];
var colors = new[] { "red", "green", "blue", "yellow", "white" };
var color = colors[depth % colors.Length];

var message = _decoder.Read().ReplaceLineEndings("");

var prefix = new string(' ', indendation * 3);
var suffix = _conversation == null ? "..." : "";
Console.WriteLine($"{prefix}...{message}{suffix}");
var n = tree.AddNode($"[{color}]{message}[/]");

_left?.Print(indendation + 2);
_right?.Print(indendation + 2);
_left?.Display(n, depth + 1);
_right?.Display(n, depth + 1);
}
}
}
128 changes: 128 additions & 0 deletions LLama.Examples/Examples/BatchedExecutorGuidance.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates using a batch to generate two sequences and then using one
/// sequence as the negative guidance ("classifier free guidance") for the other.
/// </summary>
public class BatchedExecutorGuidance
{
private const int n_len = 32;

public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);

var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f);

// Create an executor that can evaluate a batch of conversations together
using var executor = new BatchedExecutor(model, parameters);

// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// Load the two prompts into two conversations
using var guided = executor.Prompt(positivePrompt);
using var guidance = executor.Prompt(negativePrompt);

// Run inference to evaluate prompts
await AnsiConsole
.Status()
.Spinner(Spinner.Known.Line)
.StartAsync("Evaluating Prompts...", _ => executor.Infer());

// Fork the "guided" conversation. We'll run this one without guidance for comparison
using var unguided = guided.Fork();

// Run inference loop
var unguidedSampler = new GuidedSampler(null, weight);
var unguidedDecoder = new StreamingTokenDecoder(executor.Context);
var guidedSampler = new GuidedSampler(guidance, weight);
var guidedDecoder = new StreamingTokenDecoder(executor.Context);
await AnsiConsole
.Progress()
.StartAsync(async progress =>
{
var reporter = progress.AddTask("Running Inference", maxValue: n_len);

for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();

// Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any
// guidance. This serves as a comparison to show the effect of guidance.
var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty<LLamaToken>());
unguidedDecoder.Add(u);
unguided.Prompt(u);

// Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation
// to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file).
var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty<LLamaToken>());
guidedDecoder.Add(g);

// Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
guided.Prompt(g);
guidance.Prompt(g);

// Early exit if we reach the natural end of the guided sentence
if (g == model.EndOfSentenceToken)
break;

// Update progress bar
reporter.Increment(1);
}
});

AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read()}[/]");
AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read()}[/]");
}

private class GuidedSampler(Conversation? guidance, float weight)
: BaseSamplingPipeline
{
public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
{
}

public override ISamplingPipeline Clone()
{
throw new NotSupportedException();
}

protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
if (guidance == null)
return logits;

var logitsCopy = logits.ToArray();

// Get the logits generated by the guidance sequences
var guidanceLogits = guidance.Sample();

// Use those logits to guide this sequence
NativeApi.llama_sample_apply_guidance(ctx, logitsCopy, guidanceLogits, weight);

return logitsCopy;
}

protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
candidates.Temperature(ctx, 0.8f);
candidates.TopK(ctx, 25);

return candidates.SampleToken(ctx);
}
}
}
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/BatchedExecutorRewind.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ public static async Task Run()
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

// Create an executor that can evaluate a batch of conversations together
var executor = new BatchedExecutor(model, parameters);
using var executor = new BatchedExecutor(model, parameters);

// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// Evaluate the initial prompt to create one conversation
var conversation = executor.Prompt(prompt);
using var conversation = executor.Prompt(prompt);

// Create the start node wrapping the conversation
var node = new Node(executor.Context);
Expand Down
50 changes: 50 additions & 0 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,56 @@ public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<LLama
}
}

/// <summary>
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// </summary>
/// <param name="context"></param>
/// <param name="guidanceLogits">Logits extracted from a separate context from the same model.
/// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param>
/// <param name="guidance">Guidance strength. 0 means no guidance, higher values applies stronger guidance</param>
public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan<float> guidanceLogits, float guidance)
{
if (guidanceLogits.Length != data.Length)
throw new ArgumentException("Guidance logits count must equal vocabulary size", nameof(guidanceLogits));

if (guidance < 0)
throw new ArgumentOutOfRangeException(nameof(guidance), "Guidance strength must be greater than or equal to zero");

// this method accepts 0 (no guidance), higher means more. llama.cpp expects 1 (no guidance), higher means more
// Add one to move up to the llama.cpp baseline.
guidance += 1;

// We need logits array, which we don't have at this point.
// Copy them to a temporary array, apply guidance, then copy them back.
var logits = ArrayPool<float>.Shared.Rent(context.VocabCount);
try
{
// Copy logits into a temporary array
for (var i = 0; i < data.Length; i++)
{
ref var item = ref data.Span[i];
logits[(int)item.id] = item.logit;
}

// Apply guidance
NativeApi.llama_sample_apply_guidance(context, logits, guidanceLogits, guidance);

// Copy logits back into data array
for (var i = 0; i < data.Length; i++)
{
ref var item = ref data.Span[i];
item.logit = logits[(int)item.id];
}

// No longer sorted since we just mutated logits!
sorted = false;
}
finally
{
ArrayPool<float>.Shared.Return(logits);
}
}

/// <summary>
/// Sample with temperature.
/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
Expand Down
30 changes: 29 additions & 1 deletion LLama/Native/NativeApi.Sampling.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Runtime.InteropServices;
using System;
using System.Runtime.InteropServices;

namespace LLama.Native
{
Expand All @@ -23,6 +24,33 @@ public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaCont
float penalty_freq,
float penalty_present);

/// <summary>
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// </summary>
/// <param name="ctx"></param>
/// <param name="logits">Logits extracted from the original generation context.</param>
/// <param name="logits_guidance">Logits extracted from a separate context from the same model.
/// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param>
/// <param name="scale">Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.</param>
public static void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<float> logits_guidance, float scale)
{
if (logits == null)
throw new ArgumentNullException(nameof(logits));
if (logits_guidance == null)
throw new ArgumentNullException(nameof(logits_guidance));
if (logits.Length != ctx.VocabCount)
throw new ArgumentException("Logits count must have equal context vocab size", nameof(logits));
if (logits_guidance.Length != ctx.VocabCount)
throw new ArgumentException("Guidance logits count must have equal context vocab size", nameof(logits_guidance));

unsafe
{
fixed (float* logitsPtr = logits)
fixed (float* logitsGuidancePtr = logits_guidance)
llama_sample_apply_guidance(ctx, logitsPtr, logitsGuidancePtr, scale);
}
}

/// <summary>
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// </summary>
Expand Down