Skip to content

Commit b0acecf

Browse files
committed
Created a new BatchedExecutor which processes multiple "Conversations" in one single inference batch. This is faster, even when the conversations are unrelated, and is much faster if the conversations share some overlap (e.g. a common system prompt prefix).
Conversations can be "forked", to create a copy of a conversation at a given point. This allows e.g. prompting a conversation with a system prefix just once and then forking it again and again for each individual conversation. Conversations can also be "rewound" to an earlier state. Added two new examples, demonstrating forking and rewinding.
1 parent 859160d commit b0acecf

14 files changed

+748
-199
lines changed

LLama.Examples/Examples/BatchedDecoding.cs

Lines changed: 0 additions & 172 deletions
This file was deleted.
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
using LLama.Batched;
2+
using LLama.Common;
3+
using LLama.Native;
4+
using LLama.Sampling;
5+
6+
namespace LLama.Examples.Examples;
7+
8+
/// <summary>
9+
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
10+
/// </summary>
11+
public class BatchedExecutorFork
12+
{
13+
private const int n_split = 16;
14+
private const int n_len = 64;
15+
16+
public static async Task Run()
17+
{
18+
Console.Write("Please input your model path: ");
19+
var modelPath = Console.ReadLine();
20+
21+
var parameters = new ModelParams(modelPath);
22+
using var model = LLamaWeights.LoadFromFile(parameters);
23+
24+
Console.WriteLine("Prompt (leave blank to select automatically):");
25+
var prompt = Console.ReadLine();
26+
if (string.IsNullOrWhiteSpace(prompt))
27+
prompt = "Not many people know that";
28+
29+
// Create an executor that can evaluate a batch of conversations together
30+
var executor = new BatchedExecutor(model, parameters);
31+
32+
// Print some info
33+
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
34+
Console.WriteLine($"Created executor with model: {name}");
35+
36+
// Evaluate the initial prompt to create one conversation
37+
var start = executor.Prompt(prompt);
38+
await executor.Infer();
39+
40+
// Create the root node of the tree
41+
var root = new Node(start);
42+
43+
// Run inference loop
44+
for (var i = 0; i < n_len; i++)
45+
{
46+
if (i != 0)
47+
await executor.Infer();
48+
49+
// Occasionally fork all the active conversations
50+
if (i != 0 && i % n_split == 0)
51+
root.Split();
52+
53+
// Sample all active conversations
54+
root.Sample();
55+
}
56+
57+
Console.WriteLine($"{prompt}...");
58+
root.Print(1);
59+
60+
Console.WriteLine("Press any key to exit demo");
61+
Console.ReadKey(true);
62+
}
63+
64+
class Node
65+
{
66+
private readonly StreamingTokenDecoder _decoder;
67+
68+
private readonly DefaultSamplingPipeline _sampler;
69+
private Conversation? _conversation;
70+
71+
private Node? _left;
72+
private Node? _right;
73+
74+
public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;
75+
76+
public Node(Conversation conversation)
77+
{
78+
_sampler = new DefaultSamplingPipeline();
79+
_conversation = conversation;
80+
_decoder = new StreamingTokenDecoder(conversation.Executor.Context);
81+
}
82+
83+
public void Sample()
84+
{
85+
if (_conversation == null)
86+
{
87+
_left?.Sample();
88+
_right?.Sample();
89+
return;
90+
}
91+
92+
if (_conversation.RequiresInference)
93+
return;
94+
95+
// Sample one token
96+
var ctx = _conversation.Executor.Context.NativeHandle;
97+
var logitsCopy = _conversation.Sample().ToArray();
98+
var token = _sampler.Sample(ctx, logitsCopy, Array.Empty<LLamaToken>());
99+
_sampler.Accept(ctx, token);
100+
_decoder.Add(token);
101+
102+
// Prompt the conversation with this token, to continue generating from there
103+
_conversation.Prompt(token);
104+
}
105+
106+
public void Split()
107+
{
108+
if (_conversation != null)
109+
{
110+
_left = new Node(_conversation.Fork());
111+
_right = new Node(_conversation.Fork());
112+
113+
_conversation.Dispose();
114+
_conversation = null;
115+
}
116+
else
117+
{
118+
_left?.Split();
119+
_right?.Split();
120+
}
121+
}
122+
123+
public void Print(int indendation)
124+
{
125+
var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White };
126+
Console.ForegroundColor = colors[indendation % colors.Length];
127+
128+
var message = _decoder.Read().ReplaceLineEndings("");
129+
130+
var prefix = new string(' ', indendation * 3);
131+
var suffix = _conversation == null ? "..." : "";
132+
Console.WriteLine($"{prefix}...{message}{suffix}");
133+
134+
_left?.Print(indendation + 2);
135+
_right?.Print(indendation + 2);
136+
}
137+
}
138+
}

0 commit comments

Comments
 (0)