Skip to content

Commit 80d1080

Browse files
committed
Created a guided sampling demo using the batched executor
1 parent 528bb01 commit 80d1080

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

LLama.Examples/ExampleRunner.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public class ExampleRunner
2626
{ "Semantic Kernel: Store", SemanticKernelMemory.Run },
2727
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
2828
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
29+
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
2930
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
3031
};
3132

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
using LLama.Batched;
2+
using LLama.Common;
3+
using LLama.Native;
4+
using LLama.Sampling;
5+
using Spectre.Console;
6+
7+
namespace LLama.Examples.Examples;
8+
9+
/// <summary>
10+
/// This demonstrates using a batch to generate two sequences and then using one
11+
/// sequence as the negative guidance ("context free guidance") for the other.
12+
/// </summary>
13+
public class BatchedExecutorGuidance
14+
{
15+
private const int n_len = 32;
16+
17+
public static async Task Run()
18+
{
19+
string modelPath = UserSettings.GetModelPath();
20+
21+
var parameters = new ModelParams(modelPath);
22+
using var model = LLamaWeights.LoadFromFile(parameters);
23+
24+
var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
25+
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
26+
27+
// Create an executor that can evaluate a batch of conversations together
28+
var executor = new BatchedExecutor(model, parameters);
29+
30+
// Print some info
31+
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
32+
Console.WriteLine($"Created executor with model: {name}");
33+
34+
// Load the two prompts into two conversations
35+
var guided = executor.Prompt(positivePrompt);
36+
var guidance = executor.Prompt(negativePrompt);
37+
38+
// Run inference to evaluate prompts
39+
await AnsiConsole
40+
.Status()
41+
.Spinner(Spinner.Known.Line)
42+
.StartAsync("Evaluating Prompts...", _ => executor.Infer());
43+
44+
// Fork the "guided" conversation. We'll run this one without guidance for comparison
45+
var unguided = guided.Fork();
46+
47+
// Run inference loop
48+
var unguidedSampler = new GuidedSampler(null);
49+
var unguidedDecoder = new StreamingTokenDecoder(executor.Context);
50+
var guidedSampler = new GuidedSampler(guidance);
51+
var guidedDecoder = new StreamingTokenDecoder(executor.Context);
52+
await AnsiConsole
53+
.Progress()
54+
.StartAsync(async progress =>
55+
{
56+
var reporter = progress.AddTask("Running Inference", maxValue: n_len);
57+
58+
for (var i = 0; i < n_len; i++)
59+
{
60+
if (i != 0)
61+
await executor.Infer();
62+
63+
// Sample from the "unguided" conversation
64+
var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample().ToArray(), Array.Empty<LLamaToken>());
65+
unguidedDecoder.Add(u);
66+
unguided.Prompt(u);
67+
68+
// Sample form the "guided" conversation
69+
var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample().ToArray(), Array.Empty<LLamaToken>());
70+
guidedDecoder.Add(g);
71+
72+
// Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
73+
guided.Prompt(g);
74+
guidance.Prompt(g);
75+
76+
// Early exit if we reach the natural end of the guided sentence
77+
if (g == model.EndOfSentenceToken)
78+
break;
79+
80+
reporter.Increment(1);
81+
}
82+
});
83+
84+
AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read()}[/]");
85+
AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read()}[/]");
86+
}
87+
88+
private class GuidedSampler(Conversation? guidance)
89+
: BaseSamplingPipeline
90+
{
91+
public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
92+
{
93+
}
94+
95+
public override ISamplingPipeline Clone()
96+
{
97+
throw new NotSupportedException();
98+
}
99+
100+
protected override IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx)
101+
{
102+
return Array.Empty<LLamaToken>();
103+
}
104+
105+
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
106+
{
107+
if (guidance != null)
108+
{
109+
// Get the logits generated by the guidance sequences
110+
var guidanceLogits = guidance.Sample();
111+
112+
// Use those logits to guide this sequence
113+
NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, 2);
114+
}
115+
}
116+
117+
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
118+
{
119+
candidates.Temperature(ctx, 0.8f);
120+
candidates.TopK(ctx, 25);
121+
122+
return candidates.SampleToken(ctx);
123+
}
124+
}
125+
}

0 commit comments

Comments
 (0)