Skip to content

Commit 528bb01

Browse files
committed
Factored out a safer llama_sample_apply_guidance method based on spans
1 parent 62381ab commit 528bb01

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

LLama/Native/LLamaTokenDataArray.cs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,7 @@ public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan<float> guidanc
217217
}
218218

219219
// Apply guidance
220-
unsafe
221-
{
222-
fixed (float* logitsPtr = logits)
223-
fixed (float* guidanceLogitsPtr = guidanceLogits)
224-
{
225-
NativeApi.llama_sample_apply_guidance(context, logitsPtr, guidanceLogitsPtr, guidance);
226-
}
227-
}
220+
NativeApi.llama_sample_apply_guidance(context, logits, guidanceLogits, guidance);
228221

229222
// Copy logits back into data array
230223
for (var i = 0; i < data.Length; i++)

LLama/Native/NativeApi.Sampling.cs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Runtime.InteropServices;
1+
using System;
2+
using System.Runtime.InteropServices;
23

34
namespace LLama.Native
45
{
@@ -23,6 +24,33 @@ public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaCont
2324
float penalty_freq,
2425
float penalty_present);
2526

27+
/// <summary>
28+
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
29+
/// </summary>
30+
/// <param name="ctx"></param>
31+
/// <param name="logits">Logits extracted from the original generation context.</param>
32+
/// <param name="logits_guidance">Logits extracted from a separate context from the same model.
33+
/// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.</param>
34+
/// <param name="scale">Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.</param>
35+
public static void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<float> logits_guidance, float scale)
36+
{
37+
if (logits == null)
38+
throw new ArgumentNullException(nameof(logits));
39+
if (logits_guidance == null)
40+
throw new ArgumentNullException(nameof(logits_guidance));
41+
if (logits.Length != ctx.VocabCount)
42+
throw new ArgumentException("Logits count must have equal context vocab size", nameof(logits));
43+
if (logits_guidance.Length != ctx.VocabCount)
44+
throw new ArgumentException("Guidance logits count must have equal context vocab size", nameof(logits_guidance));
45+
46+
unsafe
47+
{
48+
fixed (float* logitsPtr = logits)
49+
fixed (float* logitsGuidancePtr = logits_guidance)
50+
llama_sample_apply_guidance(ctx, logitsPtr, logitsGuidancePtr, scale);
51+
}
52+
}
53+
2654
/// <summary>
2755
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
2856
/// </summary>

0 commit comments

Comments
 (0)