Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ protected virtual void TryReuseMatchingPrefix()
/// </summary>
/// <param name="text"></param>
/// <param name="args"></param>
protected abstract Task PreprocessInputs(string text, InferStateArgs args);
protected abstract Task PreprocessInputs(string? text, InferStateArgs args);

/// <summary>
/// Do some post processing after the inference.
Expand Down Expand Up @@ -296,11 +296,11 @@ protected virtual void TryReuseMatchingPrefix()
/// <summary>
/// Execute the inference.
/// </summary>
/// <param name="text"></param>
/// <param name="text">The prompt. If null, generation will continue where it left off previously.</param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
inferenceParams ??= new InferenceParams();
Expand Down
28 changes: 17 additions & 11 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,30 +117,36 @@ protected override Task<bool> GetLoopCondition(InferStateArgs args)
}

/// <inheritdoc />
protected override Task PreprocessInputs(string text, InferStateArgs args)
protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
args.Antiprompts ??= new List<string>();
args.Antiprompts.Add(_instructionPrefix);
if (!args.Antiprompts.Contains(_instructionPrefix)) args.Antiprompts.Add(_instructionPrefix);
if (_is_prompt_run)
{
// When running the first input (prompt) in inteactive mode, we should specially process it.
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
_embed_inps = Context.Tokenize(text, true, true).ToList();
}
else
{
if (!text.EndsWith("\n"))
{
text += "\n";
}
_consumedTokensCount = _embed_inps.Count;
_embed_inps.AddRange(_inp_pfx);

var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
// Don't append the template tokens if continuation is requested (by providing a null prompt)
if (text != null)
{
if (!text.EndsWith("\n"))
{
text += "\n";
}
_embed_inps.AddRange(_inp_pfx);

var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);

_embed_inps.AddRange(_inp_sfx);
_embed_inps.AddRange(_inp_sfx);

args.RemainedTokens -= line_inp.Length;
args.RemainedTokens -= line_inp.Length;
}
}

return Task.CompletedTask;
Expand Down
31 changes: 18 additions & 13 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,12 @@ protected override Task<bool> GetLoopCondition(InferStateArgs args)
}

/// <inheritdoc />
protected override Task PreprocessInputs(string text, InferStateArgs args)
protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
if (_is_prompt_run)
{
// When running the first input (prompt) in interactive mode, we should specially process it.
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
if (!this.IsMultiModal)
{
_embed_inps = Context.Tokenize(text, true, true).ToList();
Expand All @@ -128,20 +129,24 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
}
else
{
if (!text.EndsWith("\n"))
// Don't add any tokens if continuation is requested (by providing a null prompt)
if (text != null)
{
text += "\n";
}
if (!text.EndsWith("\n"))
{
text += "\n";
}

if (!this.IsMultiModal)
{
var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
else
{
PreprocessLlava(text, args, false);
if (!this.IsMultiModal)
{
var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
else
{
PreprocessLlava(text, args, false);
}
}
}

Expand Down