Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ protected virtual void TryReuseMatchingPrefix()
/// </summary>
/// <param name="text"></param>
/// <param name="args"></param>
protected abstract Task PreprocessInputs(string text, InferStateArgs args);
protected abstract Task PreprocessInputs(string? text, InferStateArgs args);

/// <summary>
/// Do some post processing after the inference.
Expand Down Expand Up @@ -296,11 +296,11 @@ protected virtual void TryReuseMatchingPrefix()
/// <summary>
/// Execute the inference.
/// </summary>
/// <param name="text"></param>
/// <param name="text">The prompt. If null, generation will continue where it left off previously.</param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
inferenceParams ??= new InferenceParams();
Expand Down
30 changes: 19 additions & 11 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,30 +116,38 @@ protected override Task<bool> GetLoopCondition(InferStateArgs args)
}

/// <inheritdoc />
protected override Task PreprocessInputs(string text, InferStateArgs args)
protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
args.Antiprompts ??= [ ];
args.Antiprompts.Add(_instructionPrefix);
if (!args.Antiprompts.Contains(_instructionPrefix))
args.Antiprompts.Add(_instructionPrefix);

if (_is_prompt_run)
{
// When running the first input (prompt) in inteactive mode, we should specially process it.
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
_embed_inps = Context.Tokenize(text, true, true).ToList();
}
else
{
if (!text.EndsWith("\n"))
{
text += "\n";
}
_consumedTokensCount = _embed_inps.Count;
_embed_inps.AddRange(_inp_pfx);

var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
// Don't append the template tokens if continuation is requested (by providing a null prompt)
if (text != null)
{
if (!text.EndsWith("\n"))
{
text += "\n";
}
_embed_inps.AddRange(_inp_pfx);

var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);

_embed_inps.AddRange(_inp_sfx);
_embed_inps.AddRange(_inp_sfx);

args.RemainedTokens -= line_inp.Length;
args.RemainedTokens -= line_inp.Length;
}
}

return Task.CompletedTask;
Expand Down
31 changes: 18 additions & 13 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,12 @@ protected override Task<bool> GetLoopCondition(InferStateArgs args)
}

/// <inheritdoc />
protected override Task PreprocessInputs(string text, InferStateArgs args)
protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
if (_is_prompt_run)
{
// When running the first input (prompt) in interactive mode, we should specially process it.
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
if (!this.IsMultiModal)
{
_embed_inps = Context.Tokenize(text, true, true).ToList();
Expand All @@ -127,20 +128,24 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
}
else
{
if (!text.EndsWith("\n"))
// Don't add any tokens if continuation is requested (by providing a null prompt)
if (text != null)
{
text += "\n";
}
if (!text.EndsWith("\n"))
{
text += "\n";
}

if (!this.IsMultiModal)
{
var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
else
{
PreprocessLlava(text, args, false);
if (!this.IsMultiModal)
{
var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
else
{
PreprocessLlava(text, args, false);
}
}
}

Expand Down