Skip to content

Commit 8359583

Browse files
committed
- Removed the object wrappers and configurable pipeline, they can be better written in code.
- Added BaseSamplingPipeline which provides a base impl of `ISamplingPipeline` - Added `DefaultSamplingPipeline` which mimics normal llama.cpp sampling
1 parent 3afc007 commit 8359583

22 files changed

+309
-844
lines changed

LLama.Unittest/GrammarParserTest.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using System.Text;
2-
using LLama.Exceptions;
1+
using LLama.Exceptions;
32
using LLama.Native;
43
using LLama.Grammars;
54

LLama.Unittest/StatelessExecutorTest.cs

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
using System.Diagnostics;
22
using LLama.Common;
33
using LLama.Sampling;
4-
using LLama.Sampling.Logits;
5-
using LLama.Sampling.Selection;
6-
using LLama.Sampling.Tokens;
74
using Xunit.Abstractions;
85

96
namespace LLama.Unittest
@@ -35,40 +32,12 @@ public void Dispose()
3532
public async Task Stateless()
3633
{
3734
// Create a custom pipeline that mimics the default pipeline
38-
var pipeline = new ConfigurableSamplingPipeline()
39-
{
40-
ProtectedLogits =
41-
{
42-
_weights.NewlineToken,
43-
_weights.BeginningOfSentenceToken,
44-
_weights.EndOfSentenceToken
45-
},
46-
LogitProcessors =
47-
{
48-
new LogitBias
49-
{
50-
Biases =
51-
{
52-
{ _weights.NewlineToken, 1000 }, // This is an insane bias, but because newline is a protected logit it will do nothing!
53-
{ 42, 0f },
54-
}
55-
}
56-
},
57-
TokenDataProcessors =
58-
{
59-
new TailFreeSampling { Z = 1 },
60-
new LocallyTypicalSampling { P = 1 },
61-
new TopPSampling { P = 0.95f },
62-
new MinPSampling { P = 0.05f },
63-
new TemperatureSampling { Temperature = 0.8f },
64-
},
65-
Selector = new StandardSelection(),
66-
};
35+
var pipeline = new DefaultSamplingPipeline();
6736

6837
var executor = new StatelessExecutor(_weights, _params);
6938

7039
const string question = "Question. what is a cat?\nAnswer: ";
71-
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline};
40+
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };
7241

7342
var timer = new Stopwatch();
7443
timer.Start();

LLama/Native/LLamaTokenDataArray.cs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,41 @@ public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
4646
return new LLamaTokenDataArray(candidates);
4747
}
4848

