diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs deleted file mode 100644 index f893b613f..000000000 --- a/LLama.Examples/Examples/BatchedDecoding.cs +++ /dev/null @@ -1,172 +0,0 @@ -using System.Diagnostics; -using System.Text; -using LLama.Common; -using LLama.Native; -using LLama.Sampling; - -namespace LLama.Examples.Examples; - -/// -/// This demonstrates generating multiple replies to the same prompt, with a shared cache -/// -/// Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this! -public class BatchedDecoding -{ - private const int n_parallel = 8; - private const int n_len = 32; - - public static async Task Run() - { - Console.Write("Please input your model path: "); - var modelPath = Console.ReadLine(); - - Console.WriteLine("Prompt (leave blank to select automatically):"); - var prompt = Console.ReadLine(); - if (string.IsNullOrWhiteSpace(prompt)) - prompt = "Not many people know that"; - - // Load model - var parameters = new ModelParams(modelPath); - - using var model = LLamaWeights.LoadFromFile(parameters); - - // Tokenize prompt - var prompt_tokens = model.Tokenize(prompt, true, false, Encoding.UTF8); - var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel; - - // Create a context - parameters.ContextSize = (uint)model.ContextSize; - parameters.BatchSize = (uint)Math.Max(n_len, n_parallel); - using var context = model.CreateContext(parameters); - - var n_ctx = context.ContextSize; - - // make sure the KV cache is big enough to hold all the prompt and generated tokens - if (n_kv_req > n_ctx) - { - await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n"); - await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n"); - return; - } - - var batch = new LLamaBatch(); - - // evaluate the initial prompt - batch.AddRange(prompt_tokens, 0, LLamaSeqId.Zero, true); - - if (await context.DecodeAsync(batch) != DecodeResult.Ok) - { - await Console.Error.WriteLineAsync("llama_decode failed"); - return; - } - - // assign the system KV cache to all parallel sequences - // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them - for (var i = 1; i < n_parallel; ++i) - { - context.NativeHandle.KvCacheSequenceCopy((LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount); - } - - if (n_parallel > 1) - { - Console.WriteLine(); - Console.WriteLine($"generating {n_parallel} sequences..."); - } - - // remember the batch index of the last token for each parallel sequence - // we need this to determine which logits to sample from - List i_batch = new(); - for (var i = 0; i < n_parallel; i++) - i_batch.Add(batch.TokenCount - 1); - - // Create per-stream decoder and sampler - var decoders = new StreamingTokenDecoder[n_parallel]; - var samplers = new ISamplingPipeline[n_parallel]; - for (var i = 0; i < n_parallel; i++) - { - decoders[i] = new StreamingTokenDecoder(context); - samplers[i] = new DefaultSamplingPipeline - { - Temperature = 0.1f + (float)i / n_parallel, - MinP = 0.25f, - }; - } - - var n_cur = batch.TokenCount; - var n_decode = 0; - - var timer = new Stopwatch(); - timer.Start(); - while (n_cur <= n_len) - { - batch.Clear(); - - for (var i = 0; i < n_parallel; i++) - { - // Skip completed streams - if (i_batch[i] < 0) - continue; - - // Use the sampling pipeline to select a token - var new_token_id = samplers[i].Sample( - context.NativeHandle, - context.NativeHandle.GetLogitsIth(i_batch[i]), - Array.Empty() - ); - - // Finish this stream early if necessary - if (new_token_id == model.EndOfSentenceToken || new_token_id == model.NewlineToken) - { - i_batch[i] = -1; - Console.WriteLine($"Completed Stream {i} early"); - continue; - } - - // Add this token to the decoder, so it will be turned into text - decoders[i].Add(new_token_id); - - i_batch[i] = batch.TokenCount; - - // push this new token for next evaluation - batch.Add(new_token_id, n_cur, (LLamaSeqId)i, true); - - n_decode++; - } - - // Check if all streams are finished - if (batch.TokenCount == 0) - { - break; - } - - n_cur++; - - // evaluate the current batch with the transformer model - if (await context.DecodeAsync(batch) != 0) - { - await Console.Error.WriteLineAsync("failed to eval"); - return; - } - } - - timer.Stop(); - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine(); - Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms"); - Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second"); - - var index = 0; - foreach (var stream in decoders) - { - var text = stream.Read(); - - Console.ForegroundColor = ConsoleColor.Green; - Console.Write($"{index++}. {prompt}"); - Console.ForegroundColor = ConsoleColor.Red; - Console.WriteLine(text); - } - - Console.WriteLine("Press any key to exit demo"); - Console.ReadKey(true); - } -} \ No newline at end of file diff --git a/LLama.Examples/Examples/BatchedExecutorFork.cs b/LLama.Examples/Examples/BatchedExecutorFork.cs new file mode 100644 index 000000000..b834190ce --- /dev/null +++ b/LLama.Examples/Examples/BatchedExecutorFork.cs @@ -0,0 +1,138 @@ +using LLama.Batched; +using LLama.Common; +using LLama.Native; +using LLama.Sampling; + +namespace LLama.Examples.Examples; + +/// +/// This demonstrates generating multiple replies to the same prompt, with a shared cache +/// +public class BatchedExecutorFork +{ + private const int n_split = 16; + private const int n_len = 64; + + public static async Task Run() + { + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath); + using var model = LLamaWeights.LoadFromFile(parameters); + + Console.WriteLine("Prompt (leave blank to select automatically):"); + var prompt = Console.ReadLine(); + if (string.IsNullOrWhiteSpace(prompt)) + prompt = "Not many people know that"; + + // Create an executor that can evaluate a batch of conversations together + var executor = new BatchedExecutor(model, parameters); + + // Print some info + var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); + Console.WriteLine($"Created executor with model: {name}"); + + // Evaluate the initial prompt to create one conversation + var start = executor.Prompt(prompt); + await executor.Infer(); + + // Create the root node of the tree + var root = new Node(start); + + // Run inference loop + for (var i = 0; i < n_len; i++) + { + if (i != 0) + await executor.Infer(); + + // Occasionally fork all the active conversations + if (i != 0 && i % n_split == 0) + root.Split(); + + // Sample all active conversations + root.Sample(); + } + + Console.WriteLine($"{prompt}..."); + root.Print(1); + + Console.WriteLine("Press any key to exit demo"); + Console.ReadKey(true); + } + + class Node + { + private readonly StreamingTokenDecoder _decoder; + + private readonly DefaultSamplingPipeline _sampler; + private Conversation? _conversation; + + private Node? _left; + private Node? _right; + + public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount; + + public Node(Conversation conversation) + { + _sampler = new DefaultSamplingPipeline(); + _conversation = conversation; + _decoder = new StreamingTokenDecoder(conversation.Executor.Context); + } + + public void Sample() + { + if (_conversation == null) + { + _left?.Sample(); + _right?.Sample(); + return; + } + + if (_conversation.RequiresInference) + return; + + // Sample one token + var ctx = _conversation.Executor.Context.NativeHandle; + var logitsCopy = _conversation.Sample().ToArray(); + var token = _sampler.Sample(ctx, logitsCopy, Array.Empty()); + _sampler.Accept(ctx, token); + _decoder.Add(token); + + // Prompt the conversation with this token, to continue generating from there + _conversation.Prompt(token); + } + + public void Split() + { + if (_conversation != null) + { + _left = new Node(_conversation.Fork()); + _right = new Node(_conversation.Fork()); + + _conversation.Dispose(); + _conversation = null; + } + else + { + _left?.Split(); + _right?.Split(); + } + } + + public void Print(int indendation) + { + var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White }; + Console.ForegroundColor = colors[indendation % colors.Length]; + + var message = _decoder.Read().ReplaceLineEndings(""); + + var prefix = new string(' ', indendation * 3); + var suffix = _conversation == null ? "..." : ""; + Console.WriteLine($"{prefix}...{message}{suffix}"); + + _left?.Print(indendation + 2); + _right?.Print(indendation + 2); + } + } +} \ No newline at end of file diff --git a/LLama.Examples/Examples/BatchedExecutorRewind.cs b/LLama.Examples/Examples/BatchedExecutorRewind.cs new file mode 100644 index 000000000..25195a56e --- /dev/null +++ b/LLama.Examples/Examples/BatchedExecutorRewind.cs @@ -0,0 +1,121 @@ +using LLama.Batched; +using LLama.Common; +using LLama.Native; +using LLama.Sampling; + +namespace LLama.Examples.Examples; + +/// +/// This demonstrates generating tokens and then rewinding to an earlier state +/// +public class BatchedExecutorRewind +{ + private const int n_generate = 24; + private const int n_rewind = 12; + private const int n_repeats = 6; + + public static async Task Run() + { + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath); + using var model = LLamaWeights.LoadFromFile(parameters); + + Console.WriteLine("Prompt (leave blank to select automatically):"); + var prompt = Console.ReadLine(); + if (string.IsNullOrWhiteSpace(prompt)) + prompt = "Not many people know that"; + + // Create an executor that can evaluate a batch of conversations together + var executor = new BatchedExecutor(model, parameters); + + // Print some info + var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); + Console.WriteLine($"Created executor with model: {name}"); + + // Evaluate the initial prompt to create one conversation + var conversation = executor.Prompt(prompt); + + // Create the start node wrapping the conversation + var node = new Node(executor.Context); + + // Print the prompt + Console.ForegroundColor = ConsoleColor.Green; + Console.WriteLine(prompt); + + for (var i = 0; i < n_repeats; i++) + { + for (var j = 0; j < n_generate; j++) + { + // Run inference + await executor.Infer(); + + // Sample a token + var token = node.Sample(conversation); + + // Continue conversation with this token + if (j != n_generate - 1) + conversation.Prompt(token); + } + + // Write out what we generated + node.Write(n_rewind, i + 1); + + // Rewind back a few tokens + conversation.Rewind(n_rewind + 1); + + // Prompt with a token + conversation.Prompt(node.GetToken(n_generate - n_rewind - 1)); + + // Create a new node around the rewound conversation + node = new Node(executor.Context); + } + + Console.WriteLine("Press any key to exit demo"); + Console.ReadKey(true); + } + + private class Node + { + private readonly LLamaContext _context; + + private readonly List _tokens = new List(); + private readonly DefaultSamplingPipeline Sampler; + + public Node(LLamaContext context) + { + _context = context; + Sampler = new DefaultSamplingPipeline(); + } + + public LLamaToken Sample(Conversation conversation) + { + var token = Sampler.Sample(_context.NativeHandle, conversation.Sample().ToArray(), Array.Empty()); + _tokens.Add(token); + return token; + } + + public void Write(int n_rewind, int depth) + { + var decoder = new StreamingTokenDecoder(_context); + + for (var i = 0; i < _tokens.Count - n_rewind; i++) + decoder.Add(_tokens[i]); + + Console.ForegroundColor = ConsoleColor.Green; + Console.Write(new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")); + + for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++) + decoder.Add(_tokens[i]); + + Console.ForegroundColor = ConsoleColor.DarkRed; + Console.WriteLine(decoder.Read().ReplaceLineEndings(" ")); + } + + public LLamaToken GetToken(int index) + { + return _tokens[index]; + } + } +} \ No newline at end of file diff --git a/LLama.Examples/Examples/Runner.cs b/LLama.Examples/Examples/Runner.cs index 3d9858e1d..54358ce70 100644 --- a/LLama.Examples/Examples/Runner.cs +++ b/LLama.Examples/Examples/Runner.cs @@ -23,7 +23,8 @@ public class Runner { "Semantic Kernel Chat.", SemanticKernelChat.Run }, { "Semantic Kernel Memory.", SemanticKernelMemory.Run }, { "Coding Assistant.", CodingAssistant.Run }, - { "Batch Decoding.", BatchedDecoding.Run }, + { "Batched Executor (Fork)", BatchedExecutorFork.Run }, + { "Batched Executor (Rewind)", BatchedExecutorRewind.Run }, { "SK Kernel Memory.", KernelMemory.Run }, { "Exit", async () => Environment.Exit(0) } }; diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs new file mode 100644 index 000000000..f432a89f1 --- /dev/null +++ b/LLama/Batched/BatchedExecutor.cs @@ -0,0 +1,119 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Native; + +namespace LLama.Batched; + +/// +/// A batched executor that can infer multiple separate "conversations" simultaneously. +/// +public sealed class BatchedExecutor + : IDisposable +{ + private int _nextSequenceId; + + internal LLamaBatch Batch { get; } + + /// + /// Epoch is incremented every time Infer is called. Conversations can use this to keep track of + /// whether they're waiting for inference, or can be sampled. + /// + internal ulong Epoch { get; private set; } + + /// + /// The this executor is using + /// + public LLamaContext Context { get; } + + /// + /// The this executor is using + /// + public LLamaWeights Model { get; } + + /// + /// Get the number of tokens in the batch, waiting for to be called + /// + public int BatchedTokenCount => Batch.TokenCount; + + /// + /// Check if this executor has been disposed. + /// + public bool IsDisposed { get; private set; } + + /// + /// Create a new batched executor + /// + /// The model to use + /// Parameters to create a new context + public BatchedExecutor(LLamaWeights model, IContextParams contextParams) + { + Model = model; + Batch = new LLamaBatch(); + Context = model.CreateContext(contextParams); + Epoch = 1; + } + + ~BatchedExecutor() + { + Dispose(); + } + + /// + /// Start a new with the given prompt + /// + /// + /// + public Conversation Prompt(string prompt) + { + if (IsDisposed) + throw new ObjectDisposedException(nameof(BatchedExecutor)); + + var conversation = new Conversation(this, GetNextSequenceId(), 0); + conversation.Prompt(prompt); + + return conversation; + } + + /// + /// Run inference for all conversations in the batch which have pending tokens. + /// + /// If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation + /// threads and running inference again. + /// + public async Task Infer(CancellationToken cancellation = default) + { + if (IsDisposed) + throw new ObjectDisposedException(nameof(BatchedExecutor)); + + var status = await Context.DecodeAsync(Batch, cancellation); + + // Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can + // be called again after a warning (e.g. NoKvSlot). + if (status == DecodeResult.Ok) + { + Epoch++; + Batch.Clear(); + } + + return status; + } + + /// + public void Dispose() + { + if (IsDisposed) + return; + IsDisposed = true; + + GC.SuppressFinalize(this); + + Context.Dispose(); + } + + internal LLamaSeqId GetNextSequenceId() + { + return checked((LLamaSeqId)_nextSequenceId++); + } +} \ No newline at end of file diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs new file mode 100644 index 000000000..6cf6e3128 --- /dev/null +++ b/LLama/Batched/Conversation.cs @@ -0,0 +1,294 @@ +using System; +using System.Collections.Generic; +using LLama.Native; + +namespace LLama.Batched; + +/// +/// A single conversation thread that can be prompted (adding tokens from the user) or inferred (extracting a token from the LLM) +/// +public sealed class Conversation + : IDisposable +{ + private ulong _requiredEpoch; + private LLamaPos _end; + private int _batchIndex; + private bool _disposed; + + /// + /// The executor which this conversation belongs to + /// + public BatchedExecutor Executor { get; } + + /// + /// Unique ID for this conversation + /// + public LLamaSeqId ConversationId { get; } + + /// + /// Total number of tokens in this conversation, cannot exceed the context length. + /// + public int TokenCount => _end.Value; + + /// + /// Indicates if this conversation has been disposed, nothing can be done with a disposed conversation + /// + public bool IsDisposed => _disposed || Executor.IsDisposed; + + /// + /// Indicates if this conversation is waiting for inference to be run on the executor. "Prompt" and "Sample" cannot be called when this is true. + /// + public bool RequiresInference => _requiredEpoch > Executor.Epoch; + + /// + /// Indicates that this conversation should be sampled. + /// + public bool RequiresSampling => _requiredEpoch == Executor.Epoch; + + #region construction/destruction + internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end) + { + ConversationId = id; + Executor = batch; + + _end = end; + } + + ~Conversation() + { + Dispose(); + } + + /// + /// End this conversation, freeing all resources used by it + /// + /// + public void Dispose() + { + if (IsDisposed) + return; + _disposed = true; + + // Remove this conversation from the KV cache + Executor.Context.NativeHandle.KvCacheRemove(ConversationId, 0, _end); + + // Prevent finalizer from running + GC.SuppressFinalize(this); + } + + private void AssertNotDisposed() + { + if (Executor.IsDisposed) + throw new ObjectDisposedException(nameof(BatchedExecutor)); + if (IsDisposed) + throw new ObjectDisposedException(nameof(Conversation)); + } + #endregion + + /// + /// Create a copy of the current conversation + /// + /// The copy shares internal state, so consumes very little extra memory. + /// + /// + public Conversation Fork() + { + AssertNotDisposed(); + + if (RequiresInference) + throw new CannotForkWhileRequiresInference(); + + // Create a new conversation which references the current position in this one + var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) + { + _batchIndex = _batchIndex, + _requiredEpoch = _requiredEpoch, + }; + + // Assign tokens to the new sequence + NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end); + + return c; + } + + #region sample + /// + /// Get the logits from this conversation, ready for sampling + /// + /// + /// + /// Thrown if this conversation was not prompted before the previous call to infer + /// Thrown if Infer() must be called on the executor + public ReadOnlySpan Sample() + { + AssertNotDisposed(); + + if (_requiredEpoch < Executor.Epoch) + throw new CannotSampleRequiresPromptException(); + if (_requiredEpoch > Executor.Epoch) + throw new CannotSampleRequiresInferenceException(); + + return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex); + } + #endregion + + #region prompt + private void AssertCanBePrompted() + { + AssertNotDisposed(); + + if (RequiresInference) + throw new AlreadyPromptedConversationException(); + } + + /// + /// Add tokens to this conversation + /// + /// + /// + public void Prompt(string input) + { + AssertCanBePrompted(); + + Prompt(Executor.Context.Tokenize(input)); + } + + /// + /// Add tokens to this conversation + /// + /// + /// + /// + public void Prompt(IReadOnlyList tokens) + { + AssertCanBePrompted(); + + // Add the prompt to the batch + for (var i = 0; i < tokens.Count; i++) + _batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1); + + // Mark this conversation as needing inference/sampling + _requiredEpoch = Executor.Epoch + 1; + } + + /// + /// Add a single token to this conversation + /// + /// + /// + /// + /// + public void Prompt(LLamaToken token) + { + AssertCanBePrompted(); + + // Add this token as input + _batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true); + + // Mark this conversation as needing inference/sampling + _requiredEpoch = Executor.Epoch + 1; + } + #endregion + + #region modify + /// + /// Directly modify the KV cache of this conversation + /// + /// + /// Thrown if this method is called while == true + public void Modify(ModifyKvCache modifier) + { + AssertNotDisposed(); + + if (RequiresInference) + throw new CannotModifyWhileRequiresInference(); + + // do whatever the modification is + _end = modifier.Invoke(_end, new KvAccessor(this)); + + // Set the epoch down to zero, this ensures that this conversation + // cannot be sampled until it is prompted again. + _requiredEpoch = 0; + } + + /// + /// Provides direct access to the KV cache of a . + /// See for how to use this. + /// + public readonly ref struct KvAccessor + { + private readonly Conversation _conversation; + + internal KvAccessor(Conversation conversation) + { + _conversation = conversation; + } + + #region remove + /// + /// Removes all tokens that have positions in [start, end) + /// + /// Start position (inclusive) + /// End position (exclusive) + public void Remove(LLamaPos start, LLamaPos end) + { + _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); + } + + /// + /// Removes all tokens starting from the given position + /// + /// Start position (inclusive) + /// Number of tokens + public void Remove(LLamaPos start, int count) + { + if (count <= 0) + return; + + var end = start.Value + count; + _conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end); + } + #endregion + + #region shift + /// + /// Adds relative position "delta" to all tokens that have positions in [p0, p1). + /// If the KV cache is RoPEd, the KV data is updated + /// accordingly + /// + /// Start position (inclusive) + /// End position (exclusive) + /// Amount to add on to each token position + public void Shift(LLamaPos start, LLamaPos end, int delta) + { + _conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta); + } + #endregion + + #region divide + /// + /// Integer division of the positions by factor of `d > 1`. + /// If the KV cache is RoPEd, the KV data is updated accordingly. + /// + /// Start position (inclusive). If less than zero, it is clamped to zero. + /// End position (exclusive). If less than zero, it is treated as "infinity". + /// Amount to divide each position by. + public void Divide(LLamaPos start, LLamaPos end, int divisor) + { + if (divisor <= 0) + throw new ArgumentOutOfRangeException(nameof(divisor)); + + _conversation.Executor.Context.NativeHandle.KvCacheSequenceDivide(_conversation.ConversationId, start, end, divisor); + } + #endregion + } + + /// + /// A function which can temporarily access the KV cache of a to modify it directly + /// + /// The current end token of this conversation + /// An which allows direct access to modify the KV cache + /// The new end token position + public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv); + #endregion +} \ No newline at end of file diff --git a/LLama/Batched/ConversationExtensions.cs b/LLama/Batched/ConversationExtensions.cs new file mode 100644 index 000000000..5fca5e94b --- /dev/null +++ b/LLama/Batched/ConversationExtensions.cs @@ -0,0 +1,59 @@ +using System; + +namespace LLama.Batched; + +/// +/// Extension method for +/// +public static class ConversationExtensions +{ + /// + /// Rewind a back to an earlier state by removing tokens from the end + /// + /// The conversation to rewind + /// The number of tokens to rewind + /// Thrown if `tokens` parameter is larger than TokenCount + public static void Rewind(this Conversation conversation, int tokens) + { + if (tokens > conversation.TokenCount) + throw new ArgumentOutOfRangeException(nameof(tokens), "Cannot rewind more than the total number of tokens"); + + conversation.Modify((end, kv) => + { + // Remove those tokens from KV + kv.Remove(end.Value - tokens, tokens); + + // Return adjusted end position + return end.Value - tokens; + }); + } + + /// + /// Shift all tokens over to the left, removing "count" tokens from the start and shifting everything over. + /// Leaves "keep" tokens at the start completely untouched. This can be used to free up space when the context + /// gets full, keeping the prompt at the start intact. + /// + /// The conversation to rewind + /// How much to shift tokens over by + /// The number of tokens at the start which should not be shifted + public static void ShiftLeft(this Conversation conversation, int count, int keep) + { + // Given a setup like this (shift=5, keep=3): + // + // AAABBBBBCCCCCCCCC... + // + // We want to remove all the B's, shift all the C's and leave all the A's untouched + + conversation.Modify((end, kv) => + { + // Remove the B's + kv.Remove(keep, count); + + // Shift the C's + kv.Shift(keep + count, end, -count); + + // Update total count + return end.Value - count; + }); + } +} \ No newline at end of file diff --git a/LLama/Batched/Exceptions.cs b/LLama/Batched/Exceptions.cs new file mode 100644 index 000000000..1feb270c2 --- /dev/null +++ b/LLama/Batched/Exceptions.cs @@ -0,0 +1,81 @@ +using System; + +namespace LLama.Batched; + +/// +/// Base class for exceptions thrown from +/// +public class ExperimentalBatchedExecutorException + : Exception +{ + internal ExperimentalBatchedExecutorException(string message) + : base(message) + { + } +} + +/// +/// This exception is thrown when "Prompt()" is called on a which has +/// already been prompted and before "Infer()" has been called on the associated +/// . +/// +public class AlreadyPromptedConversationException + : ExperimentalBatchedExecutorException +{ + internal AlreadyPromptedConversationException() + : base("Must call `Infer()` before prompting this Conversation again") + { + } +} + +/// +/// This exception is thrown when "Sample()" is called on a which has +/// already been prompted and before "Infer()" has been called on the associated +/// . +/// +public class CannotSampleRequiresInferenceException + : ExperimentalBatchedExecutorException +{ + internal CannotSampleRequiresInferenceException() + : base("Must call `Infer()` before sampling from this Conversation") + { + } +} + +/// +/// This exception is thrown when "Sample()" is called on a which was not +/// first prompted. +/// . +/// +public class CannotSampleRequiresPromptException + : ExperimentalBatchedExecutorException +{ + internal CannotSampleRequiresPromptException() + : base("Must call `Prompt()` and then `Infer()` before sampling from this Conversation") + { + } +} + +/// +/// This exception is thrown when is called when = true +/// +public class CannotForkWhileRequiresInference + : ExperimentalBatchedExecutorException +{ + internal CannotForkWhileRequiresInference() + : base("Cannot `Fork()` a conversation while RequiresInference is true") + { + } +} + +/// +/// This exception is thrown when is called when = true +/// +public class CannotModifyWhileRequiresInference + : ExperimentalBatchedExecutorException +{ + internal CannotModifyWhileRequiresInference() + : base("Cannot `Modify()` a conversation while RequiresInference is true") + { + } +} \ No newline at end of file diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 5d026b672..9075c89fc 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -221,7 +221,9 @@ public void LoadState(State state) /// The selected token public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens) { - return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); + var token = pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); + pipeline.Accept(NativeHandle, token); + return token; } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 969da783c..31e975caf 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -213,6 +213,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta if (inferenceParams.SamplingPipeline is not null) { id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id); } else { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 7d742c813..9338d8396 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -192,6 +192,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In if (inferenceParams.SamplingPipeline is not null) { id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id); } else { diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index 532e16fff..50b9c8f0e 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -18,6 +18,11 @@ public class LLamaBatch private LLamaSeqId[][] _sequenceIds; private IntPtr[] _sequenceIdsPtrs; + /// + /// Keep track of the index of existing token/position combos in the batch + /// + private readonly Dictionary<(LLamaToken, LLamaPos), int> _index = new(); + /// /// The number of tokens in this batch /// @@ -130,23 +135,44 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) /// The position to add it att /// The set of sequences to add this token to /// - public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) + /// The index that the token was added at. Use this for GetLogitsIth + public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) { + // Try to find this (token, position) combo somewhere in the batch to re-use it + if (_index.TryGetValue((token, pos), out var existingIndex)) + { + if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity) + GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length); + + foreach (var sequence in sequences) + { + _sequenceIds[existingIndex][_sequenceIdCount[existingIndex]] = sequence; + _sequenceIdCount[existingIndex]++; + } + + return existingIndex; + } + + // Couldn't find this it in the batch, add a new item + + // Frow capacity as necessary if (TokenCount == TokenCapacity) GrowTokenCapacity(); if (sequences.Length > SequenceCapacity) GrowMaxSequences(sequences.Length); + // Store the position in the index, so it can be found later + _index.Add((token, pos), TokenCount); + + // Add the items to the arrays _tokens[TokenCount] = token; _positions[TokenCount] = pos; - _sequenceIdCount[TokenCount] = sequences.Length; for (var i = 0; i < sequences.Length; i++) _sequenceIds[TokenCount][i] = sequences[i]; - _logits[TokenCount] = Convert.ToByte(logits); - TokenCount++; + return TokenCount++; } /// @@ -157,11 +183,12 @@ public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequenc /// The position to add it att /// The set of sequences to add this token to /// - public void Add(LLamaToken token, LLamaPos pos, List sequences, bool logits) + /// The index that the token was added at. Use this for GetLogitsIth + public int Add(LLamaToken token, LLamaPos pos, List sequences, bool logits) { #if NET5_0_OR_GREATER var seqSpan = CollectionsMarshal.AsSpan(sequences); - Add(token, pos, seqSpan, logits); + return Add(token, pos, seqSpan, logits); #else // on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't @@ -171,7 +198,7 @@ public void Add(LLamaToken token, LLamaPos pos, List sequences, bool try { sequences.CopyTo(rented, 0); - Add(token, pos, rented.AsSpan(0, sequences.Count), logits); + return Add(token, pos, rented.AsSpan(0, sequences.Count), logits); } finally { @@ -188,14 +215,15 @@ public void Add(LLamaToken token, LLamaPos pos, List sequences, bool /// The position to add it att /// The sequence to add this token to /// - public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) + /// The index that the token was added at. Use this for GetLogitsIth + public int Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) { // Create a temporary span to contain 1 item without allocating Span sequences = stackalloc LLamaSeqId[1]; sequences[0] = sequence; // Add it - Add(token, pos, sequences, logits); + return Add(token, pos, sequences, logits); } /// @@ -205,13 +233,17 @@ public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits /// The starting position to add tokens at /// The sequence to add this token to /// Whether the final token should generate logits - public void AddRange(ReadOnlySpan tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast) + /// The index that the final token was added at. Use this for GetLogitsIth + public int AddRange(ReadOnlySpan tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast) { + var last = -1; for (var i = 0; i < tokens.Length; i++) { var logits = (i == tokens.Length - 1) & logitsLast; - Add(tokens[i], start.Value + i, sequence, logits); + last = Add(tokens[i], start.Value + i, sequence, logits); } + + return last; } #endregion @@ -221,5 +253,6 @@ public void AddRange(ReadOnlySpan tokens, LLamaPos start, LLamaSeqId public void Clear() { TokenCount = 0; + _index.Clear(); } } \ No newline at end of file diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index c953cb237..578cad405 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -388,7 +388,7 @@ public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken ll /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta); + public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta); /// /// Integer division of the positions by factor of `d > 1` diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 91e82c85d..2d9387ae4 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -369,15 +369,15 @@ public void KvCacheSequenceKeep(LLamaSeqId seq) /// /// /// - public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta) + public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta) { NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); } /// - /// Integer division of the positions by factor of `d > 1` - /// If the KV cache is RoPEd, the KV data is updated accordingly - /// p0 < 0 : [0, p1] + /// Integer division of the positions by factor of `d > 1`. + /// If the KV cache is RoPEd, the KV data is updated accordingly.
+ /// p0 < 0 : [0, p1]
/// p1 < 0 : [p0, inf) ///
/// diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs index a41aa67e1..b86001aa0 100644 --- a/LLama/Sampling/BaseSamplingPipeline.cs +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -40,10 +40,7 @@ public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnl var candidates = LLamaTokenDataArray.Create(logits); // Process token data array - ProcessTokenDataArray(ctx, candidates, lastTokens); - - // Choose the final value - return ChooseToken(ctx, candidates); + return ProcessTokenDataArray(ctx, candidates, lastTokens); } finally { @@ -53,6 +50,9 @@ public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnl } } + /// + public abstract void Accept(SafeLLamaContextHandle ctx, LLamaToken token); + #region protected tokens /// /// Get all of the "protected" tokens that cannot be changed by ProcessLogits @@ -107,19 +107,14 @@ protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) /// protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); - /// - /// Choose the final token from the candidates - /// - /// - /// - /// - protected abstract LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); - /// public virtual void Reset() { } + /// + public abstract ISamplingPipeline Clone(); + /// public virtual void Dispose() { diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index b0fb5c596..531f34faa 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -141,9 +141,31 @@ protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, return id; } + public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) + { + Grammar?.AcceptToken(ctx, token); + } + /// - protected override LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) + public override ISamplingPipeline Clone() { - return candidates.SampleToken(ctx); + var clone = new DefaultSamplingPipeline(); + + foreach (var (k, v) in LogitBias) + clone.LogitBias.Add(k, v); + + clone.Grammar = Grammar?.Clone(); + clone.RepeatPenalty = RepeatPenalty; + clone.AlphaFrequency = AlphaFrequency; + clone.AlphaPresence = AlphaPresence; + clone.Temperature = Temperature; + clone.TopK = TopK; + clone.TailFreeZ = TailFreeZ; + clone.TypicalP = TypicalP; + clone.TopP = TopP; + clone.MinP = MinP; + clone.PenalizeNewline = PenalizeNewline; + + return clone; } } \ No newline at end of file diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index be1398790..b538d1feb 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -21,10 +21,23 @@ public interface ISamplingPipeline /// LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + /// + /// Update the pipeline, with knowledge that a particular token was just accepted + /// + /// + /// + void Accept(SafeLLamaContextHandle ctx, LLamaToken token); + /// /// Reset all internal state of the sampling pipeline /// void Reset(); + + /// + /// Create a copy of this sampling pipeline + /// + /// + ISamplingPipeline Clone(); } ///