Skip to content

Commit 201e1df

Browse files
committed
Added lazy grammar method
1 parent 7447a41 commit 201e1df

File tree

1 file changed

+59
-3
lines changed

1 file changed

+59
-3
lines changed

LLama/Native/SafeLLamaSamplerHandle.cs

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace LLama.Native;
88
/// A chain of sampler stages that can be used to select tokens from logits.
99
/// </summary>
1010
/// <remarks>Wraps a handle returned from `llama_sampler_chain_init`. Other samplers are owned by this chain and are never directly exposed.</remarks>
11-
public class SafeLLamaSamplerChainHandle
11+
public sealed class SafeLLamaSamplerChainHandle
1212
: SafeLLamaHandleBase
1313
{
1414
/// <summary>
@@ -415,6 +415,62 @@ public void AddGrammar(SafeLlamaModelHandle model, string grammar, string root)
415415
// ReSharper restore InconsistentNaming
416416
}
417417

418+
/// <summary>
419+
/// Create a sampler using lazy grammar sampling: https://github.com/ggerganov/llama.cpp/pull/9639
420+
/// </summary>
421+
/// <param name="model"></param>
422+
/// <param name="grammar">Grammar in GBNF form</param>
423+
/// <param name="root">Root rule of the grammar</param>
424+
/// <param name="triggerTokens">A list of tokens that will trigger the grammar sampler.</param>
425+
/// <param name="triggerWords">A list of words that will trigger the grammar sampler.</param>
426+
/// <returns></returns>
427+
public void AddLazyGrammar(
428+
SafeLlamaModelHandle model,
429+
string grammar, string root,
430+
ReadOnlySpan<string> triggerWords,
431+
ReadOnlySpan<LLamaToken> triggerTokens)
432+
{
433+
unsafe
434+
{
435+
// Convert strings, fix memory in place, build array of pointers
436+
var handles = new List<MemoryHandle>();
437+
var triggerWordsPtrs = stackalloc byte*[triggerWords.Length];
438+
for (var i = 0; i < triggerWords.Length; i++)
439+
{
440+
var chars = Encoding.Default.GetBytes(triggerWords[i]);
441+
handles.Add(chars.AsMemory().Pin());
442+
443+
triggerWordsPtrs[i] = (byte*)handles[i].Pointer;
444+
}
445+
446+
fixed (LLamaToken* triggerTokensPtr = triggerTokens)
447+
{
448+
llama_sampler_chain_add(
449+
this,
450+
llama_sampler_init_grammar_lazy(
451+
model.Vocab.VocabNative,
452+
grammar, root,
453+
triggerWordsPtrs, (nuint)triggerWords.Length,
454+
triggerTokensPtr, (nuint)triggerTokens.Length
455+
)
456+
);
457+
}
458+
459+
// Clear up all the handles fixing the memory in place
460+
for (var i = 0; i < handles.Count; i++)
461+
handles[i].Dispose();
462+
}
463+
464+
// ReSharper disable InconsistentNaming
465+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
466+
static extern unsafe IntPtr llama_sampler_init_grammar_lazy(
467+
LLamaVocabNative* model,
468+
string grammar_str, string grammar_root,
469+
byte** trigger_words, nuint num_trigger_words,
470+
LLamaToken* trigger_tokens, nuint num_trigger_tokens);
471+
// ReSharper restore InconsistentNaming
472+
}
473+
418474
/// <summary>
419475
/// Create a sampler that applies various repetition penalties.
420476
///
@@ -625,9 +681,9 @@ internal struct LLamaSamplerINative
625681
/// Apply this sampler to a set of logits
626682
/// </summary>
627683
/// <param name="smpl"></param>
628-
/// <param name="cur_p"></param>
684+
/// <param name="logits"></param>
629685
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
630-
public delegate void ApplyDelegate(ref LLamaSamplerNative smpl, ref LLamaTokenDataArrayNative cur_p);
686+
public delegate void ApplyDelegate(ref LLamaSamplerNative smpl, ref LLamaTokenDataArrayNative logits);
631687

632688
/// <summary>
633689
/// Reset the internal state of this sampler

0 commit comments

Comments
 (0)