Skip to content

Commit d87d654

Browse files
authored
Merge pull request #348 from martindevans/new_object_based_sampling_pipeline
Custom Sampling Pipelines
2 parents 50c1b2d + 8359583 commit d87d654

13 files changed

+462
-39
lines changed

LLama.Unittest/GrammarParserTest.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using System.Text;
2-
using LLama.Exceptions;
1+
using LLama.Exceptions;
32
using LLama.Native;
43
using LLama.Grammars;
54

LLama.Unittest/StatelessExecutorTest.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Diagnostics;
22
using LLama.Common;
3+
using LLama.Sampling;
34
using Xunit.Abstractions;
45

56
namespace LLama.Unittest
@@ -30,10 +31,13 @@ public void Dispose()
3031
[Fact]
3132
public async Task Stateless()
3233
{
34+
// Create a custom pipeline that mimics the default pipeline
35+
var pipeline = new DefaultSamplingPipeline();
36+
3337
var executor = new StatelessExecutor(_weights, _params);
3438

3539
const string question = "Question. what is a cat?\nAnswer: ";
36-
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
40+
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };
3741

3842
var timer = new Stopwatch();
3943
timer.Start();

LLama.Web/Common/InferenceOptions.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
using LLama.Common;
1+
#nullable enable
2+
3+
using LLama.Common;
24
using LLama.Abstractions;
35
using LLama.Native;
6+
using LLama.Sampling;
47

58
namespace LLama.Web.Common
69
{
@@ -64,6 +67,9 @@ public class InferenceOptions
6467
/// <summary>
6568
/// A grammar to constrain possible tokens
6669
/// </summary>
67-
public SafeLLamaGrammarHandle Grammar { get; set; } = null;
70+
public SafeLLamaGrammarHandle? Grammar { get; set; }
71+
72+
/// <inheritdoc />
73+
public ISamplingPipeline? SamplingPipeline { get; set; }
6874
}
6975
}

LLama/Abstractions/IInferenceParams.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Collections.Generic;
22
using LLama.Common;
33
using LLama.Native;
4+
using LLama.Sampling;
45

56
namespace LLama.Abstractions
67
{
@@ -108,5 +109,10 @@ public interface IInferenceParams
108109
/// Grammar to constrain possible tokens
109110
/// </summary>
110111
SafeLLamaGrammarHandle? Grammar { get; set; }
112+
113+
/// <summary>
114+
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
115+
/// </summary>
116+
ISamplingPipeline? SamplingPipeline { get; set; }
111117
}
112118
}

LLama/Common/InferenceParams.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using System.Collections.Generic;
44
using LLama.Native;
5+
using LLama.Sampling;
56

67
namespace LLama.Common
78
{
@@ -76,6 +77,9 @@ public record InferenceParams
7677

7778
/// <inheritdoc />
7879
public SafeLLamaGrammarHandle? Grammar { get; set; }
80+
81+
/// <inheritdoc />
82+
public ISamplingPipeline? SamplingPipeline { get; set; }
7983
}
8084

8185
/// <summary>

LLama/LLamaContext.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using System.Runtime.InteropServices;
1111
using LLama.Extensions;
1212
using LLama.Abstractions;
13+
using LLama.Sampling;
1314
using Microsoft.Extensions.Logging;
1415

1516
namespace LLama
@@ -212,6 +213,17 @@ public void LoadState(State state)
212213
}
213214
}
214215

216+
/// <summary>
217+
/// Sample a single token from this context, using the given sampling pipeline
218+
/// </summary>
219+
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
220+
/// <param name="lastTokens">The tokens recently returned from the model</param>
221+
/// <returns>The selected token</returns>
222+
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
223+
{
224+
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
225+
}
226+
215227
/// <summary>
216228
/// Perform the sampling. Please don't use it unless you fully know what it does.
217229
/// </summary>

LLama/LLamaInstructExecutor.cs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,16 +210,24 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
210210
SaveSessionFile(_pathSession);
211211
}
212212