49+
/// <summary>
50+
/// Overwrite the logit values for all given tokens
51+
/// </summary>
52+
/// <param name="values">tuples of token and logit value to overwrite</param>
53+
public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values)
54+
{
55+
if (values.Length == 0)
56+
return;
57+
58+
var dataSpan = data.Span;
59+
foreach (var (token, value) in values)
60+
{
61+
for (var i = 0; i < data.Length; i++)
62+
{
63+
if (dataSpan[i].id == token)
64+
{
65+
dataSpan[i].logit = value;
66+
break;
67+
}
68+
}
69+
}
70+
sorted = false;
71+
}
72+
4973
#region sampling
5074
/// <summary>
5175
/// Apply grammar rules to candidate tokens
5276
/// </summary>
5377
/// <param name="ctx"></param>
5478
/// <param name="grammar"></param>
55-
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
79+
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar)
5680
{
81+
if (grammar == null)
82+
return;
83+
5784
using (LLamaTokenDataArrayNative.Create(this, out var st))
5885
{
5986
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
using System;
2+
using System.Buffers;
3+
using System.Collections.Generic;
4+
using LLama.Native;
5+
6+
namespace LLama.Sampling;
7+
8+
/// <summary>
9+
/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`.
10+
/// </summary>
11+
public abstract class BaseSamplingPipeline
12+
: ISamplingPipeline
13+
{
14+
private int _savedLogitsCount;
15+
private (int index, float logit)[]? _savedLogits;
16+
17+
/// <inheritdoc/>
18+
public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
19+
{
20+
var protectedLogits = GetProtectedTokens(ctx);
21+
_savedLogitsCount = protectedLogits.Count;
22+
_savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount);
23+
try
24+
{
25+
// Save the values of protected logits
26+
for (var i = 0; i < protectedLogits.Count; i++)
27+
{
28+
var index = protectedLogits[i];
29+
var value = logits[index];
30+
_savedLogits[i] = (index, value);
31+
}
32+
33+
// Process raw logits
34+
ProcessLogits(ctx, logits, lastTokens);
35+
36+
// Automatically restore saved logit values after processing
37+
RestoreProtectedTokens(logits);
38+
39+
// Convert logits into token candidates
40+
var candidates = LLamaTokenDataArray.Create(logits);
41+
42+
// Process token data array
43+
ProcessTokenDataArray(ctx, candidates, lastTokens);
44+
45+
// Choose the final value
46+
return ChooseToken(ctx, candidates);
47+
}
48+
finally
49+
{
50+
ArrayPool<(int, float)>.Shared.Return(_savedLogits);
51+
_savedLogits = null;
52+
_savedLogitsCount = 0;
53+
}
54+
}
55+
56+
#region protected tokens
57+
/// <summary>
58+
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
59+
/// </summary>
60+
/// <returns></returns>
61+
protected abstract IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx);
62+
63+
/// <summary>
64+
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
65+
/// </summary>
66+
/// <param name="logits"></param>
67+
protected void RestoreProtectedTokens(Span<float> logits)
68+
{
69+
if (_savedLogits == null)
70+
return;
71+
72+
// The array may be bigger than necessary, get a span of the valid bit
73+
var saved = _savedLogits.AsSpan(0, _savedLogitsCount);
74+
75+
// Restore the values of protected logits
76+
for (var i = 0; i < saved.Length; i++)
77+
logits[saved[i].index] = saved[i].logit;
78+
}
79+
80+
/// <summary>
81+
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
82+
/// </summary>
83+
/// <param name="candidates"></param>
84+
protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
85+
{
86+
if (_savedLogits == null || _savedLogits.Length == 0)
87+
return;
88+
89+
candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount));
90+
}
91+
#endregion
92+
93+
/// <summary>
94+
/// Process the raw logit values
95+
/// </summary>
96+
/// <param name="ctx">The context being sampled from</param>
97+
/// <param name="logits">The logits produced by the model</param>
98+
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
99+
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);
100+
101+
/// <summary>
102+
/// Process the LLamaTokenDataArray and select a single token
103+
/// </summary>
104+
/// <param name="ctx">The context being sampled from</param>
105+
/// <param name="candidates">The LLamaTokenDataArray data produced by the model</param>
106+
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
107+
/// <returns></returns>
108+
protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens);
109+
110+
/// <summary>
111+
/// Choose the final token from the candidates
112+
/// </summary>
113+
/// <param name="ctx"></param>
114+
/// <param name="candidates"></param>
115+
/// <returns></returns>
116+
protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);
117+
118+
/// <inheritdoc/>
119+
public virtual void Reset()
120+
{
121+
}
122+
123+
/// <inheritdoc/>
124+
public virtual void Dispose()
125+
{
126+
GC.SuppressFinalize(this);
127+
}
128+
}
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using LLama.Extensions;
4+
using LLama.Native;
5+
6+
namespace LLama.Sampling;
7+
8+
/// <summary>
9+
/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling
10+
/// </summary>
11+
public sealed class DefaultSamplingPipeline
12+
: BaseSamplingPipeline
13+
{
14+
/// <summary>
15+
/// Bias values to add to certain logits
16+
/// </summary>
17+
public Dictionary<int, float> LogitBias { get; } = new();
18+
19+
/// <summary>
20+
/// Grammar to constrain valid tokens
21+
/// </summary>
22+
public SafeLLamaGrammarHandle? Grammar { get; set; }
23+
24+
/// <summary>
25+
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
26+
/// </summary>
27+
public float RepeatPenalty { get; set; } = 1.1f;
28+
29+
/// <summary>
30+
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
31+
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
32+
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
33+
/// </summary>
34+
public float AlphaFrequency
35+
{
36+
get => _alphaFreq;
37+
set
38+
{
39+
if (value < -2)
40+
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
41+
if (value > 2)
42+
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
43+
_alphaFreq = value;
44+
}
45+
}
46+
private float _alphaFreq = 0.1f;
47+
48+
/// <summary>
49+
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
50+
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
51+
/// text so far, increasing the model's likelihood to talk about new topics.
52+
/// </summary>
53+
public float AlphaPresence
54+
{
55+
get => _alphaPresence;
56+
set
57+
{
58+
if (value < -2)
59+
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
60+
if (value > 2)
61+
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
62+
_alphaPresence = value;
63+
}
64+
}
65+
private float _alphaPresence = 0.1f;
66+
67+
/// <summary>
68+
/// Temperature to apply (higher temperature is more "creative")
69+
/// </summary>
70+
public float Temperature { get; set; } = 0.75f;
71+
72+
/// <summary>
73+
/// Number of tokens to keep in TopK sampling
74+
/// </summary>
75+
public int TopK { get; set; }
76+
77+
/// <summary>
78+
/// Z value for tail free sampling
79+
/// </summary>
80+
public float TailFreeZ { get; set; }
81+
82+
/// <summary>
83+
/// P value for locally typical sampling
84+
/// </summary>
85+
public float TypicalP { get; set; }
86+
87+
/// <summary>
88+
/// P value for TopP sampling
89+
/// </summary>
90+
public float TopP { get; set; } = 1f;
91+
92+
/// <summary>
93+
/// P value for MinP sampling
94+
/// </summary>
95+
public float MinP { get; set; }
96+
97+
/// <summary>
98+
/// Whether the newline value should be protected from being modified by logit bias and repeat penalty
99+
/// </summary>
100+
public bool PenalizeNewline { get; set; } = false;
101+
102+
private readonly int[] _newlineToken = new int[1];
103+
104+
/// <inheritdoc />
105+
protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx)
106+
{
107+
if (PenalizeNewline)
108+
return Array.Empty<int>();
109+
110+
_newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
111+
return _newlineToken;
112+
}
113+
114+
/// <inheritdoc />
115+
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
116+
{
117+
foreach (var (key, value) in LogitBias)
118+
logits[key] += value;
119+
}
120+
121+
/// <inheritdoc />
122+
protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
123+
{
124+
// Apply penalties to candidates
125+
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
126+
127+
// Restore protected tokens, so they are not affected by repetition penalties
128+
RestoreProtectedTokens(candidates);
129+
130+
// Apply the normal llama.cpp pipeline
131+
candidates.ApplyGrammar(ctx, Grammar);
132+
candidates.TopK(ctx, TopK);
133+
candidates.TailFree(ctx, TailFreeZ);
134+
candidates.LocallyTypical(ctx, TypicalP);
135+
candidates.TopP(ctx, TopP);
136+
candidates.MinP(ctx, MinP);
137+
candidates.Temperature(ctx, Temperature);
138+
var id = candidates.SampleToken(ctx);
139+
140+
Grammar?.AcceptToken(ctx, id);
141+
return id;
142+
}
143+
144+
/// <inheritdoc />
145+
protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
146+
{
147+
return candidates.SampleToken(ctx);
148+
}
149+
}

0 commit comments

Comments
 (0)