@@ -8,7 +8,7 @@ namespace LLama.Native;
88/// A chain of sampler stages that can be used to select tokens from logits.
99/// </summary>
1010/// <remarks>Wraps a handle returned from `llama_sampler_chain_init`. Other samplers are owned by this chain and are never directly exposed.</remarks>
11- public class SafeLLamaSamplerChainHandle
11+ public sealed class SafeLLamaSamplerChainHandle
1212 : SafeLLamaHandleBase
1313{
1414 /// <summary>
@@ -415,6 +415,62 @@ public void AddGrammar(SafeLlamaModelHandle model, string grammar, string root)
415415 // ReSharper restore InconsistentNaming
416416 }
417417
418+ /// <summary>
419+ /// Create a sampler using lazy grammar sampling: https://github.com/ggerganov/llama.cpp/pull/9639
420+ /// </summary>
421+ /// <param name="model"></param>
422+ /// <param name="grammar">Grammar in GBNF form</param>
423+ /// <param name="root">Root rule of the grammar</param>
424+ /// <param name="triggerTokens">A list of tokens that will trigger the grammar sampler.</param>
425+ /// <param name="triggerWords">A list of words that will trigger the grammar sampler.</param>
426+ /// <returns></returns>
427+ public void AddLazyGrammar (
428+ SafeLlamaModelHandle model ,
429+ string grammar , string root ,
430+ ReadOnlySpan < string > triggerWords ,
431+ ReadOnlySpan < LLamaToken > triggerTokens )
432+ {
433+ unsafe
434+ {
435+ // Convert strings, fix memory in place, build array of pointers
436+ var handles = new List < MemoryHandle > ( ) ;
437+ var triggerWordsPtrs = stackalloc byte * [ triggerWords . Length ] ;
438+ for ( var i = 0 ; i < triggerWords . Length ; i ++ )
439+ {
440+ var chars = Encoding . Default . GetBytes ( triggerWords [ i ] ) ;
441+ handles . Add ( chars . AsMemory ( ) . Pin ( ) ) ;
442+
443+ triggerWordsPtrs [ i ] = ( byte * ) handles [ i ] . Pointer ;
444+ }
445+
446+ fixed ( LLamaToken * triggerTokensPtr = triggerTokens )
447+ {
448+ llama_sampler_chain_add (
449+ this ,
450+ llama_sampler_init_grammar_lazy (
451+ model . Vocab . VocabNative ,
452+ grammar , root ,
453+ triggerWordsPtrs , ( nuint ) triggerWords . Length ,
454+ triggerTokensPtr , ( nuint ) triggerTokens . Length
455+ )
456+ ) ;
457+ }
458+
459+ // Clear up all the handles fixing the memory in place
460+ for ( var i = 0 ; i < handles . Count ; i ++ )
461+ handles [ i ] . Dispose ( ) ;
462+ }
463+
464+ // ReSharper disable InconsistentNaming
465+ [ DllImport ( NativeApi . libraryName , CallingConvention = CallingConvention . Cdecl ) ]
466+ static extern unsafe IntPtr llama_sampler_init_grammar_lazy (
467+ LLamaVocabNative * model ,
468+ string grammar_str , string grammar_root ,
469+ byte * * trigger_words , nuint num_trigger_words ,
470+ LLamaToken * trigger_tokens , nuint num_trigger_tokens ) ;
471+ // ReSharper restore InconsistentNaming
472+ }
473+
418474 /// <summary>
419475 /// Create a sampler that applies various repetition penalties.
420476 ///
@@ -625,9 +681,9 @@ internal struct LLamaSamplerINative
625681 /// Apply this sampler to a set of logits
626682 /// </summary>
627683 /// <param name="smpl"></param>
628- /// <param name="cur_p "></param>
684+ /// <param name="logits "></param>
629685 [ UnmanagedFunctionPointer ( CallingConvention . Cdecl ) ]
630- public delegate void ApplyDelegate ( ref LLamaSamplerNative smpl , ref LLamaTokenDataArrayNative cur_p ) ;
686+ public delegate void ApplyDelegate ( ref LLamaSamplerNative smpl , ref LLamaTokenDataArrayNative logits ) ;
631687
632688 /// <summary>
633689 /// Reset the internal state of this sampler
0 commit comments