213-
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
214-
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
213+
llama_token id;
214+
if (inferenceParams.SamplingPipeline is not null)
215+
{
216+
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
217+
}
218+
else
219+
{
220+
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
221+
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
215222

216-
var mu = MirostatMu;
217-
var id = Context.Sample(
218-
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
219-
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
220-
inferenceParams.MinP
221-
);
222-
MirostatMu = mu;
223+
var mu = MirostatMu;
224+
id = Context.Sample(
225+
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
226+
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
227+
inferenceParams.MinP
228+
);
229+
MirostatMu = mu;
230+
}
223231

224232
_last_n_tokens.Enqueue(id);
225233

LLama/LLamaInteractExecutor.cs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,24 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
189189
SaveSessionFile(_pathSession);
190190
}
191191

192-
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
193-
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
194-
195-
var mu = MirostatMu;
196-
var id = Context.Sample(
197-
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
198-
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
199-
inferenceParams.MinP
200-
);
201-
MirostatMu = mu;
192+
llama_token id;
193+
if (inferenceParams.SamplingPipeline is not null)
194+
{
195+
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
196+
}
197+
else
198+
{
199+
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
200+
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
201+
202+
var mu = MirostatMu;
203+
id = Context.Sample(
204+
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
205+
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
206+
inferenceParams.MinP
207+
);
208+
MirostatMu = mu;
209+
}
202210

203211
_last_n_tokens.Enqueue(id);
204212

LLama/LLamaStatelessExecutor.cs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Threading;
88
using System.Threading.Tasks;
99
using LLama.Native;
10+
using LLama.Sampling;
1011
using Microsoft.Extensions.Logging;
1112

1213
namespace LLama
@@ -85,16 +86,24 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
8586
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
8687
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
8788
{
88-
// Penalize the generated tokens by various penalties
89-
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
90-
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
91-
92-
// Sample a single token
93-
var id = Context.Sample(
94-
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
95-
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
96-
inferenceParams.MinP
97-
);
89+
llama_token id;
90+
if (inferenceParams.SamplingPipeline is not null)
91+
{
92+
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
93+
}
94+
else
95+
{
96+
// Penalize the generated tokens by various penalties
97+
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
98+
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
99+
100+
// Sample a single token
101+
id = Context.Sample(
102+
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
103+
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
104+
inferenceParams.MinP
105+
);
106+
}
98107

99108
// Decode this token into text
100109
decoder.Add(id);

LLama/Native/LLamaTokenDataArray.cs

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,41 @@ public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
4646
return new LLamaTokenDataArray(candidates);
4747
}
4848

49+
/// <summary>
50+
/// Overwrite the logit values for all given tokens
51+
/// </summary>
52+
/// <param name="values">tuples of token and logit value to overwrite</param>
53+
public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values)
54+
{
55+
if (values.Length == 0)
56+
return;
57+
58+
var dataSpan = data.Span;
59+
foreach (var (token, value) in values)
60+
{
61+
for (var i = 0; i < data.Length; i++)
62+
{
63+
if (dataSpan[i].id == token)
64+
{
65+
dataSpan[i].logit = value;
66+
break;
67+
}
68+
}
69+
}
70+
sorted = false;
71+
}
72+
4973
#region sampling
5074
/// <summary>
5175
/// Apply grammar rules to candidate tokens
5276
/// </summary>
5377
/// <param name="ctx"></param>
5478
/// <param name="grammar"></param>
55-
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
79+
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar)
5680
{
81+
if (grammar == null)
82+
return;
83+
5784
using (LLamaTokenDataArrayNative.Create(this, out var st))
5885
{
5986
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
@@ -145,15 +172,17 @@ public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_ke
145172
/// <param name="penalty_repeat"></param>
146173
/// <param name="penalty_freq"></param>
147174
/// <param name="penalty_present"></param>
148-
public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
175+
public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
149176
{
150177
unsafe
151178
{
152179
using (LLamaTokenDataArrayNative.Create(this, out var st))
153-
using (var last_tokens_handle = last_tokens.Pin())
154180
{
155-
NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
156-
sorted = st.sorted;
181+
fixed (int* last_tokens_handle = last_tokens)
182+
{
183+
NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
184+
sorted = st.sorted;
185+
}
157186
}
158187
}
159188
}

0 commit comments

Comments
 (0)