From 9052e41d68909307c1c62cb14d2bcb62fe4a6efc Mon Sep 17 00:00:00 2001 From: DeProgrammer Date: Fri, 12 Jul 2024 17:16:01 -0500 Subject: [PATCH] Allow continuation in Instruct and Interact executors; fix a minor leak --- LLama/LLamaExecutorBase.cs | 6 +++--- LLama/LLamaInstructExecutor.cs | 28 +++++++++++++++++----------- LLama/LLamaInteractExecutor.cs | 31 ++++++++++++++++++------------- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index e01a40ccc..9b2b17617 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -251,7 +251,7 @@ protected virtual void TryReuseMatchingPrefix() /// /// /// - protected abstract Task PreprocessInputs(string text, InferStateArgs args); + protected abstract Task PreprocessInputs(string? text, InferStateArgs args); /// /// Do some post processing after the inference. @@ -296,11 +296,11 @@ protected virtual void TryReuseMatchingPrefix() /// /// Execute the inference. /// - /// + /// The prompt. If null, generation will continue where it left off previously. /// /// /// - public virtual async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public virtual async IAsyncEnumerable InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); inferenceParams ??= new InferenceParams(); diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index adc82eb85..f4aa48aa5 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -117,30 +117,36 @@ protected override Task GetLoopCondition(InferStateArgs args) } /// - protected override Task PreprocessInputs(string text, InferStateArgs args) + protected override Task PreprocessInputs(string? text, InferStateArgs args) { args.Antiprompts ??= new List(); - args.Antiprompts.Add(_instructionPrefix); + if (!args.Antiprompts.Contains(_instructionPrefix)) args.Antiprompts.Add(_instructionPrefix); if (_is_prompt_run) { // When running the first input (prompt) in inteactive mode, we should specially process it. + if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously."); _embed_inps = Context.Tokenize(text, true, true).ToList(); } else { - if (!text.EndsWith("\n")) - { - text += "\n"; - } _consumedTokensCount = _embed_inps.Count; - _embed_inps.AddRange(_inp_pfx); - var line_inp = Context.Tokenize(text, false, true); - _embed_inps.AddRange(line_inp); + // Don't append the template tokens if continuation is requested (by providing a null prompt) + if (text != null) + { + if (!text.EndsWith("\n")) + { + text += "\n"; + } + _embed_inps.AddRange(_inp_pfx); + + var line_inp = Context.Tokenize(text, false, true); + _embed_inps.AddRange(line_inp); - _embed_inps.AddRange(_inp_sfx); + _embed_inps.AddRange(_inp_sfx); - args.RemainedTokens -= line_inp.Length; + args.RemainedTokens -= line_inp.Length; + } } return Task.CompletedTask; diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 3f3f4a41e..687209769 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -112,11 +112,12 @@ protected override Task GetLoopCondition(InferStateArgs args) } /// - protected override Task PreprocessInputs(string text, InferStateArgs args) + protected override Task PreprocessInputs(string? text, InferStateArgs args) { if (_is_prompt_run) { // When running the first input (prompt) in interactive mode, we should specially process it. + if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously."); if (!this.IsMultiModal) { _embed_inps = Context.Tokenize(text, true, true).ToList(); @@ -128,20 +129,24 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) } else { - if (!text.EndsWith("\n")) + // Don't add any tokens if continuation is requested (by providing a null prompt) + if (text != null) { - text += "\n"; - } + if (!text.EndsWith("\n")) + { + text += "\n"; + } - if (!this.IsMultiModal) - { - var line_inp = Context.Tokenize(text, false, true); - _embed_inps.AddRange(line_inp); - args.RemainedTokens -= line_inp.Length; - } - else - { - PreprocessLlava(text, args, false); + if (!this.IsMultiModal) + { + var line_inp = Context.Tokenize(text, false, true); + _embed_inps.AddRange(line_inp); + args.RemainedTokens -= line_inp.Length; + } + else + { + PreprocessLlava(text, args, false); + } } }