From 62381ab3f86b5ead233ad68f42b265542c17460a Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 24 Feb 2024 16:50:14 +0000 Subject: [PATCH 01/10] Added a `Guidance` method to `LLamaTokenDataArray` which applies classifier free guidance --- LLama/Native/LLamaTokenDataArray.cs | 57 +++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 98dd91b6e..515a4eea6 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -185,6 +185,63 @@ public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan + /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// + /// + /// Logits extracted from a separate context from the same model. + /// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// Guidance strength. 0 means no guidance, higher values applies stronger guidance + public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan guidanceLogits, float guidance) + { + if (guidanceLogits.Length != data.Length) + throw new ArgumentException("Guidance logits count must equal vocabulary size", nameof(guidanceLogits)); + + if (guidance < 0) + throw new ArgumentOutOfRangeException(nameof(guidance), "Guidance strength must be greater than or equal to zero"); + + // this method accepts 0 (no guidance), higher means more. llama.cpp expects 1 (no guidance), higher means more + // Add one to move up to the llama.cpp baseline. + guidance += 1; + + // We need logits array, which we don't have at this point. + // Copy them to a temporary array, apply guidance, then copy them back. + var logits = ArrayPool.Shared.Rent(context.VocabCount); + try + { + // Copy logits into a temporary array + for (var i = 0; i < data.Length; i++) + { + ref var item = ref data.Span[i]; + logits[(int)item.id] = item.logit; + } + + // Apply guidance + unsafe + { + fixed (float* logitsPtr = logits) + fixed (float* guidanceLogitsPtr = guidanceLogits) + { + NativeApi.llama_sample_apply_guidance(context, logitsPtr, guidanceLogitsPtr, guidance); + } + } + + // Copy logits back into data array + for (var i = 0; i < data.Length; i++) + { + ref var item = ref data.Span[i]; + item.logit = logits[(int)item.id]; + } + + // No longer sorted since we just mutated logits! + sorted = false; + } + finally + { + ArrayPool.Shared.Return(logits); + } + } + /// /// Sample with temperature. /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual From 528bb0185615325b1f7a1e9535998848943aec8c Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 24 Feb 2024 17:10:35 +0000 Subject: [PATCH 02/10] Factored out a safer `llama_sample_apply_guidance` method based on spans --- LLama/Native/LLamaTokenDataArray.cs | 9 +-------- LLama/Native/NativeApi.Sampling.cs | 30 ++++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 515a4eea6..f36679255 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -217,14 +217,7 @@ public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan guidanc } // Apply guidance - unsafe - { - fixed (float* logitsPtr = logits) - fixed (float* guidanceLogitsPtr = guidanceLogits) - { - NativeApi.llama_sample_apply_guidance(context, logitsPtr, guidanceLogitsPtr, guidance); - } - } + NativeApi.llama_sample_apply_guidance(context, logits, guidanceLogits, guidance); // Copy logits back into data array for (var i = 0; i < data.Length; i++) diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs index a52edc668..441e70ecd 100644 --- a/LLama/Native/NativeApi.Sampling.cs +++ b/LLama/Native/NativeApi.Sampling.cs @@ -1,4 +1,5 @@ -using System.Runtime.InteropServices; +using System; +using System.Runtime.InteropServices; namespace LLama.Native { @@ -23,6 +24,33 @@ public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaCont float penalty_freq, float penalty_present); + /// + /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// + /// + /// Logits extracted from the original generation context. + /// Logits extracted from a separate context from the same model. + /// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + public static void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan logits_guidance, float scale) + { + if (logits == null) + throw new ArgumentNullException(nameof(logits)); + if (logits_guidance == null) + throw new ArgumentNullException(nameof(logits_guidance)); + if (logits.Length != ctx.VocabCount) + throw new ArgumentException("Logits count must have equal context vocab size", nameof(logits)); + if (logits_guidance.Length != ctx.VocabCount) + throw new ArgumentException("Guidance logits count must have equal context vocab size", nameof(logits_guidance)); + + unsafe + { + fixed (float* logitsPtr = logits) + fixed (float* logitsGuidancePtr = logits_guidance) + llama_sample_apply_guidance(ctx, logitsPtr, logitsGuidancePtr, scale); + } + } + /// /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// From 80d10806fd8d436dc9fd0e6cd21ad14f0cf3beab Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 24 Feb 2024 20:03:32 +0000 Subject: [PATCH 03/10] Created a guided sampling demo using the batched executor --- LLama.Examples/ExampleRunner.cs | 1 + .../Examples/BatchedExecutorGuidance.cs | 125 ++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 LLama.Examples/Examples/BatchedExecutorGuidance.cs diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index 9ebbf5b8c..790a1f9c6 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -26,6 +26,7 @@ public class ExampleRunner { "Semantic Kernel: Store", SemanticKernelMemory.Run }, { "Batched Executor: Fork", BatchedExecutorFork.Run }, { "Batched Executor: Rewind", BatchedExecutorRewind.Run }, + { "Batched Executor: Guidance", BatchedExecutorGuidance.Run }, { "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } } }; diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs new file mode 100644 index 000000000..130c4dfdf --- /dev/null +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -0,0 +1,125 @@ +using LLama.Batched; +using LLama.Common; +using LLama.Native; +using LLama.Sampling; +using Spectre.Console; + +namespace LLama.Examples.Examples; + +/// +/// This demonstrates using a batch to generate two sequences and then using one +/// sequence as the negative guidance ("context free guidance") for the other. +/// +public class BatchedExecutorGuidance +{ + private const int n_len = 32; + + public static async Task Run() + { + string modelPath = UserSettings.GetModelPath(); + + var parameters = new ModelParams(modelPath); + using var model = LLamaWeights.LoadFromFile(parameters); + + var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim(); + var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim(); + + // Create an executor that can evaluate a batch of conversations together + var executor = new BatchedExecutor(model, parameters); + + // Print some info + var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); + Console.WriteLine($"Created executor with model: {name}"); + + // Load the two prompts into two conversations + var guided = executor.Prompt(positivePrompt); + var guidance = executor.Prompt(negativePrompt); + + // Run inference to evaluate prompts + await AnsiConsole + .Status() + .Spinner(Spinner.Known.Line) + .StartAsync("Evaluating Prompts...", _ => executor.Infer()); + + // Fork the "guided" conversation. We'll run this one without guidance for comparison + var unguided = guided.Fork(); + + // Run inference loop + var unguidedSampler = new GuidedSampler(null); + var unguidedDecoder = new StreamingTokenDecoder(executor.Context); + var guidedSampler = new GuidedSampler(guidance); + var guidedDecoder = new StreamingTokenDecoder(executor.Context); + await AnsiConsole + .Progress() + .StartAsync(async progress => + { + var reporter = progress.AddTask("Running Inference", maxValue: n_len); + + for (var i = 0; i < n_len; i++) + { + if (i != 0) + await executor.Infer(); + + // Sample from the "unguided" conversation + var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample().ToArray(), Array.Empty()); + unguidedDecoder.Add(u); + unguided.Prompt(u); + + // Sample form the "guided" conversation + var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample().ToArray(), Array.Empty()); + guidedDecoder.Add(g); + + // Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt). + guided.Prompt(g); + guidance.Prompt(g); + + // Early exit if we reach the natural end of the guided sentence + if (g == model.EndOfSentenceToken) + break; + + reporter.Increment(1); + } + }); + + AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read()}[/]"); + AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read()}[/]"); + } + + private class GuidedSampler(Conversation? guidance) + : BaseSamplingPipeline + { + public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) + { + } + + public override ISamplingPipeline Clone() + { + throw new NotSupportedException(); + } + + protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx) + { + return Array.Empty(); + } + + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + if (guidance != null) + { + // Get the logits generated by the guidance sequences + var guidanceLogits = guidance.Sample(); + + // Use those logits to guide this sequence + NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, 2); + } + } + + protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + candidates.Temperature(ctx, 0.8f); + candidates.TopK(ctx, 25); + + return candidates.SampleToken(ctx); + } + } +} \ No newline at end of file From 8526d6bb1cda55fd69e0c2fbc10d930252232d4a Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 24 Feb 2024 20:12:40 +0000 Subject: [PATCH 04/10] fixed comment, "classifier free" not "context free" --- LLama.Examples/Examples/BatchedExecutorGuidance.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs index 130c4dfdf..b1a13687c 100644 --- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -8,7 +8,7 @@ namespace LLama.Examples.Examples; /// /// This demonstrates using a batch to generate two sequences and then using one -/// sequence as the negative guidance ("context free guidance") for the other. +/// sequence as the negative guidance ("classifier free guidance") for the other. /// public class BatchedExecutorGuidance { From 879c62e618657e905583decb371cdd6aed01ea37 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 25 Feb 2024 17:55:29 +0000 Subject: [PATCH 05/10] Rebased onto master and fixed breakage due to changes in `BaseSamplingPipeline` --- .../Examples/BatchedExecutorGuidance.cs | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs index b1a13687c..12c22a1b8 100644 --- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -97,21 +97,20 @@ public override ISamplingPipeline Clone() throw new NotSupportedException(); } - protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx) + protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) { - return Array.Empty(); - } + if (guidance == null) + return logits; - protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { - if (guidance != null) - { - // Get the logits generated by the guidance sequences - var guidanceLogits = guidance.Sample(); + var logitsCopy = logits.ToArray(); + + // Get the logits generated by the guidance sequences + var guidanceLogits = guidance.Sample(); + + // Use those logits to guide this sequence + NativeApi.llama_sample_apply_guidance(ctx, logitsCopy, guidanceLogits, 2); - // Use those logits to guide this sequence - NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, 2); - } + return logitsCopy; } protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) From d509968f3cb77d84cf0810f6f340e21dee864b5e Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 25 Feb 2024 18:53:55 +0000 Subject: [PATCH 06/10] Asking user for guidance weight --- LLama.Examples/Examples/BatchedExecutorGuidance.cs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs index 12c22a1b8..38064e1b0 100644 --- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -23,6 +23,7 @@ public static async Task Run() var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim(); var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim(); + var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f); // Create an executor that can evaluate a batch of conversations together var executor = new BatchedExecutor(model, parameters); @@ -45,9 +46,9 @@ await AnsiConsole var unguided = guided.Fork(); // Run inference loop - var unguidedSampler = new GuidedSampler(null); + var unguidedSampler = new GuidedSampler(null, weight); var unguidedDecoder = new StreamingTokenDecoder(executor.Context); - var guidedSampler = new GuidedSampler(guidance); + var guidedSampler = new GuidedSampler(guidance, weight); var guidedDecoder = new StreamingTokenDecoder(executor.Context); await AnsiConsole .Progress() @@ -85,7 +86,7 @@ await AnsiConsole AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read()}[/]"); } - private class GuidedSampler(Conversation? guidance) + private class GuidedSampler(Conversation? guidance, float weight) : BaseSamplingPipeline { public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) @@ -108,7 +109,7 @@ protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, var guidanceLogits = guidance.Sample(); // Use those logits to guide this sequence - NativeApi.llama_sample_apply_guidance(ctx, logitsCopy, guidanceLogits, 2); + NativeApi.llama_sample_apply_guidance(ctx, logitsCopy, guidanceLogits, weight); return logitsCopy; } From d50fe797bf8eec8df08f85e37bd9617f1ddd0101 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 25 Feb 2024 19:17:45 +0000 Subject: [PATCH 07/10] Progress bar in batched fork demo --- .../Examples/BatchedExecutorFork.cs | 43 +++++++++++++------ .../Examples/BatchedExecutorGuidance.cs | 1 + 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorFork.cs b/LLama.Examples/Examples/BatchedExecutorFork.cs index 861eecc75..3887b4243 100644 --- a/LLama.Examples/Examples/BatchedExecutorFork.cs +++ b/LLama.Examples/Examples/BatchedExecutorFork.cs @@ -32,24 +32,39 @@ public static async Task Run() // Evaluate the initial prompt to create one conversation var start = executor.Prompt(prompt); - await executor.Infer(); + + // Run inference to evaluate prompts + await AnsiConsole + .Status() + .Spinner(Spinner.Known.Line) + .StartAsync("Evaluating Prompt...", _ => executor.Infer()); // Create the root node of the tree var root = new Node(start); - // Run inference loop - for (var i = 0; i < n_len; i++) - { - if (i != 0) - await executor.Infer(); - - // Occasionally fork all the active conversations - if (i != 0 && i % n_split == 0) - root.Split(); - - // Sample all active conversations - root.Sample(); - } + await AnsiConsole + .Progress() + .StartAsync(async progress => + { + var reporter = progress.AddTask("Running Inference", maxValue: n_len); + + // Run inference loop + for (var i = 0; i < n_len; i++) + { + if (i != 0) + await executor.Infer(); + + // Occasionally fork all the active conversations + if (i != 0 && i % n_split == 0) + root.Split(); + + // Sample all active conversations + root.Sample(); + + // Update progress bar + reporter.Increment(1); + } + }); Console.WriteLine($"{prompt}..."); root.Print(1); diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs index 38064e1b0..aa4f3156e 100644 --- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -78,6 +78,7 @@ await AnsiConsole if (g == model.EndOfSentenceToken) break; + // Update progress bar reporter.Increment(1); } }); From cec4d658de8c990b9e140ed05a93df636dd87c2f Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Mon, 26 Feb 2024 00:17:40 +0000 Subject: [PATCH 08/10] Improved fork example (using tree display) --- .../Examples/BatchedExecutorFork.cs | 82 +++++++++---------- 1 file changed, 38 insertions(+), 44 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorFork.cs b/LLama.Examples/Examples/BatchedExecutorFork.cs index 3887b4243..9389755c5 100644 --- a/LLama.Examples/Examples/BatchedExecutorFork.cs +++ b/LLama.Examples/Examples/BatchedExecutorFork.cs @@ -12,7 +12,7 @@ namespace LLama.Examples.Examples; public class BatchedExecutorFork { private const int n_split = 16; - private const int n_len = 64; + private const int n_len = 72; public static async Task Run() { @@ -32,48 +32,43 @@ public static async Task Run() // Evaluate the initial prompt to create one conversation var start = executor.Prompt(prompt); - - // Run inference to evaluate prompts - await AnsiConsole - .Status() - .Spinner(Spinner.Known.Line) - .StartAsync("Evaluating Prompt...", _ => executor.Infer()); + await executor.Infer(); // Create the root node of the tree var root = new Node(start); await AnsiConsole - .Progress() - .StartAsync(async progress => - { - var reporter = progress.AddTask("Running Inference", maxValue: n_len); - - // Run inference loop - for (var i = 0; i < n_len; i++) - { - if (i != 0) - await executor.Infer(); - - // Occasionally fork all the active conversations - if (i != 0 && i % n_split == 0) - root.Split(); - - // Sample all active conversations - root.Sample(); - - // Update progress bar - reporter.Increment(1); - } - }); - - Console.WriteLine($"{prompt}..."); - root.Print(1); - - Console.WriteLine("Press any key to exit demo"); - Console.ReadKey(true); + .Progress() + .StartAsync(async progress => + { + var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len); + + // Run inference loop + for (var i = 0; i < n_len; i++) + { + if (i != 0) + await executor.Infer(); + + // Occasionally fork all the active conversations + if (i != 0 && i % n_split == 0) + root.Split(); + + // Sample all active conversations + root.Sample(); + + // Update progress bar + reporter.Increment(1); + reporter.Description($"Running Inference ({root.ActiveConversationCount})"); + } + + // Display results + var display = new Tree(prompt); + root.Display(display); + AnsiConsole.Write(display); + }); } - class Node + private class Node { private readonly StreamingTokenDecoder _decoder; @@ -131,19 +126,18 @@ public void Split() } } - public void Print(int indendation) + public void Display(T tree, int depth = 0) + where T : IHasTreeNodes { - var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White }; - Console.ForegroundColor = colors[indendation % colors.Length]; + var colors = new[] { "red", "green", "blue", "yellow", "white" }; + var color = colors[depth % colors.Length]; var message = _decoder.Read().ReplaceLineEndings(""); - var prefix = new string(' ', indendation * 3); - var suffix = _conversation == null ? "..." : ""; - Console.WriteLine($"{prefix}...{message}{suffix}"); + var n = tree.AddNode($"[{color}]{message}[/]"); - _left?.Print(indendation + 2); - _right?.Print(indendation + 2); + _left?.Display(n, depth + 1); + _right?.Display(n, depth + 1); } } } \ No newline at end of file From 29e043dc1cf8e68257a5ebb126cd79b4695b4228 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Mon, 26 Feb 2024 13:47:13 +0000 Subject: [PATCH 09/10] Added proper disposal of resources in batched examples --- LLama.Examples/Examples/BatchedExecutorFork.cs | 4 ++-- LLama.Examples/Examples/BatchedExecutorGuidance.cs | 12 ++++++------ LLama.Examples/Examples/BatchedExecutorRewind.cs | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorFork.cs b/LLama.Examples/Examples/BatchedExecutorFork.cs index 9389755c5..3366951a1 100644 --- a/LLama.Examples/Examples/BatchedExecutorFork.cs +++ b/LLama.Examples/Examples/BatchedExecutorFork.cs @@ -24,14 +24,14 @@ public static async Task Run() var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); // Create an executor that can evaluate a batch of conversations together - var executor = new BatchedExecutor(model, parameters); + using var executor = new BatchedExecutor(model, parameters); // Print some info var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); Console.WriteLine($"Created executor with model: {name}"); // Evaluate the initial prompt to create one conversation - var start = executor.Prompt(prompt); + using var start = executor.Prompt(prompt); await executor.Infer(); // Create the root node of the tree diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs index aa4f3156e..43f3471d6 100644 --- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -26,15 +26,15 @@ public static async Task Run() var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f); // Create an executor that can evaluate a batch of conversations together - var executor = new BatchedExecutor(model, parameters); + using var executor = new BatchedExecutor(model, parameters); // Print some info var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); Console.WriteLine($"Created executor with model: {name}"); // Load the two prompts into two conversations - var guided = executor.Prompt(positivePrompt); - var guidance = executor.Prompt(negativePrompt); + using var guided = executor.Prompt(positivePrompt); + using var guidance = executor.Prompt(negativePrompt); // Run inference to evaluate prompts await AnsiConsole @@ -43,7 +43,7 @@ await AnsiConsole .StartAsync("Evaluating Prompts...", _ => executor.Infer()); // Fork the "guided" conversation. We'll run this one without guidance for comparison - var unguided = guided.Fork(); + using var unguided = guided.Fork(); // Run inference loop var unguidedSampler = new GuidedSampler(null, weight); @@ -62,12 +62,12 @@ await AnsiConsole await executor.Infer(); // Sample from the "unguided" conversation - var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample().ToArray(), Array.Empty()); + var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty()); unguidedDecoder.Add(u); unguided.Prompt(u); // Sample form the "guided" conversation - var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample().ToArray(), Array.Empty()); + var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty()); guidedDecoder.Add(g); // Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt). diff --git a/LLama.Examples/Examples/BatchedExecutorRewind.cs b/LLama.Examples/Examples/BatchedExecutorRewind.cs index 54c6e5f90..4007dd508 100644 --- a/LLama.Examples/Examples/BatchedExecutorRewind.cs +++ b/LLama.Examples/Examples/BatchedExecutorRewind.cs @@ -25,14 +25,14 @@ public static async Task Run() var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); // Create an executor that can evaluate a batch of conversations together - var executor = new BatchedExecutor(model, parameters); + using var executor = new BatchedExecutor(model, parameters); // Print some info var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); Console.WriteLine($"Created executor with model: {name}"); // Evaluate the initial prompt to create one conversation - var conversation = executor.Prompt(prompt); + using var conversation = executor.Prompt(prompt); // Create the start node wrapping the conversation var node = new Node(executor.Context); From 6c8dea9f027d125a690b55d9fbafcc60f40dfa7f Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Mon, 26 Feb 2024 13:57:42 +0000 Subject: [PATCH 10/10] Added some more comments in BatchedExecutorGuidance --- LLama.Examples/Examples/BatchedExecutorGuidance.cs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs index 43f3471d6..1c2c4b49c 100644 --- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -61,12 +61,14 @@ await AnsiConsole if (i != 0) await executor.Infer(); - // Sample from the "unguided" conversation + // Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any + // guidance. This serves as a comparison to show the effect of guidance. var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty()); unguidedDecoder.Add(u); unguided.Prompt(u); - // Sample form the "guided" conversation + // Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation + // to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file). var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty()); guidedDecoder.Add(g);