diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs
deleted file mode 100644
index f893b613f..000000000
--- a/LLama.Examples/Examples/BatchedDecoding.cs
+++ /dev/null
@@ -1,172 +0,0 @@
-using System.Diagnostics;
-using System.Text;
-using LLama.Common;
-using LLama.Native;
-using LLama.Sampling;
-
-namespace LLama.Examples.Examples;
-
-///
-/// This demonstrates generating multiple replies to the same prompt, with a shared cache
-///
-/// Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!
-public class BatchedDecoding
-{
- private const int n_parallel = 8;
- private const int n_len = 32;
-
- public static async Task Run()
- {
- Console.Write("Please input your model path: ");
- var modelPath = Console.ReadLine();
-
- Console.WriteLine("Prompt (leave blank to select automatically):");
- var prompt = Console.ReadLine();
- if (string.IsNullOrWhiteSpace(prompt))
- prompt = "Not many people know that";
-
- // Load model
- var parameters = new ModelParams(modelPath);
-
- using var model = LLamaWeights.LoadFromFile(parameters);
-
- // Tokenize prompt
- var prompt_tokens = model.Tokenize(prompt, true, false, Encoding.UTF8);
- var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;
-
- // Create a context
- parameters.ContextSize = (uint)model.ContextSize;
- parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
- using var context = model.CreateContext(parameters);
-
- var n_ctx = context.ContextSize;
-
- // make sure the KV cache is big enough to hold all the prompt and generated tokens
- if (n_kv_req > n_ctx)
- {
- await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
- await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
- return;
- }
-
- var batch = new LLamaBatch();
-
- // evaluate the initial prompt
- batch.AddRange(prompt_tokens, 0, LLamaSeqId.Zero, true);
-
- if (await context.DecodeAsync(batch) != DecodeResult.Ok)
- {
- await Console.Error.WriteLineAsync("llama_decode failed");
- return;
- }
-
- // assign the system KV cache to all parallel sequences
- // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
- for (var i = 1; i < n_parallel; ++i)
- {
- context.NativeHandle.KvCacheSequenceCopy((LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
- }
-
- if (n_parallel > 1)
- {
- Console.WriteLine();
- Console.WriteLine($"generating {n_parallel} sequences...");
- }
-
- // remember the batch index of the last token for each parallel sequence
- // we need this to determine which logits to sample from
- List i_batch = new();
- for (var i = 0; i < n_parallel; i++)
- i_batch.Add(batch.TokenCount - 1);
-
- // Create per-stream decoder and sampler
- var decoders = new StreamingTokenDecoder[n_parallel];
- var samplers = new ISamplingPipeline[n_parallel];
- for (var i = 0; i < n_parallel; i++)
- {
- decoders[i] = new StreamingTokenDecoder(context);
- samplers[i] = new DefaultSamplingPipeline
- {
- Temperature = 0.1f + (float)i / n_parallel,
- MinP = 0.25f,
- };
- }
-
- var n_cur = batch.TokenCount;
- var n_decode = 0;
-
- var timer = new Stopwatch();
- timer.Start();
- while (n_cur <= n_len)
- {
- batch.Clear();
-
- for (var i = 0; i < n_parallel; i++)
- {
- // Skip completed streams
- if (i_batch[i] < 0)
- continue;
-
- // Use the sampling pipeline to select a token
- var new_token_id = samplers[i].Sample(
- context.NativeHandle,
- context.NativeHandle.GetLogitsIth(i_batch[i]),
- Array.Empty()
- );
-
- // Finish this stream early if necessary
- if (new_token_id == model.EndOfSentenceToken || new_token_id == model.NewlineToken)
- {
- i_batch[i] = -1;
- Console.WriteLine($"Completed Stream {i} early");
- continue;
- }
-
- // Add this token to the decoder, so it will be turned into text
- decoders[i].Add(new_token_id);
-
- i_batch[i] = batch.TokenCount;
-
- // push this new token for next evaluation
- batch.Add(new_token_id, n_cur, (LLamaSeqId)i, true);
-
- n_decode++;
- }
-
- // Check if all streams are finished
- if (batch.TokenCount == 0)
- {
- break;
- }
-
- n_cur++;
-
- // evaluate the current batch with the transformer model
- if (await context.DecodeAsync(batch) != 0)
- {
- await Console.Error.WriteLineAsync("failed to eval");
- return;
- }
- }
-
- timer.Stop();
- Console.ForegroundColor = ConsoleColor.Yellow;
- Console.WriteLine();
- Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
- Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");
-
- var index = 0;
- foreach (var stream in decoders)
- {
- var text = stream.Read();
-
- Console.ForegroundColor = ConsoleColor.Green;
- Console.Write($"{index++}. {prompt}");
- Console.ForegroundColor = ConsoleColor.Red;
- Console.WriteLine(text);
- }
-
- Console.WriteLine("Press any key to exit demo");
- Console.ReadKey(true);
- }
-}
\ No newline at end of file
diff --git a/LLama.Examples/Examples/BatchedExecutorFork.cs b/LLama.Examples/Examples/BatchedExecutorFork.cs
new file mode 100644
index 000000000..b834190ce
--- /dev/null
+++ b/LLama.Examples/Examples/BatchedExecutorFork.cs
@@ -0,0 +1,138 @@
+using LLama.Batched;
+using LLama.Common;
+using LLama.Native;
+using LLama.Sampling;
+
+namespace LLama.Examples.Examples;
+
+///
+/// This demonstrates generating multiple replies to the same prompt, with a shared cache
+///
+public class BatchedExecutorFork
+{
+ private const int n_split = 16;
+ private const int n_len = 64;
+
+ public static async Task Run()
+ {
+ Console.Write("Please input your model path: ");
+ var modelPath = Console.ReadLine();
+
+ var parameters = new ModelParams(modelPath);
+ using var model = LLamaWeights.LoadFromFile(parameters);
+
+ Console.WriteLine("Prompt (leave blank to select automatically):");
+ var prompt = Console.ReadLine();
+ if (string.IsNullOrWhiteSpace(prompt))
+ prompt = "Not many people know that";
+
+ // Create an executor that can evaluate a batch of conversations together
+ var executor = new BatchedExecutor(model, parameters);
+
+ // Print some info
+ var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
+ Console.WriteLine($"Created executor with model: {name}");
+
+ // Evaluate the initial prompt to create one conversation
+ var start = executor.Prompt(prompt);
+ await executor.Infer();
+
+ // Create the root node of the tree
+ var root = new Node(start);
+
+ // Run inference loop
+ for (var i = 0; i < n_len; i++)
+ {
+ if (i != 0)
+ await executor.Infer();
+
+ // Occasionally fork all the active conversations
+ if (i != 0 && i % n_split == 0)
+ root.Split();
+
+ // Sample all active conversations
+ root.Sample();
+ }
+
+ Console.WriteLine($"{prompt}...");
+ root.Print(1);
+
+ Console.WriteLine("Press any key to exit demo");
+ Console.ReadKey(true);
+ }
+
+ class Node
+ {
+ private readonly StreamingTokenDecoder _decoder;
+
+ private readonly DefaultSamplingPipeline _sampler;
+ private Conversation? _conversation;
+
+ private Node? _left;
+ private Node? _right;
+
+ public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;
+
+ public Node(Conversation conversation)
+ {
+ _sampler = new DefaultSamplingPipeline();
+ _conversation = conversation;
+ _decoder = new StreamingTokenDecoder(conversation.Executor.Context);
+ }
+
+ public void Sample()
+ {
+ if (_conversation == null)
+ {
+ _left?.Sample();
+ _right?.Sample();
+ return;
+ }
+
+ if (_conversation.RequiresInference)
+ return;
+
+ // Sample one token
+ var ctx = _conversation.Executor.Context.NativeHandle;
+ var logitsCopy = _conversation.Sample().ToArray();
+ var token = _sampler.Sample(ctx, logitsCopy, Array.Empty());
+ _sampler.Accept(ctx, token);
+ _decoder.Add(token);
+
+ // Prompt the conversation with this token, to continue generating from there
+ _conversation.Prompt(token);
+ }
+
+ public void Split()
+ {
+ if (_conversation != null)
+ {
+ _left = new Node(_conversation.Fork());
+ _right = new Node(_conversation.Fork());
+
+ _conversation.Dispose();
+ _conversation = null;
+ }
+ else
+ {
+ _left?.Split();
+ _right?.Split();
+ }
+ }
+
+ public void Print(int indendation)
+ {
+ var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White };
+ Console.ForegroundColor = colors[indendation % colors.Length];
+
+ var message = _decoder.Read().ReplaceLineEndings("");
+
+ var prefix = new string(' ', indendation * 3);
+ var suffix = _conversation == null ? "..." : "";
+ Console.WriteLine($"{prefix}...{message}{suffix}");
+
+ _left?.Print(indendation + 2);
+ _right?.Print(indendation + 2);
+ }
+ }
+}
\ No newline at end of file
diff --git a/LLama.Examples/Examples/BatchedExecutorRewind.cs b/LLama.Examples/Examples/BatchedExecutorRewind.cs
new file mode 100644
index 000000000..25195a56e
--- /dev/null
+++ b/LLama.Examples/Examples/BatchedExecutorRewind.cs
@@ -0,0 +1,121 @@
+using LLama.Batched;
+using LLama.Common;
+using LLama.Native;
+using LLama.Sampling;
+
+namespace LLama.Examples.Examples;
+
+///
+/// This demonstrates generating tokens and then rewinding to an earlier state
+///
+public class BatchedExecutorRewind
+{
+ private const int n_generate = 24;
+ private const int n_rewind = 12;
+ private const int n_repeats = 6;
+
+ public static async Task Run()
+ {
+ Console.Write("Please input your model path: ");
+ var modelPath = Console.ReadLine();
+
+ var parameters = new ModelParams(modelPath);
+ using var model = LLamaWeights.LoadFromFile(parameters);
+
+ Console.WriteLine("Prompt (leave blank to select automatically):");
+ var prompt = Console.ReadLine();
+ if (string.IsNullOrWhiteSpace(prompt))
+ prompt = "Not many people know that";
+
+ // Create an executor that can evaluate a batch of conversations together
+ var executor = new BatchedExecutor(model, parameters);
+
+ // Print some info
+ var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
+ Console.WriteLine($"Created executor with model: {name}");
+
+ // Evaluate the initial prompt to create one conversation
+ var conversation = executor.Prompt(prompt);
+
+ // Create the start node wrapping the conversation
+ var node = new Node(executor.Context);
+
+ // Print the prompt
+ Console.ForegroundColor = ConsoleColor.Green;
+ Console.WriteLine(prompt);
+
+ for (var i = 0; i < n_repeats; i++)
+ {
+ for (var j = 0; j < n_generate; j++)
+ {
+ // Run inference
+ await executor.Infer();
+
+ // Sample a token
+ var token = node.Sample(conversation);
+
+ // Continue conversation with this token
+ if (j != n_generate - 1)
+ conversation.Prompt(token);
+ }
+
+ // Write out what we generated
+ node.Write(n_rewind, i + 1);
+
+ // Rewind back a few tokens
+ conversation.Rewind(n_rewind + 1);
+
+ // Prompt with a token
+ conversation.Prompt(node.GetToken(n_generate - n_rewind - 1));
+
+ // Create a new node around the rewound conversation
+ node = new Node(executor.Context);
+ }
+
+ Console.WriteLine("Press any key to exit demo");
+ Console.ReadKey(true);
+ }
+
+ private class Node
+ {
+ private readonly LLamaContext _context;
+
+ private readonly List _tokens = new List();
+ private readonly DefaultSamplingPipeline Sampler;
+
+ public Node(LLamaContext context)
+ {
+ _context = context;
+ Sampler = new DefaultSamplingPipeline();
+ }
+
+ public LLamaToken Sample(Conversation conversation)
+ {
+ var token = Sampler.Sample(_context.NativeHandle, conversation.Sample().ToArray(), Array.Empty());
+ _tokens.Add(token);
+ return token;
+ }
+
+ public void Write(int n_rewind, int depth)
+ {
+ var decoder = new StreamingTokenDecoder(_context);
+
+ for (var i = 0; i < _tokens.Count - n_rewind; i++)
+ decoder.Add(_tokens[i]);
+
+ Console.ForegroundColor = ConsoleColor.Green;
+ Console.Write(new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" "));
+
+ for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++)
+ decoder.Add(_tokens[i]);
+
+ Console.ForegroundColor = ConsoleColor.DarkRed;
+ Console.WriteLine(decoder.Read().ReplaceLineEndings(" "));
+ }
+
+ public LLamaToken GetToken(int index)
+ {
+ return _tokens[index];
+ }
+ }
+}
\ No newline at end of file
diff --git a/LLama.Examples/Examples/Runner.cs b/LLama.Examples/Examples/Runner.cs
index 3d9858e1d..54358ce70 100644
--- a/LLama.Examples/Examples/Runner.cs
+++ b/LLama.Examples/Examples/Runner.cs
@@ -23,7 +23,8 @@ public class Runner
{ "Semantic Kernel Chat.", SemanticKernelChat.Run },
{ "Semantic Kernel Memory.", SemanticKernelMemory.Run },
{ "Coding Assistant.", CodingAssistant.Run },
- { "Batch Decoding.", BatchedDecoding.Run },
+ { "Batched Executor (Fork)", BatchedExecutorFork.Run },
+ { "Batched Executor (Rewind)", BatchedExecutorRewind.Run },
{ "SK Kernel Memory.", KernelMemory.Run },
{ "Exit", async () => Environment.Exit(0) }
};
diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs
new file mode 100644
index 000000000..f432a89f1
--- /dev/null
+++ b/LLama/Batched/BatchedExecutor.cs
@@ -0,0 +1,119 @@
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using LLama.Abstractions;
+using LLama.Native;
+
+namespace LLama.Batched;
+
+///
+/// A batched executor that can infer multiple separate "conversations" simultaneously.
+///
+public sealed class BatchedExecutor
+ : IDisposable
+{
+ private int _nextSequenceId;
+
+ internal LLamaBatch Batch { get; }
+
+ ///
+ /// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
+ /// whether they're waiting for inference, or can be sampled.
+ ///
+ internal ulong Epoch { get; private set; }
+
+ ///
+ /// The this executor is using
+ ///
+ public LLamaContext Context { get; }
+
+ ///
+ /// The this executor is using
+ ///
+ public LLamaWeights Model { get; }
+
+ ///
+ /// Get the number of tokens in the batch, waiting for to be called
+ ///
+ public int BatchedTokenCount => Batch.TokenCount;
+
+ ///
+ /// Check if this executor has been disposed.
+ ///
+ public bool IsDisposed { get; private set; }
+
+ ///
+ /// Create a new batched executor
+ ///
+ /// The model to use
+ /// Parameters to create a new context
+ public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
+ {
+ Model = model;
+ Batch = new LLamaBatch();
+ Context = model.CreateContext(contextParams);
+ Epoch = 1;
+ }
+
+ ~BatchedExecutor()
+ {
+ Dispose();
+ }
+
+ ///
+ /// Start a new with the given prompt
+ ///
+ ///
+ ///
+ public Conversation Prompt(string prompt)
+ {
+ if (IsDisposed)
+ throw new ObjectDisposedException(nameof(BatchedExecutor));
+
+ var conversation = new Conversation(this, GetNextSequenceId(), 0);
+ conversation.Prompt(prompt);
+
+ return conversation;
+ }
+
+ ///
+ /// Run inference for all conversations in the batch which have pending tokens.
+ ///
+ /// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation
+ /// threads and running inference again.
+ ///
+ public async Task Infer(CancellationToken cancellation = default)
+ {
+ if (IsDisposed)
+ throw new ObjectDisposedException(nameof(BatchedExecutor));
+
+ var status = await Context.DecodeAsync(Batch, cancellation);
+
+ // Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can
+ // be called again after a warning (e.g. NoKvSlot).
+ if (status == DecodeResult.Ok)
+ {
+ Epoch++;
+ Batch.Clear();
+ }
+
+ return status;
+ }
+
+ ///
+ public void Dispose()
+ {
+ if (IsDisposed)
+ return;
+ IsDisposed = true;
+
+ GC.SuppressFinalize(this);
+
+ Context.Dispose();
+ }
+
+ internal LLamaSeqId GetNextSequenceId()
+ {
+ return checked((LLamaSeqId)_nextSequenceId++);
+ }
+}
\ No newline at end of file
diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs
new file mode 100644
index 000000000..6cf6e3128
--- /dev/null
+++ b/LLama/Batched/Conversation.cs
@@ -0,0 +1,294 @@
+using System;
+using System.Collections.Generic;
+using LLama.Native;
+
+namespace LLama.Batched;
+
+///
+/// A single conversation thread that can be prompted (adding tokens from the user) or inferred (extracting a token from the LLM)
+///
+public sealed class Conversation
+ : IDisposable
+{
+ private ulong _requiredEpoch;
+ private LLamaPos _end;
+ private int _batchIndex;
+ private bool _disposed;
+
+ ///
+ /// The executor which this conversation belongs to
+ ///
+ public BatchedExecutor Executor { get; }
+
+ ///
+ /// Unique ID for this conversation
+ ///
+ public LLamaSeqId ConversationId { get; }
+
+ ///
+ /// Total number of tokens in this conversation, cannot exceed the context length.
+ ///
+ public int TokenCount => _end.Value;
+
+ ///
+ /// Indicates if this conversation has been disposed, nothing can be done with a disposed conversation
+ ///
+ public bool IsDisposed => _disposed || Executor.IsDisposed;
+
+ ///
+ /// Indicates if this conversation is waiting for inference to be run on the executor. "Prompt" and "Sample" cannot be called when this is true.
+ ///
+ public bool RequiresInference => _requiredEpoch > Executor.Epoch;
+
+ ///
+ /// Indicates that this conversation should be sampled.
+ ///
+ public bool RequiresSampling => _requiredEpoch == Executor.Epoch;
+
+ #region construction/destruction
+ internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end)
+ {
+ ConversationId = id;
+ Executor = batch;
+
+ _end = end;
+ }
+
+ ~Conversation()
+ {
+ Dispose();
+ }
+
+ ///
+ /// End this conversation, freeing all resources used by it
+ ///
+ ///
+ public void Dispose()
+ {
+ if (IsDisposed)
+ return;
+ _disposed = true;
+
+ // Remove this conversation from the KV cache
+ Executor.Context.NativeHandle.KvCacheRemove(ConversationId, 0, _end);
+
+ // Prevent finalizer from running
+ GC.SuppressFinalize(this);
+ }
+
+ private void AssertNotDisposed()
+ {
+ if (Executor.IsDisposed)
+ throw new ObjectDisposedException(nameof(BatchedExecutor));
+ if (IsDisposed)
+ throw new ObjectDisposedException(nameof(Conversation));
+ }
+ #endregion
+
+ ///
+ /// Create a copy of the current conversation
+ ///
+ /// The copy shares internal state, so consumes very little extra memory.
+ ///
+ ///
+ public Conversation Fork()
+ {
+ AssertNotDisposed();
+
+ if (RequiresInference)
+ throw new CannotForkWhileRequiresInference();
+
+ // Create a new conversation which references the current position in this one
+ var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
+ {
+ _batchIndex = _batchIndex,
+ _requiredEpoch = _requiredEpoch,
+ };
+
+ // Assign tokens to the new sequence
+ NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);
+
+ return c;
+ }
+
+ #region sample
+ ///
+ /// Get the logits from this conversation, ready for sampling
+ ///
+ ///
+ ///
+ /// Thrown if this conversation was not prompted before the previous call to infer
+ /// Thrown if Infer() must be called on the executor
+ public ReadOnlySpan Sample()
+ {
+ AssertNotDisposed();
+
+ if (_requiredEpoch < Executor.Epoch)
+ throw new CannotSampleRequiresPromptException();
+ if (_requiredEpoch > Executor.Epoch)
+ throw new CannotSampleRequiresInferenceException();
+
+ return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
+ }
+ #endregion
+
+ #region prompt
+ private void AssertCanBePrompted()
+ {
+ AssertNotDisposed();
+
+ if (RequiresInference)
+ throw new AlreadyPromptedConversationException();
+ }
+
+ ///
+ /// Add tokens to this conversation
+ ///
+ ///
+ ///
+ public void Prompt(string input)
+ {
+ AssertCanBePrompted();
+
+ Prompt(Executor.Context.Tokenize(input));
+ }
+
+ ///
+ /// Add tokens to this conversation
+ ///
+ ///
+ ///
+ ///
+ public void Prompt(IReadOnlyList tokens)
+ {
+ AssertCanBePrompted();
+
+ // Add the prompt to the batch
+ for (var i = 0; i < tokens.Count; i++)
+ _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1);
+
+ // Mark this conversation as needing inference/sampling
+ _requiredEpoch = Executor.Epoch + 1;
+ }
+
+ ///
+ /// Add a single token to this conversation
+ ///
+ ///
+ ///
+ ///
+ ///
+ public void Prompt(LLamaToken token)
+ {
+ AssertCanBePrompted();
+
+ // Add this token as input
+ _batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true);
+
+ // Mark this conversation as needing inference/sampling
+ _requiredEpoch = Executor.Epoch + 1;
+ }
+ #endregion
+
+ #region modify
+ ///
+ /// Directly modify the KV cache of this conversation
+ ///
+ ///
+ /// Thrown if this method is called while == true
+ public void Modify(ModifyKvCache modifier)
+ {
+ AssertNotDisposed();
+
+ if (RequiresInference)
+ throw new CannotModifyWhileRequiresInference();
+
+ // do whatever the modification is
+ _end = modifier.Invoke(_end, new KvAccessor(this));
+
+ // Set the epoch down to zero, this ensures that this conversation
+ // cannot be sampled until it is prompted again.
+ _requiredEpoch = 0;
+ }
+
+ ///
+ /// Provides direct access to the KV cache of a .
+ /// See for how to use this.
+ ///
+ public readonly ref struct KvAccessor
+ {
+ private readonly Conversation _conversation;
+
+ internal KvAccessor(Conversation conversation)
+ {
+ _conversation = conversation;
+ }
+
+ #region remove
+ ///
+ /// Removes all tokens that have positions in [start, end)
+ ///
+ /// Start position (inclusive)
+ /// End position (exclusive)
+ public void Remove(LLamaPos start, LLamaPos end)
+ {
+ _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
+ }
+
+ ///
+ /// Removes all tokens starting from the given position
+ ///
+ /// Start position (inclusive)
+ /// Number of tokens
+ public void Remove(LLamaPos start, int count)
+ {
+ if (count <= 0)
+ return;
+
+ var end = start.Value + count;
+ _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
+ }
+ #endregion
+
+ #region shift
+ ///
+ /// Adds relative position "delta" to all tokens that have positions in [p0, p1).
+ /// If the KV cache is RoPEd, the KV data is updated
+ /// accordingly
+ ///
+ /// Start position (inclusive)
+ /// End position (exclusive)
+ /// Amount to add on to each token position
+ public void Shift(LLamaPos start, LLamaPos end, int delta)
+ {
+ _conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta);
+ }
+ #endregion
+
+ #region divide
+ ///
+ /// Integer division of the positions by factor of `d > 1`.
+ /// If the KV cache is RoPEd, the KV data is updated accordingly.
+ ///
+ /// Start position (inclusive). If less than zero, it is clamped to zero.
+ /// End position (exclusive). If less than zero, it is treated as "infinity".
+ /// Amount to divide each position by.
+ public void Divide(LLamaPos start, LLamaPos end, int divisor)
+ {
+ if (divisor <= 0)
+ throw new ArgumentOutOfRangeException(nameof(divisor));
+
+ _conversation.Executor.Context.NativeHandle.KvCacheSequenceDivide(_conversation.ConversationId, start, end, divisor);
+ }
+ #endregion
+ }
+
+ ///
+ /// A function which can temporarily access the KV cache of a to modify it directly
+ ///
+ /// The current end token of this conversation
+ /// An which allows direct access to modify the KV cache
+ /// The new end token position
+ public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
+ #endregion
+}
\ No newline at end of file
diff --git a/LLama/Batched/ConversationExtensions.cs b/LLama/Batched/ConversationExtensions.cs
new file mode 100644
index 000000000..5fca5e94b
--- /dev/null
+++ b/LLama/Batched/ConversationExtensions.cs
@@ -0,0 +1,59 @@
+using System;
+
+namespace LLama.Batched;
+
+///
+/// Extension method for
+///
+public static class ConversationExtensions
+{
+ ///
+ /// Rewind a back to an earlier state by removing tokens from the end
+ ///
+ /// The conversation to rewind
+ /// The number of tokens to rewind
+ /// Thrown if `tokens` parameter is larger than TokenCount
+ public static void Rewind(this Conversation conversation, int tokens)
+ {
+ if (tokens > conversation.TokenCount)
+ throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens");
+
+ conversation.Modify((end, kv) =>
+ {
+ // Remove those tokens from KV
+ kv.Remove(end.Value - tokens, tokens);
+
+ // Return adjusted end position
+ return end.Value - tokens;
+ });
+ }
+
+ ///
+ /// Shift all tokens over to the left, removing "count" tokens from the start and shifting everything over.
+ /// Leaves "keep" tokens at the start completely untouched. This can be used to free up space when the context
+ /// gets full, keeping the prompt at the start intact.
+ ///
+ /// The conversation to rewind
+ /// How much to shift tokens over by
+ /// The number of tokens at the start which should not be shifted
+ public static void ShiftLeft(this Conversation conversation, int count, int keep)
+ {
+ // Given a setup like this (shift=5, keep=3):
+ //
+ // AAABBBBBCCCCCCCCC...
+ //
+ // We want to remove all the B's, shift all the C's and leave all the A's untouched
+
+ conversation.Modify((end, kv) =>
+ {
+ // Remove the B's
+ kv.Remove(keep, count);
+
+ // Shift the C's
+ kv.Shift(keep + count, end, -count);
+
+ // Update total count
+ return end.Value - count;
+ });
+ }
+}
\ No newline at end of file
diff --git a/LLama/Batched/Exceptions.cs b/LLama/Batched/Exceptions.cs
new file mode 100644
index 000000000..1feb270c2
--- /dev/null
+++ b/LLama/Batched/Exceptions.cs
@@ -0,0 +1,81 @@
+using System;
+
+namespace LLama.Batched;
+
+///
+/// Base class for exceptions thrown from
+///
+public class ExperimentalBatchedExecutorException
+ : Exception
+{
+ internal ExperimentalBatchedExecutorException(string message)
+ : base(message)
+ {
+ }
+}
+
+///
+/// This exception is thrown when "Prompt()" is called on a which has
+/// already been prompted and before "Infer()" has been called on the associated
+/// .
+///
+public class AlreadyPromptedConversationException
+ : ExperimentalBatchedExecutorException
+{
+ internal AlreadyPromptedConversationException()
+ : base("Must call `Infer()` before prompting this Conversation again")
+ {
+ }
+}
+
+///
+/// This exception is thrown when "Sample()" is called on a which has
+/// already been prompted and before "Infer()" has been called on the associated
+/// .
+///
+public class CannotSampleRequiresInferenceException
+ : ExperimentalBatchedExecutorException
+{
+ internal CannotSampleRequiresInferenceException()
+ : base("Must call `Infer()` before sampling from this Conversation")
+ {
+ }
+}
+
+///
+/// This exception is thrown when "Sample()" is called on a which was not
+/// first prompted.
+/// .
+///
+public class CannotSampleRequiresPromptException
+ : ExperimentalBatchedExecutorException
+{
+ internal CannotSampleRequiresPromptException()
+ : base("Must call `Prompt()` and then `Infer()` before sampling from this Conversation")
+ {
+ }
+}
+
+///
+/// This exception is thrown when is called when = true
+///
+public class CannotForkWhileRequiresInference
+ : ExperimentalBatchedExecutorException
+{
+ internal CannotForkWhileRequiresInference()
+ : base("Cannot `Fork()` a conversation while RequiresInference is true")
+ {
+ }
+}
+
+///
+/// This exception is thrown when is called when = true
+///
+public class CannotModifyWhileRequiresInference
+ : ExperimentalBatchedExecutorException
+{
+ internal CannotModifyWhileRequiresInference()
+ : base("Cannot `Modify()` a conversation while RequiresInference is true")
+ {
+ }
+}
\ No newline at end of file
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 5d026b672..9075c89fc 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -221,7 +221,9 @@ public void LoadState(State state)
/// The selected token
public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens)
{
- return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
+ var token = pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
+ pipeline.Accept(NativeHandle, token);
+ return token;
}
///
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 969da783c..31e975caf 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -213,6 +213,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
+ inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id);
}
else
{
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 7d742c813..9338d8396 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -192,6 +192,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
+ inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id);
}
else
{
diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs
index 532e16fff..50b9c8f0e 100644
--- a/LLama/Native/LLamaBatch.cs
+++ b/LLama/Native/LLamaBatch.cs
@@ -18,6 +18,11 @@ public class LLamaBatch
private LLamaSeqId[][] _sequenceIds;
private IntPtr[] _sequenceIdsPtrs;
+ ///
+ /// Keep track of the index of existing token/position combos in the batch
+ ///
+ private readonly Dictionary<(LLamaToken, LLamaPos), int> _index = new();
+
///
/// The number of tokens in this batch
///
@@ -130,23 +135,44 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
/// The position to add it att
/// The set of sequences to add this token to
///
- public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits)
+ /// The index that the token was added at. Use this for GetLogitsIth
+ public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits)
{
+ // Try to find this (token, position) combo somewhere in the batch to re-use it
+ if (_index.TryGetValue((token, pos), out var existingIndex))
+ {
+ if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity)
+ GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length);
+
+ foreach (var sequence in sequences)
+ {
+ _sequenceIds[existingIndex][_sequenceIdCount[existingIndex]] = sequence;
+ _sequenceIdCount[existingIndex]++;
+ }
+
+ return existingIndex;
+ }
+
+ // Couldn't find this it in the batch, add a new item
+
+ // Frow capacity as necessary
if (TokenCount == TokenCapacity)
GrowTokenCapacity();
if (sequences.Length > SequenceCapacity)
GrowMaxSequences(sequences.Length);
+ // Store the position in the index, so it can be found later
+ _index.Add((token, pos), TokenCount);
+
+ // Add the items to the arrays
_tokens[TokenCount] = token;
_positions[TokenCount] = pos;
-
_sequenceIdCount[TokenCount] = sequences.Length;
for (var i = 0; i < sequences.Length; i++)
_sequenceIds[TokenCount][i] = sequences[i];
-
_logits[TokenCount] = Convert.ToByte(logits);
- TokenCount++;
+ return TokenCount++;
}
///
@@ -157,11 +183,12 @@ public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequenc
/// The position to add it att
/// The set of sequences to add this token to
///
- public void Add(LLamaToken token, LLamaPos pos, List sequences, bool logits)
+ /// The index that the token was added at. Use this for GetLogitsIth
+ public int Add(LLamaToken token, LLamaPos pos, List sequences, bool logits)
{
#if NET5_0_OR_GREATER
var seqSpan = CollectionsMarshal.AsSpan(sequences);
- Add(token, pos, seqSpan, logits);
+ return Add(token, pos, seqSpan, logits);
#else
// on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
@@ -171,7 +198,7 @@ public void Add(LLamaToken token, LLamaPos pos, List sequences, bool
try
{
sequences.CopyTo(rented, 0);
- Add(token, pos, rented.AsSpan(0, sequences.Count), logits);
+ return Add(token, pos, rented.AsSpan(0, sequences.Count), logits);
}
finally
{
@@ -188,14 +215,15 @@ public void Add(LLamaToken token, LLamaPos pos, List sequences, bool
/// The position to add it att
/// The sequence to add this token to
///
- public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
+ /// The index that the token was added at. Use this for GetLogitsIth
+ public int Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
{
// Create a temporary span to contain 1 item without allocating
Span sequences = stackalloc LLamaSeqId[1];
sequences[0] = sequence;
// Add it
- Add(token, pos, sequences, logits);
+ return Add(token, pos, sequences, logits);
}
///
@@ -205,13 +233,17 @@ public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits
/// The starting position to add tokens at
/// The sequence to add this token to
/// Whether the final token should generate logits
- public void AddRange(ReadOnlySpan tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast)
+ /// The index that the final token was added at. Use this for GetLogitsIth
+ public int AddRange(ReadOnlySpan tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast)
{
+ var last = -1;
for (var i = 0; i < tokens.Length; i++)
{
var logits = (i == tokens.Length - 1) & logitsLast;
- Add(tokens[i], start.Value + i, sequence, logits);
+ last = Add(tokens[i], start.Value + i, sequence, logits);
}
+
+ return last;
}
#endregion
@@ -221,5 +253,6 @@ public void AddRange(ReadOnlySpan tokens, LLamaPos start, LLamaSeqId
public void Clear()
{
TokenCount = 0;
+ _index.Clear();
}
}
\ No newline at end of file
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index c953cb237..578cad405 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -388,7 +388,7 @@ public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken ll
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta);
+ public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta);
///
/// Integer division of the positions by factor of `d > 1`
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 91e82c85d..2d9387ae4 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -369,15 +369,15 @@ public void KvCacheSequenceKeep(LLamaSeqId seq)
///
///
///
- public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta)
+ public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta)
{
NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta);
}
///
- /// Integer division of the positions by factor of `d > 1`
- /// If the KV cache is RoPEd, the KV data is updated accordingly
- /// p0 < 0 : [0, p1]
+ /// Integer division of the positions by factor of `d > 1`.
+ /// If the KV cache is RoPEd, the KV data is updated accordingly.
+ /// p0 < 0 : [0, p1]
/// p1 < 0 : [p0, inf)
///
///
diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs
index a41aa67e1..b86001aa0 100644
--- a/LLama/Sampling/BaseSamplingPipeline.cs
+++ b/LLama/Sampling/BaseSamplingPipeline.cs
@@ -40,10 +40,7 @@ public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnl
var candidates = LLamaTokenDataArray.Create(logits);
// Process token data array
- ProcessTokenDataArray(ctx, candidates, lastTokens);
-
- // Choose the final value
- return ChooseToken(ctx, candidates);
+ return ProcessTokenDataArray(ctx, candidates, lastTokens);
}
finally
{
@@ -53,6 +50,9 @@ public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnl
}
}
+ ///
+ public abstract void Accept(SafeLLamaContextHandle ctx, LLamaToken token);
+
#region protected tokens
///
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
@@ -107,19 +107,14 @@ protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
///
protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens);
- ///
- /// Choose the final token from the candidates
- ///
- ///
- ///
- ///
- protected abstract LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);
-
///
public virtual void Reset()
{
}
+ ///
+ public abstract ISamplingPipeline Clone();
+
///
public virtual void Dispose()
{
diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs
index b0fb5c596..531f34faa 100644
--- a/LLama/Sampling/DefaultSamplingPipeline.cs
+++ b/LLama/Sampling/DefaultSamplingPipeline.cs
@@ -141,9 +141,31 @@ protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx,
return id;
}
+ public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
+ {
+ Grammar?.AcceptToken(ctx, token);
+ }
+
///
- protected override LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
+ public override ISamplingPipeline Clone()
{
- return candidates.SampleToken(ctx);
+ var clone = new DefaultSamplingPipeline();
+
+ foreach (var (k, v) in LogitBias)
+ clone.LogitBias.Add(k, v);
+
+ clone.Grammar = Grammar?.Clone();
+ clone.RepeatPenalty = RepeatPenalty;
+ clone.AlphaFrequency = AlphaFrequency;
+ clone.AlphaPresence = AlphaPresence;
+ clone.Temperature = Temperature;
+ clone.TopK = TopK;
+ clone.TailFreeZ = TailFreeZ;
+ clone.TypicalP = TypicalP;
+ clone.TopP = TopP;
+ clone.MinP = MinP;
+ clone.PenalizeNewline = PenalizeNewline;
+
+ return clone;
}
}
\ No newline at end of file
diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs
index be1398790..b538d1feb 100644
--- a/LLama/Sampling/ISamplingPipeline.cs
+++ b/LLama/Sampling/ISamplingPipeline.cs
@@ -21,10 +21,23 @@ public interface ISamplingPipeline
///
LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens);
+ ///
+ /// Update the pipeline, with knowledge that a particular token was just accepted
+ ///
+ ///
+ ///
+ void Accept(SafeLLamaContextHandle ctx, LLamaToken token);
+
///
/// Reset all internal state of the sampling pipeline
///
void Reset();
+
+ ///
+ /// Create a copy of this sampling pipeline
+ ///
+ ///
+ ISamplingPipeline Clone();
}
///