Skip to content

Commit d03c1a9

Browse files
authored
Merge pull request #503 from martindevans/batched_executor_again
Introduced a new `BatchedExecutor`
2 parents 968e1e4 + e9d9042 commit d03c1a9

17 files changed

+912
-204
lines changed

LLama.Examples/Examples/BatchedDecoding.cs

Lines changed: 0 additions & 172 deletions
This file was deleted.
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
using LLama.Batched;
2+
using LLama.Common;
3+
using LLama.Native;
4+
using LLama.Sampling;
5+
6+
namespace LLama.Examples.Examples;
7+
8+
/// <summary>
9+
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
10+
/// </summary>
11+
public class BatchedExecutorFork
12+
{
13+
private const int n_split = 16;
14+
private const int n_len = 64;
15+
16+
public static async Task Run()
17+
{
18+
Console.Write("Please input your model path: ");
19+
var modelPath = Console.ReadLine();
20+
21+
var parameters = new ModelParams(modelPath);
22+
using var model = LLamaWeights.LoadFromFile(parameters);
23+
24+
Console.WriteLine("Prompt (leave blank to select automatically):");
25+
var prompt = Console.ReadLine();
26+
if (string.IsNullOrWhiteSpace(prompt))
27+
prompt = "Not many people know that";
28+
29+
// Create an executor that can evaluate a batch of conversations together
30+
var executor = new BatchedExecutor(model, parameters);
31+
32+
// Print some info
33+
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
34+
Console.WriteLine($"Created executor with model: {name}");
35+
36+
// Evaluate the initial prompt to create one conversation
37+
var start = executor.Prompt(prompt);
38+
await executor.Infer();
39+
40+
// Create the root node of the tree
41+
var root = new Node(start);
42+
43+
// Run inference loop
44+
for (var i = 0; i < n_len; i++)
45+
{
46+
if (i != 0)
47+
await executor.Infer();
48+
49+
// Occasionally fork all the active conversations
50+
if (i != 0 && i % n_split == 0)
51+
root.Split();
52+
53+
// Sample all active conversations
54+
root.Sample();
55+
}
56+
57+
Console.WriteLine($"{prompt}...");
58+
root.Print(1);
59+
60+
Console.WriteLine("Press any key to exit demo");
61+
Console.ReadKey(true);
62+
}
63+
64+
class Node
65+
{
66+
private readonly StreamingTokenDecoder _decoder;
67+
68+
private readonly DefaultSamplingPipeline _sampler;
69+
private Conversation? _conversation;
70+
71+
private Node? _left;
72+
private Node? _right;
73+
74+
public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;
75+
76+
public Node(Conversation conversation)
77+
{
78+
_sampler = new DefaultSamplingPipeline();
79+
_conversation = conversation;
80+
_decoder = new StreamingTokenDecoder(conversation.Executor.Context);
81+
}
82+
83+
public void Sample()
84+
{
85+
if (_conversation == null)
86+
{
87+
_left?.Sample();
88+
_right?.Sample();
89+
return;
90+
}
91+
92+
if (_conversation.RequiresInference)
93+
return;
94+
95+
// Sample one token
96+
var ctx = _conversation.Executor.Context.NativeHandle;
97+
var logitsCopy = _conversation.Sample().ToArray();
98+
var token = _sampler.Sample(ctx, logitsCopy, Array.Empty<LLamaToken>());
99+
_sampler.Accept(ctx, token);
100+
_decoder.Add(token);
101+
102+
// Prompt the conversation with this token, to continue generating from there
103+
_conversation.Prompt(token);
104+
}
105+
106+
public void Split()
107+
{
108+
if (_conversation != null)
109+
{
110+
_left = new Node(_conversation.Fork());
111+
_right = new Node(_conversation.Fork());
112+
113+
_conversation.Dispose();
114+
_conversation = null;
115+
}
116+
else
117+
{
118+
_left?.Split();
119+
_right?.Split();
120+
}
121+
}
122+
123+
public void Print(int indendation)
124+
{
125+
var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White };
126+
Console.ForegroundColor = colors[indendation % colors.Length];
127+
128+
var message = _decoder.Read().ReplaceLineEndings("");
129+
130+
var prefix = new string(' ', indendation * 3);
131+
var suffix = _conversation == null ? "..." : "";
132+
Console.WriteLine($"{prefix}...{message}{suffix}");
133+
134+
_left?.Print(indendation + 2);
135+
_right?.Print(indendation + 2);
136+
}
137+
}
138+
}

0 commit comments

Comments
 (0)