Skip to content

Commit 402a110

Browse files
authored
Merge pull request #404 from martindevans/switched_to_LLamaToken_struct
LLamaToken Struct
2 parents d9b4e1f + 82727c4 commit 402a110

29 files changed

+196
-168
lines changed

LLama.Examples/Examples/BatchedDecoding.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using System.Diagnostics;
22
using System.Text;
3-
using LLama.Abstractions;
43
using LLama.Common;
54
using LLama.Native;
65

@@ -94,7 +93,7 @@ public static async Task Run()
9493
var n_cur = batch.NativeBatch.n_tokens;
9594
var n_decode = 0;
9695

97-
var streams = new List<int>[n_parallel];
96+
var streams = new List<LLamaToken>[n_parallel];
9897
for (var i = 0; i < n_parallel; i++)
9998
streams[i] = new();
10099

LLama.Unittest/LLamaContextTests.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using LLama.Common;
2+
using LLama.Native;
23

34
namespace LLama.Unittest
45
{
@@ -37,23 +38,23 @@ public void Tokenize()
3738
{
3839
var tokens = _context.Tokenize("The quick brown fox", true);
3940

40-
Assert.Equal(new[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
41+
Assert.Equal(new LLamaToken[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
4142
}
4243

4344
[Fact]
4445
public void TokenizeWithoutBOS()
4546
{
4647
var tokens = _context.Tokenize("The quick brown fox", false);
4748

48-
Assert.Equal(new[] { 450, 4996, 17354, 1701, 29916 }, tokens);
49+
Assert.Equal(new LLamaToken[] { 450, 4996, 17354, 1701, 29916 }, tokens);
4950
}
5051

5152
[Fact]
5253
public void TokenizeEmpty()
5354
{
5455
var tokens = _context.Tokenize("", false);
5556

56-
Assert.Equal(Array.Empty<int>(), tokens);
57+
Assert.Equal(Array.Empty<LLamaToken>(), tokens);
5758
}
5859
}
5960
}

LLama.Web/Common/InferenceOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public class InferenceOptions
1717
public int MaxTokens { get; set; } = -1;
1818

1919
/// <inheritdoc />
20-
public Dictionary<int, float>? LogitBias { get; set; } = null;
20+
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;
2121

2222
/// <inheritdoc />
2323
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();

LLama/Abstractions/IInferenceParams.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public interface IInferenceParams
2424
/// <summary>
2525
/// logit bias for specific tokens
2626
/// </summary>
27-
public Dictionary<int, float>? LogitBias { get; set; }
27+
public Dictionary<LLamaToken, float>? LogitBias { get; set; }
2828

2929
/// <summary>
3030
/// Sequences where the model will stop generating further tokens.

LLama/Common/InferenceParams.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
namespace LLama.Common
88
{
9-
using llama_token = Int32;
10-
119
/// <summary>
1210
/// The paramters used for inference.
1311
/// </summary>
@@ -28,7 +26,7 @@ public record InferenceParams
2826
/// <summary>
2927
/// logit bias for specific tokens
3028
/// </summary>
31-
public Dictionary<llama_token, float>? LogitBias { get; set; } = null;
29+
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;
3230

3331
/// <summary>
3432
/// Sequences where the model will stop generating further tokens.

LLama/Extensions/IReadOnlyListExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ internal static class IReadOnlyListExtensions
3838
/// <returns></returns>
3939
[Obsolete("Use an Antiprompt processor instead")]
4040
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
41-
where TTokens : IReadOnlyList<int>
41+
where TTokens : IReadOnlyList<LLamaToken>
4242
where TQueries : IReadOnlyList<string>
4343
{
4444
if (queries == null || queries.Count == 0 || tokens.Count == 0)
@@ -79,7 +79,7 @@ internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tok
7979
/// <returns></returns>
8080
[Obsolete("Use an Antiprompt processor instead")]
8181
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
82-
where TTokens : IReadOnlyList<int>
82+
where TTokens : IReadOnlyList<LLamaToken>
8383
{
8484
if (queries == null || queries.Count == 0 || tokens.Count == 0)
8585
return false;

LLama/LLamaContext.cs

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
namespace LLama
1717
{
18-
using llama_token = Int32;
19-
2018
/// <summary>
2119
/// A llama_context, which holds all the context required to interact with a model
2220
/// </summary>
@@ -93,7 +91,7 @@ public void SetSeed(uint seed)
9391
/// <param name="addBos">Whether to add a bos to the text.</param>
9492
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
9593
/// <returns></returns>
96-
public llama_token[] Tokenize(string text, bool addBos = true, bool special = false)
94+
public LLamaToken[] Tokenize(string text, bool addBos = true, bool special = false)
9795
{
9896
return NativeHandle.Tokenize(text, addBos, special, Encoding);
9997
}
@@ -104,7 +102,7 @@ public llama_token[] Tokenize(string text, bool addBos = true, bool special = fa
104102
/// <param name="tokens"></param>
105103
/// <returns></returns>
106104
[Obsolete("Use a `StreamingTokenDecoder` instead")]
107-
public string DeTokenize(IReadOnlyList<llama_token> tokens)
105+
public string DeTokenize(IReadOnlyList<LLamaToken> tokens)
108106
{
109107
// Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder!
110108
// It should be kept around for the entire time you are decoding one stream of tokens.
@@ -219,7 +217,7 @@ public void LoadState(State state)
219217
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
220218
/// <param name="lastTokens">The tokens recently returned from the model</param>
221219
/// <returns>The selected token</returns>
222-
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
220+
public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan<LLamaToken> lastTokens)
223221
{
224222
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
225223
}
@@ -240,11 +238,11 @@ public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token>
240238
/// <param name="grammar"></param>
241239
/// <param name="minP"></param>
242240
/// <returns></returns>
243-
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
244-
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
245-
SafeLLamaGrammarHandle? grammar, float minP)
241+
public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
242+
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
243+
SafeLLamaGrammarHandle? grammar, float minP)
246244
{
247-
llama_token id;
245+
LLamaToken id;
248246

249247
if (grammar != null)
250248
{
@@ -301,7 +299,7 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu
301299
/// <param name="alphaPresence"></param>
302300
/// <param name="penalizeNL"></param>
303301
/// <returns></returns>
304-
public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dictionary<llama_token, float>? logitBias = null,
302+
public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
305303
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
306304
bool penalizeNL = true)
307305
{
@@ -311,12 +309,12 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dic
311309
if (logitBias is not null)
312310
{
313311
foreach (var (key, value) in logitBias)
314-
logits[key] += value;
312+
logits[(int)key] += value;
315313
}
316314

317315
// Save the newline logit value
318-
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
319-
var nl_logit = logits[nl_token];
316+
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
317+
var nl_logit = logits[(int)nl_token];
320318

321319
// Convert logits into token candidates
322320
var candidates_p = LLamaTokenDataArray.Create(logits);
@@ -353,7 +351,7 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dic
353351
/// <returns>The updated `pastTokensCount`.</returns>
354352
/// <exception cref="RuntimeError"></exception>
355353
[Obsolete("use llama_decode() instead")]
356-
public int Eval(llama_token[] tokens, int pastTokensCount)
354+
public int Eval(LLamaToken[] tokens, int pastTokensCount)
357355
{
358356
return Eval(tokens.AsSpan(), pastTokensCount);
359357
}
@@ -366,7 +364,7 @@ public int Eval(llama_token[] tokens, int pastTokensCount)
366364
/// <returns>The updated `pastTokensCount`.</returns>
367365
/// <exception cref="RuntimeError"></exception>
368366
[Obsolete("use llama_decode() instead")]
369-
public int Eval(List<llama_token> tokens, int pastTokensCount)
367+
public int Eval(List<LLamaToken> tokens, int pastTokensCount)
370368
{
371369
#if NET5_0_OR_GREATER
372370
var span = CollectionsMarshal.AsSpan(tokens);
@@ -376,15 +374,15 @@ public int Eval(List<llama_token> tokens, int pastTokensCount)
376374
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
377375
// avoid the copying.
378376

379-
var rented = System.Buffers.ArrayPool<llama_token>.Shared.Rent(tokens.Count);
377+
var rented = System.Buffers.ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
380378
try
381379
{
382380
tokens.CopyTo(rented, 0);
383381
return Eval(rented.AsSpan(0, tokens.Count), pastTokensCount);
384382
}
385383
finally
386384
{
387-
System.Buffers.ArrayPool<llama_token>.Shared.Return(rented);
385+
System.Buffers.ArrayPool<LLamaToken>.Shared.Return(rented);
388386
}
389387
#endif
390388
}
@@ -397,7 +395,7 @@ public int Eval(List<llama_token> tokens, int pastTokensCount)
397395
/// <returns>The updated `pastTokensCount`.</returns>
398396
/// <exception cref="RuntimeError"></exception>
399397
[Obsolete("use llama_decode() instead")]
400-
public int Eval(ReadOnlyMemory<llama_token> tokens, int pastTokensCount)
398+
public int Eval(ReadOnlyMemory<LLamaToken> tokens, int pastTokensCount)
401399
{
402400
return Eval(tokens.Span, pastTokensCount);
403401
}
@@ -410,7 +408,7 @@ public int Eval(ReadOnlyMemory<llama_token> tokens, int pastTokensCount)
410408
/// <returns>The updated `pastTokensCount`.</returns>
411409
/// <exception cref="RuntimeError"></exception>
412410
[Obsolete("use llama_decode() instead")]
413-
public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount)
411+
public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount)
414412
{
415413
var total = tokens.Length;
416414
for(var i = 0; i < total; i += (int)Params.BatchSize)

LLama/LLamaExecutorBase.cs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
namespace LLama
1616
{
17-
using llama_token = Int32;
1817
/// <summary>
1918
/// The base class for stateful LLama executors.
2019
/// </summary>
@@ -47,19 +46,19 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
4746
/// <summary>
4847
/// A container of the tokens to be processed and after processed.
4948
/// </summary>
50-
protected List<llama_token> _embeds = new(); // embd
49+
protected List<LLamaToken> _embeds = new(); // embd
5150
/// <summary>
5251
/// A container for the tokens of input.
5352
/// </summary>
54-
protected List<llama_token> _embed_inps = new();
53+
protected List<LLamaToken> _embed_inps = new();
5554
/// <summary>
5655
///
5756
/// </summary>
58-
protected List<llama_token> _session_tokens = new();
57+
protected List<LLamaToken> _session_tokens = new();
5958
/// <summary>
6059
/// The last tokens generated by the model.
6160
/// </summary>
62-
protected FixedSizeQueue<llama_token> _last_n_tokens;
61+
protected FixedSizeQueue<LLamaToken> _last_n_tokens;
6362
/// <summary>
6463
/// The context used by the executor.
6564
/// </summary>
@@ -84,7 +83,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
8483
_pastTokensCount = 0;
8584
_consumedTokensCount = 0;
8685
_n_session_consumed = 0;
87-
_last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize);
86+
_last_n_tokens = new FixedSizeQueue<LLamaToken>(Context.ContextSize);
8887
_decoder = new StreamingTokenDecoder(context);
8988
}
9089

@@ -105,7 +104,7 @@ public StatefulExecutorBase WithSessionFile(string filename)
105104
if (File.Exists(filename))
106105
{
107106
_logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}");
108-
var session_tokens = new llama_token[Context.ContextSize];
107+
var session_tokens = new LLamaToken[Context.ContextSize];
109108
if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out))
110109
{
111110
_logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}");
@@ -361,16 +360,16 @@ public class ExecutorBaseState
361360
public string? SessionFilePath { get; set; }
362361

363362
[JsonPropertyName("embd")]
364-
public List<llama_token> Embeds { get; set; }
363+
public List<LLamaToken> Embeds { get; set; }
365364

366365
[JsonPropertyName("embd_inps")]
367-
public List<llama_token> EmbedInps { get; set; }
366+
public List<LLamaToken> EmbedInps { get; set; }
368367

369368
[JsonPropertyName("session_tokens")]
370-
public List<llama_token> SessionTokens { get; set; }
369+
public List<LLamaToken> SessionTokens { get; set; }
371370

372371
[JsonPropertyName("last_n_tokens")]
373-
public llama_token[] LastTokens { get; set; }
372+
public LLamaToken[] LastTokens { get; set; }
374373

375374
[JsonPropertyName("last_tokens_maximum_count")]
376375
public int LastTokensCapacity { get; set; }

LLama/LLamaInstructExecutor.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
namespace LLama
1515
{
16-
using llama_token = Int32;
1716
/// <summary>
1817
/// The LLama executor for instruct mode.
1918
/// </summary>
@@ -22,8 +21,8 @@ public class InstructExecutor
2221
{
2322
private bool _is_prompt_run = true;
2423
private readonly string _instructionPrefix;
25-
private llama_token[] _inp_pfx;
26-
private llama_token[] _inp_sfx;
24+
private LLamaToken[] _inp_pfx;
25+
private LLamaToken[] _inp_sfx;
2726

2827
/// <summary>
2928
///
@@ -75,7 +74,7 @@ public override Task LoadState(ExecutorBaseState data)
7574
_is_prompt_run = state.IsPromptRun;
7675
_consumedTokensCount = state.ConsumedTokensCount;
7776
_embeds = state.Embeds;
78-
_last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens);
77+
_last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
7978
_inp_pfx = state.InputPrefixTokens;
8079
_inp_sfx = state.InputSuffixTokens;
8180
_n_matching_session_tokens = state.MatchingSessionTokensCount;
@@ -210,7 +209,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
210209
SaveSessionFile(_pathSession);
211210
}
212211

213-
llama_token id;
212+
LLamaToken id;
214213
if (inferenceParams.SamplingPipeline is not null)
215214
{
216215
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
@@ -266,12 +265,12 @@ public class InstructExecutorState : ExecutorBaseState
266265
/// Instruction prefix tokens.
267266
/// </summary>
268267
[JsonPropertyName("inp_pfx")]
269-
public llama_token[] InputPrefixTokens { get; set; }
268+
public LLamaToken[] InputPrefixTokens { get; set; }
270269
/// <summary>
271270
/// Instruction suffix tokens.
272271
/// </summary>
273272
[JsonPropertyName("inp_sfx")]
274-
public llama_token[] InputSuffixTokens { get; set; }
273+
public LLamaToken[] InputSuffixTokens { get; set; }
275274
}
276275
}
277276
}

LLama/LLamaInteractExecutor.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313

1414
namespace LLama
1515
{
16-
using llama_token = Int32;
1716
/// <summary>
1817
/// The LLama executor for interactive mode.
1918
/// </summary>
2019
public class InteractiveExecutor : StatefulExecutorBase
2120
{
2221
private bool _is_prompt_run = true;
23-
private readonly llama_token _llama_token_newline;
22+
private readonly LLamaToken _llama_token_newline;
2423

2524
/// <summary>
2625
///
@@ -63,7 +62,7 @@ public override Task LoadState(ExecutorBaseState data)
6362
_is_prompt_run = state.IsPromptRun;
6463
_consumedTokensCount = state.ConsumedTokensCount;
6564
_embeds = state.Embeds;
66-
_last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens);
65+
_last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
6766
_n_matching_session_tokens = state.MatchingSessionTokensCount;
6867
_pastTokensCount = state.PastTokensCount;
6968
_pathSession = state.SessionFilePath;
@@ -189,7 +188,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
189188
SaveSessionFile(_pathSession);
190189
}
191190

192-
llama_token id;
191+
LLamaToken id;
193192
if (inferenceParams.SamplingPipeline is not null)
194193
{
195194
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());

0 commit comments

Comments
 (0)