From dbbaca54550cea991c65d3005f55bdc1af206265 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 26 Jul 2024 02:11:57 +0100 Subject: [PATCH] - Marked all properties for configuring sampling in `IInferenceParams` as obsolete, pushing users towards the newer `SamplingPipeline` system. - Removed old sampling code from `LLamaContext`, instead if no `SamplingPipeline` is supplied one is created (existing one is re-used, as much as possible). - Updated all examples to use new system, uncovered a bug in `TalkToYourself` which tried to use `Temperature` _and_ mirostate - not compatible! This is exactly the kind of bug this is trying to fix. - Added `AsSpan` to `FixedSizeQueue` to avoid allocations of temporary arrays for every token! --- .../Examples/ChatSessionStripRoleName.cs | 10 +- .../Examples/ChatSessionWithHistory.cs | 8 +- .../Examples/ChatSessionWithRestart.cs | 8 +- .../Examples/ChatSessionWithRoleName.cs | 10 +- LLama.Examples/Examples/CodingAssistant.cs | 11 +- LLama.Examples/Examples/GetEmbeddings.cs | 2 +- .../Examples/GrammarJsonResponse.cs | 10 +- .../Examples/InstructModeExecute.cs | 12 +- LLama.Examples/Examples/LLama3ChatSession.cs | 11 +- .../Examples/LlavaInteractiveModeExecute.cs | 17 +- LLama.Examples/Examples/LoadAndSaveSession.cs | 8 +- LLama.Examples/Examples/LoadAndSaveState.cs | 13 +- .../Examples/StatelessModeExecute.cs | 14 +- LLama.Examples/Examples/TalkToYourself.cs | 13 +- LLama/Abstractions/IInferenceParams.cs | 195 ++++++++++++------ LLama/Common/FixedSizeQueue.cs | 15 +- LLama/Common/InferenceParams.cs | 19 +- LLama/Extensions/DictionaryExtensions.cs | 9 +- LLama/LLamaContext.cs | 141 ------------- LLama/LLamaExecutorBase.cs | 5 - LLama/LLamaInstructExecutor.cs | 32 ++- LLama/LLamaInteractExecutor.cs | 31 ++- LLama/LLamaStatelessExecutor.cs | 33 +-- LLama/Sampling/DefaultSamplingPipeline.cs | 4 +- LLama/Sampling/Mirostat2SamplingPipeline.cs | 6 +- 25 files changed, 325 insertions(+), 312 deletions(-) diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs index 5469aa8f0..f1a24ba86 100644 --- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using LLama.Common; +using LLama.Sampling; namespace LLama.Examples.Examples; @@ -27,9 +28,12 @@ public static async Task Run() new string[] { "User:", "Assistant:" }, redundancyLength: 8)); - InferenceParams inferenceParams = new InferenceParams() + var inferenceParams = new InferenceParams { - Temperature = 0.9f, + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.9f + }, AntiPrompts = new List { "User:" } }; diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index af7d7eac4..75b40aee2 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -1,4 +1,5 @@ using LLama.Common; +using LLama.Sampling; namespace LLama.Examples.Examples; @@ -39,9 +40,12 @@ public static async Task Run() new string[] { "User:", "Assistant:" }, redundancyLength: 8)); - InferenceParams inferenceParams = new InferenceParams() + var inferenceParams = new InferenceParams { - Temperature = 0.9f, + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.9f + }, AntiPrompts = new List { "User:" } }; diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index c2bfb8954..c0b2ee9f4 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -1,4 +1,5 @@ using LLama.Common; +using LLama.Sampling; namespace LLama.Examples.Examples; @@ -29,9 +30,12 @@ public static async Task Run() ChatSession session = new ChatSession(executor); session.LoadSession(resetState); - InferenceParams inferenceParams = new InferenceParams() + var inferenceParams = new InferenceParams { - Temperature = 0.9f, + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.9f + }, AntiPrompts = new List { "User:" } }; diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs index 4e2befd98..51fcb248b 100644 --- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using LLama.Common; +using LLama.Sampling; namespace LLama.Examples.Examples; @@ -22,9 +23,12 @@ public static async Task Run() ChatSession session = new(executor, chatHistory); - InferenceParams inferenceParams = new InferenceParams() + var inferenceParams = new InferenceParams { - Temperature = 0.9f, + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.9f + }, AntiPrompts = new List { "User:" } }; diff --git a/LLama.Examples/Examples/CodingAssistant.cs b/LLama.Examples/Examples/CodingAssistant.cs index a2edf8be6..384f5c526 100644 --- a/LLama.Examples/Examples/CodingAssistant.cs +++ b/LLama.Examples/Examples/CodingAssistant.cs @@ -1,4 +1,6 @@ -namespace LLama.Examples.Examples +using LLama.Sampling; + +namespace LLama.Examples.Examples { using LLama.Common; using System; @@ -40,9 +42,12 @@ public static async Task Run() "\nWrite 'exit' to exit"); Console.ForegroundColor = ConsoleColor.White; - var inferenceParams = new InferenceParams() + var inferenceParams = new InferenceParams { - Temperature = 0.8f, + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.8f + }, MaxTokens = -1, }; diff --git a/LLama.Examples/Examples/GetEmbeddings.cs b/LLama.Examples/Examples/GetEmbeddings.cs index 1e10ba22b..ad844004e 100644 --- a/LLama.Examples/Examples/GetEmbeddings.cs +++ b/LLama.Examples/Examples/GetEmbeddings.cs @@ -1,4 +1,4 @@ -using LLama.Common; +using LLama.Common; namespace LLama.Examples.Examples { diff --git a/LLama.Examples/Examples/GrammarJsonResponse.cs b/LLama.Examples/Examples/GrammarJsonResponse.cs index a5bb5486b..0a86a912d 100644 --- a/LLama.Examples/Examples/GrammarJsonResponse.cs +++ b/LLama.Examples/Examples/GrammarJsonResponse.cs @@ -1,5 +1,6 @@ -using LLama.Common; +using LLama.Common; using LLama.Grammars; +using LLama.Sampling; namespace LLama.Examples.Examples { @@ -27,10 +28,13 @@ public static async Task Run() using var grammarInstance = grammar.CreateInstance(); var inferenceParams = new InferenceParams() { - Temperature = 0.6f, + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.6f, + Grammar = grammarInstance + }, AntiPrompts = new List { "Question:", "#", "Question: ", ".\n" }, MaxTokens = 50, - Grammar = grammarInstance }; while (true) diff --git a/LLama.Examples/Examples/InstructModeExecute.cs b/LLama.Examples/Examples/InstructModeExecute.cs index 4f65dd233..d319c2b58 100644 --- a/LLama.Examples/Examples/InstructModeExecute.cs +++ b/LLama.Examples/Examples/InstructModeExecute.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using LLama.Common; +using LLama.Sampling; namespace LLama.Examples.Examples { @@ -25,7 +26,14 @@ public static async Task Run() "make friend with human, no less than 200 words.\""); Console.ForegroundColor = ConsoleColor.White; - var inferenceParams = new InferenceParams() { Temperature = 0.8f, MaxTokens = 600 }; + var inferenceParams = new InferenceParams + { + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.8f + }, + MaxTokens = 600 + }; while (true) { diff --git a/LLama.Examples/Examples/LLama3ChatSession.cs b/LLama.Examples/Examples/LLama3ChatSession.cs index 01aa33cd6..e5e6167d8 100644 --- a/LLama.Examples/Examples/LLama3ChatSession.cs +++ b/LLama.Examples/Examples/LLama3ChatSession.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using LLama.Common; +using LLama.Sampling; using LLama.Transformers; namespace LLama.Examples.Examples; @@ -37,10 +38,14 @@ public static async Task Run() [model.Tokens.EndOfTurnToken!, "�"], redundancyLength: 5)); - var inferenceParams = new InferenceParams() + var inferenceParams = new InferenceParams { + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.6f + }, + MaxTokens = -1, // keep generating tokens until the anti prompt is encountered - Temperature = 0.6f, AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string }; diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs index 9d396ebfe..c763d0a7c 100644 --- a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs @@ -1,7 +1,8 @@ -using System.Text.RegularExpressions; +using System.Text.RegularExpressions; using LLama.Common; using Spectre.Console; using LLama.Native; +using LLama.Sampling; namespace LLama.Examples.Examples { @@ -30,9 +31,19 @@ public static async Task Run() Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize ); - Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}."); + Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}."); - var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List { "\nUSER:" }, MaxTokens = maxTokens }; + var inferenceParams = new InferenceParams + { + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.1f + }, + + AntiPrompts = new List { "\nUSER:" }, + MaxTokens = maxTokens + + }; do { diff --git a/LLama.Examples/Examples/LoadAndSaveSession.cs b/LLama.Examples/Examples/LoadAndSaveSession.cs index 68ed8aa32..7a0196adf 100644 --- a/LLama.Examples/Examples/LoadAndSaveSession.cs +++ b/LLama.Examples/Examples/LoadAndSaveSession.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using LLama.Common; +using LLama.Sampling; namespace LLama.Examples.Examples { @@ -35,7 +36,10 @@ in session.ChatAsync( new ChatHistory.Message(AuthorRole.User, prompt), new InferenceParams() { - Temperature = 0.6f, + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.6f + }, AntiPrompts = new List { "User:" } })) { diff --git a/LLama.Examples/Examples/LoadAndSaveState.cs b/LLama.Examples/Examples/LoadAndSaveState.cs index 0fef49f12..af445ee99 100644 --- a/LLama.Examples/Examples/LoadAndSaveState.cs +++ b/LLama.Examples/Examples/LoadAndSaveState.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using LLama.Common; +using LLama.Sampling; namespace LLama.Examples.Examples { @@ -27,7 +28,15 @@ public static async Task Run() Console.ForegroundColor = ConsoleColor.White; Console.Write(prompt); - var inferenceParams = new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } }; + var inferenceParams = new InferenceParams + { + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.6f + }, + + AntiPrompts = new List { "User:" } + }; while (true) { diff --git a/LLama.Examples/Examples/StatelessModeExecute.cs b/LLama.Examples/Examples/StatelessModeExecute.cs index 806616e7c..9f49919b8 100644 --- a/LLama.Examples/Examples/StatelessModeExecute.cs +++ b/LLama.Examples/Examples/StatelessModeExecute.cs @@ -1,5 +1,6 @@ -using LLama.Common; +using LLama.Common; using LLama.Examples.Extensions; +using LLama.Sampling; namespace LLama.Examples.Examples { @@ -24,7 +25,16 @@ public static async Task Run() "a prompt for it yourself!"); Console.ForegroundColor = ConsoleColor.White; - var inferenceParams = new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "Question:", "#", "Question: ", ".\n" }, MaxTokens = 50 }; + var inferenceParams = new InferenceParams + { + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.6f + }, + + AntiPrompts = new List { "Question:", "#", "Question: ", ".\n" }, + MaxTokens = 50 + }; while (true) { diff --git a/LLama.Examples/Examples/TalkToYourself.cs b/LLama.Examples/Examples/TalkToYourself.cs index f888209ae..ae0685471 100644 --- a/LLama.Examples/Examples/TalkToYourself.cs +++ b/LLama.Examples/Examples/TalkToYourself.cs @@ -1,6 +1,7 @@ -using System.Text; +using System.Text; using LLama.Abstractions; using LLama.Common; +using LLama.Sampling; namespace LLama.Examples.Examples { @@ -43,11 +44,13 @@ private static async Task Prompt(ILLamaExecutor executor, ConsoleColor c { var inferenceParams = new InferenceParams { - Temperature = 0.9f, - AntiPrompts = new List { "Alice:", "Bob:", "User:" }, + SamplingPipeline = new Mirostat2SamplingPipeline + { + Tau = 10 + }, + + AntiPrompts = [ "Alice:", "Bob:", "User:" ], MaxTokens = 128, - Mirostat = MirostatType.Mirostat2, - MirostatTau = 10, }; Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index 425bc88d2..f00c22110 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,14 +1,15 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using LLama.Common; using LLama.Native; using LLama.Sampling; namespace LLama.Abstractions { - /// - /// The parameters used for inference. - /// - public interface IInferenceParams + /// + /// The parameters used for inference. + /// + public interface IInferenceParams { /// /// number of tokens to keep from initial prompt @@ -21,93 +22,110 @@ public interface IInferenceParams /// public int MaxTokens { get; set; } - /// - /// logit bias for specific tokens - /// - public Dictionary? LogitBias { get; set; } + /// + /// logit bias for specific tokens + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] + public Dictionary? LogitBias { get; set; } /// /// Sequences where the model will stop generating further tokens. /// public IReadOnlyList AntiPrompts { get; set; } - /// - /// 0 or lower to use vocab size - /// - public int TopK { get; set; } + /// + /// 0 or lower to use vocab size + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] + public int TopK { get; set; } - /// - /// 1.0 = disabled - /// - public float TopP { get; set; } + /// + /// 1.0 = disabled + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] + public float TopP { get; set; } /// /// 0.0 = disabled /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public float MinP { get; set; } /// /// 1.0 = disabled /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public float TfsZ { get; set; } - /// - /// 1.0 = disabled - /// - public float TypicalP { get; set; } + /// + /// 1.0 = disabled + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] + public float TypicalP { get; set; } - /// - /// 1.0 = disabled - /// - public float Temperature { get; set; } + /// + /// 1.0 = disabled + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] + public float Temperature { get; set; } - /// - /// 1.0 = disabled - /// - public float RepeatPenalty { get; set; } + /// + /// 1.0 = disabled + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] + public float RepeatPenalty { get; set; } - /// - /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) - /// - public int RepeatLastTokensCount { get; set; } + /// + /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] + public int RepeatLastTokensCount { get; set; } - /// - /// frequency penalty coefficient - /// 0.0 = disabled - /// - public float FrequencyPenalty { get; set; } + /// + /// frequency penalty coefficient + /// 0.0 = disabled + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] + public float FrequencyPenalty { get; set; } - /// - /// presence penalty coefficient - /// 0.0 = disabled - /// - public float PresencePenalty { get; set; } + /// + /// presence penalty coefficient + /// 0.0 = disabled + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] + public float PresencePenalty { get; set; } - /// - /// Mirostat uses tokens instead of words. - /// algorithm described in the paper https://arxiv.org/abs/2007.14966. - /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - /// - public MirostatType Mirostat { get; set; } + /// + /// Mirostat uses tokens instead of words. + /// algorithm described in the paper https://arxiv.org/abs/2007.14966. + /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")] + public MirostatType Mirostat { get; set; } - /// - /// target entropy - /// - public float MirostatTau { get; set; } + /// + /// target entropy + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")] + public float MirostatTau { get; set; } - /// - /// learning rate - /// - public float MirostatEta { get; set; } + /// + /// learning rate + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")] + public float MirostatEta { get; set; } - /// - /// consider newlines as a repeatable token (penalize_nl) - /// - public bool PenalizeNL { get; set; } + /// + /// consider newlines as a repeatable token (penalize_nl) + /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] + // ReSharper disable once InconsistentNaming (obsolete, will be removed anyway) + public bool PenalizeNL { get; set; } /// /// Grammar to constrain possible tokens /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] SafeLLamaGrammarHandle? Grammar { get; set; } /// @@ -115,4 +133,59 @@ public interface IInferenceParams /// ISamplingPipeline? SamplingPipeline { get; set; } } + + internal static class IInferanceParamsExtensions + { + public static ISamplingPipeline Create(this IInferenceParams @params, ref ISamplingPipeline? pipeline) + { + // This method exists to adapt the old style of inference params to the newer sampling pipeline system. It's touching a lot + // of obsolete things which we don't really care about, disable the warning. + #pragma warning disable CS0618 // Type or member is obsolete + + if (@params.Mirostat == MirostatType.Mirostat) + { + if (pipeline is not MirostatSamplingPipeline) + pipeline = new MirostatSamplingPipeline(); + + var m = (MirostatSamplingPipeline)pipeline; + m.Eta = @params.MirostatEta; + m.Tau = @params.MirostatTau; + return m; + } + + if (@params.Mirostat == MirostatType.Mirostat2) + { + if (pipeline is not Mirostat2SamplingPipeline) + pipeline = new Mirostat2SamplingPipeline(); + + var m = (Mirostat2SamplingPipeline)pipeline; + m.Eta = @params.MirostatEta; + m.Tau = @params.MirostatTau; + return m; + } + + if (pipeline is not DefaultSamplingPipeline) + pipeline = new DefaultSamplingPipeline(); + + var d = (DefaultSamplingPipeline)pipeline; + d.AlphaPresence = @params.PresencePenalty; + d.MinP = @params.MinP; + d.PenalizeNewline = @params.PenalizeNL; + d.RepeatPenalty = @params.RepeatPenalty; + d.TailFreeZ = @params.TfsZ; + d.Temperature = @params.Temperature; + d.TopK = @params.TopK; + d.TopP = @params.TopP; + d.AlphaFrequency = @params.FrequencyPenalty; + d.TypicalP = @params.TypicalP; + d.Grammar = @params.Grammar; + + d.LogitBias.Clear(); + @params.LogitBias?.CopyTo(d.LogitBias); + + return d; + + #pragma warning restore CS0618 // Type or member is obsolete + } + } } \ No newline at end of file diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs index 8c14a1961..62056498c 100644 --- a/LLama/Common/FixedSizeQueue.cs +++ b/LLama/Common/FixedSizeQueue.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections; using System.Collections.Generic; using System.Linq; @@ -82,5 +82,18 @@ IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); } + + internal ReadOnlySpan AsSpan(int count) + { + // Ensure the request isn't for more tokens than actually exist + count = Math.Min(count, Count); + + // Take `count` items from the end +#if NET8_0_OR_GREATER + return CollectionsMarshal.AsSpan(_storage)[^count..]; +#else + return _storage.ToArray().AsSpan(_storage.Count - count, count); +#endif + } } } diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index b2e429f83..2e83f4b41 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -1,7 +1,8 @@ -using LLama.Abstractions; +using LLama.Abstractions; using System.Collections.Generic; using LLama.Native; using LLama.Sampling; +using System; namespace LLama.Common { @@ -25,6 +26,7 @@ public record InferenceParams /// /// logit bias for specific tokens /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public Dictionary? LogitBias { get; set; } = null; /// @@ -33,48 +35,63 @@ public record InferenceParams public IReadOnlyList AntiPrompts { get; set; } = []; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public int TopK { get; set; } = 40; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public float TopP { get; set; } = 0.95f; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public float MinP { get; set; } = 0.05f; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public float TfsZ { get; set; } = 1.0f; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public float TypicalP { get; set; } = 1.0f; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public float Temperature { get; set; } = 0.8f; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public float RepeatPenalty { get; set; } = 1.1f; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public int RepeatLastTokensCount { get; set; } = 64; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public float FrequencyPenalty { get; set; } = .0f; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public float PresencePenalty { get; set; } = .0f; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")] public MirostatType Mirostat { get; set; } = MirostatType.Disable; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")] public float MirostatTau { get; set; } = 5.0f; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")] public float MirostatEta { get; set; } = 0.1f; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public bool PenalizeNL { get; set; } = true; /// + [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] public SafeLLamaGrammarHandle? Grammar { get; set; } /// diff --git a/LLama/Extensions/DictionaryExtensions.cs b/LLama/Extensions/DictionaryExtensions.cs index 6599f6316..aaa782eae 100644 --- a/LLama/Extensions/DictionaryExtensions.cs +++ b/LLama/Extensions/DictionaryExtensions.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System.Collections.Generic; +using LLama.Native; namespace LLama.Extensions { @@ -18,5 +19,11 @@ internal static TValue GetValueOrDefaultImpl(IReadOnlyDictionary(this IReadOnlyDictionary source, IDictionary dest) + { + foreach (var (k, v) in source) + dest[k] = v; + } } } diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index b4e23dd1e..a9ee4bb0e 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -2,14 +2,11 @@ using LLama.Native; using System; using System.Collections.Generic; -using System.Linq; using System.Text; using System.IO; using System.IO.MemoryMappedFiles; -using LLama.Common; using System.Threading.Tasks; using LLama.Abstractions; -using LLama.Sampling; using Microsoft.Extensions.Logging; using System.Threading; @@ -82,8 +79,6 @@ public uint BatchThreads /// public SafeLlamaModelHandle.ModelTokens Tokens { get; } - private LLamaTokenData[]? _samplingBuffer; - /// /// Create a new LLamaContext for the given LLamaWeights /// @@ -379,142 +374,6 @@ public void LoadState(SequenceState state, LLamaSeqId sequence) } #endregion - /// - /// Sample a single token from this context, using the given sampling pipeline - /// - /// The pipeline to use to process the logits and to select a token - /// The tokens recently returned from the model - /// The selected token - public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens) - { - var token = pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); - pipeline.Accept(NativeHandle, token); - return token; - } - - /// - /// Perform the sampling. Please don't use it unless you fully know what it does. - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat, - float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP, - SafeLLamaGrammarHandle? grammar, float minP) - { - LLamaToken id; - - if (grammar != null) - { - candidates.ApplyGrammar(NativeHandle, grammar); - } - - if (temperature <= 0) - { - // Greedy sampling - id = candidates.SampleTokenGreedy(NativeHandle); - } - else - { - var mu = mirostat_mu ?? (2 * mirostatTau); - { - if (mirostat == MirostatType.Mirostat) - { - const int mirostat_m = 100; - candidates.Temperature(NativeHandle, temperature); - id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu); - } - else if (mirostat == MirostatType.Mirostat2) - { - candidates.Temperature(NativeHandle, temperature); - id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu); - } - else - { - candidates.TopK(NativeHandle, topK); - candidates.TailFree(NativeHandle, tfsZ); - candidates.LocallyTypical(NativeHandle, typicalP); - candidates.TopP(NativeHandle, topP); - candidates.MinP(NativeHandle, minP); - candidates.Temperature(NativeHandle, temperature); - id = candidates.SampleToken(NativeHandle); - } - } - mirostat_mu = mu; - } - - grammar?.AcceptToken(NativeHandle, id); - - return id; - } - - /// - /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does. - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable lastTokens, Dictionary? logitBias = null, - int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, - bool penalizeNL = true) - { - var logits = NativeHandle.GetLogitsIth(logits_i); - - // Apply params.logit_bias map - if (logitBias is not null) - { - foreach (var (key, value) in logitBias) - logits[(int)key] += value; - } - - // Save the newline logit value - var nl_token = NativeHandle.ModelHandle.Tokens.Newline; - var nl_logit = logits[(int?)nl_token ?? 0]; - - // Convert logits into token candidates - if (_samplingBuffer == null || _samplingBuffer.Length < logits.Length) - _samplingBuffer = new LLamaTokenData[logits.Length]; - var candidates_p = LLamaTokenDataArray.Create(logits, _samplingBuffer); - - // Extract most recently returned tokens - var last_n_repeat = Math.Min((int)ContextSize, repeatLastTokensCount); - var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); - - // Apply penalties to candidates - candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence); - - // Restore newline token logit value if necessary - if (!penalizeNL && nl_token.HasValue) - { - var candidatesSpan = candidates_p.Data.Span; - for (var i = 0; i < candidates_p.Data.Length; i++) - { - ref var item = ref candidatesSpan[i]; - if (item.id == nl_token) - item.logit = nl_logit; - } - candidates_p.Sorted = false; - } - - return candidates_p; - } - /// /// Gets whether or not the Bos token should be added. /// From common.cpp https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/common/common.cpp#L2417 diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 9b2b17617..bc8812cb5 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -81,11 +81,6 @@ public bool IsMultiModal /// public List Images { get; } - /// - /// Current "mu" value for mirostat sampling - /// - protected float? MirostatMu { get; set; } - private readonly StreamingTokenDecoder _decoder; /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index ec41aa7fb..b37d1a79c 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -9,6 +9,7 @@ using System.Text.Json.Serialization; using System.Threading.Tasks; using LLama.Exceptions; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -24,6 +25,8 @@ public class InstructExecutor private LLamaToken[] _inp_pfx; private LLamaToken[] _inp_sfx; + private ISamplingPipeline? _pipeline; + /// /// /// @@ -60,7 +63,6 @@ public override ExecutorBaseState GetStateData() SessionFilePath = _pathSession, SessionTokens = _session_tokens.ToArray(), LastTokensCapacity = _last_n_tokens.Capacity, - MirostatMu = MirostatMu }; return state; } @@ -227,28 +229,18 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In SaveSessionFile(_pathSession); } - LLamaToken id; - if (inferenceParams.SamplingPipeline is not null) - { - id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(batch.TokenCount - 1), _last_n_tokens.ToArray()); - inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id); - } + // use the explicitly supplied pipeline, if there is one. Otherwise construct a suitable one. + var pipeline = inferenceParams.SamplingPipeline; + if (pipeline != null) + _pipeline = null; else - { - var tokenDataArray = Context.ApplyPenalty(batch.TokenCount - 1, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - var mu = MirostatMu; - id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; - } + pipeline = inferenceParams.Create(ref _pipeline); - _last_n_tokens.Enqueue(id); + // Sample with the pipeline + var id = pipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(batch.TokenCount - 1), _last_n_tokens.AsSpan(repeat_last_n)); + pipeline.Accept(Context.NativeHandle, id); + _last_n_tokens.Enqueue(id); _embeds.Add(id); args.RemainedTokens--; diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index f97a2b63c..d0a950db0 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -9,6 +9,7 @@ using System.Text.Json.Serialization; using System.Threading.Tasks; using LLama.Exceptions; +using LLama.Sampling; using Microsoft.Extensions.Logging; @@ -26,6 +27,8 @@ public class InteractiveExecutor : StatefulExecutorBase private List _imageEmbedHandles = new List(); private bool _imageInPrompt = false; + private ISamplingPipeline? _pipeline; + /// /// /// @@ -57,7 +60,6 @@ public override ExecutorBaseState GetStateData() SessionFilePath = _pathSession, SessionTokens = _session_tokens.ToArray(), LastTokensCapacity = _last_n_tokens.Capacity, - MirostatMu = MirostatMu }; return state; } @@ -304,25 +306,16 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In SaveSessionFile(_pathSession); } - LLamaToken id; - if (inferenceParams.SamplingPipeline is not null) - { - id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(batch.TokenCount - 1), _last_n_tokens.ToArray()); - inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id); - } + // use the explicitly supplied pipeline, if there is one. Otherwise construct a suitable one. + var pipeline = inferenceParams.SamplingPipeline; + if (pipeline != null) + _pipeline = null; else - { - var tokenDataArray = Context.ApplyPenalty(batch.TokenCount - 1, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - var mu = MirostatMu; - id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; - } + pipeline = inferenceParams.Create(ref _pipeline); + + // Sample with the pipeline + var id = pipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(batch.TokenCount - 1), _last_n_tokens.AsSpan(repeat_last_n)); + pipeline.Accept(Context.NativeHandle, id); _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 2ba0f7c8f..dcb466569 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -23,7 +23,7 @@ public class StatelessExecutor private readonly IContextParams _params; private readonly ILogger? _logger; private readonly LLamaBatch _batch; - + // LLava Section public bool IsMultiModal => false; @@ -98,29 +98,18 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams if (r != DecodeResult.Ok) throw new LLamaDecodeError(r); + // use the explicitly supplied pipeline, if there is one. Otherwise construct a suitable one. + var pipeline = inferenceParams.SamplingPipeline; + if (pipeline == null) + pipeline = inferenceParams.Create(ref pipeline); + // Begin loop, evaluating one token at a time - var mu = (float?)null; - var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; - for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) + var maxTokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; + for(var i = 0; i < maxTokens && !cancellationToken.IsCancellationRequested; i++) { - LLamaToken id; - if (inferenceParams.SamplingPipeline is not null) - { - id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens); - } - else - { - // Penalize the generated tokens by various penalties - var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - // Sample a single token - id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - } + // Sample with the pipeline + var id = pipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens); + pipeline.Accept(Context.NativeHandle, id); // Check if this token should end generation if (_weights.Tokens.IsEndOfGeneration(id)) diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 4f87b0030..02575bc01 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -13,7 +13,7 @@ public sealed class DefaultSamplingPipeline /// /// Bias values to add to certain logits /// - public Dictionary LogitBias { get; } = new(); + public Dictionary LogitBias { get; } = new(); /// /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 @@ -98,7 +98,7 @@ protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span lo { // Apply logit bias foreach (var (key, value) in LogitBias) - logits[key] += value; + logits[(int)key] += value; } diff --git a/LLama/Sampling/Mirostat2SamplingPipeline.cs b/LLama/Sampling/Mirostat2SamplingPipeline.cs index de5de2f40..d0789d7e1 100644 --- a/LLama/Sampling/Mirostat2SamplingPipeline.cs +++ b/LLama/Sampling/Mirostat2SamplingPipeline.cs @@ -9,15 +9,15 @@ namespace LLama.Sampling; public class Mirostat2SamplingPipeline : BaseSamplingPipeline { - private const float DEFAULT_TAU = 5; + private const float DefaultTau = 5; - private float _mu = DEFAULT_TAU * 2; + private float _mu = DefaultTau * 2; /// /// Currently learned mu value /// public float Mu => _mu; - private float _tau = DEFAULT_TAU; + private float _tau = DefaultTau; /// /// target entropy ///