1
+ using LLama . Batched ;
2
+ using LLama . Common ;
3
+ using LLama . Native ;
4
+ using LLama . Sampling ;
5
+ using Spectre . Console ;
6
+
7
+ namespace LLama . Examples . Examples ;
8
+
9
+ /// <summary>
10
+ /// This demonstrates using a batch to generate two sequences and then using one
11
+ /// sequence as the negative guidance ("context free guidance") for the other.
12
+ /// </summary>
13
+ public class BatchedExecutorGuidance
14
+ {
15
+ private const int n_len = 32 ;
16
+
17
+ public static async Task Run ( )
18
+ {
19
+ string modelPath = UserSettings . GetModelPath ( ) ;
20
+
21
+ var parameters = new ModelParams ( modelPath ) ;
22
+ using var model = LLamaWeights . LoadFromFile ( parameters ) ;
23
+
24
+ var positivePrompt = AnsiConsole . Ask ( "Positive Prompt (or ENTER for default):" , "My favourite colour is" ) . Trim ( ) ;
25
+ var negativePrompt = AnsiConsole . Ask ( "Negative Prompt (or ENTER for default):" , "I hate the colour red. My favourite colour is" ) . Trim ( ) ;
26
+
27
+ // Create an executor that can evaluate a batch of conversations together
28
+ var executor = new BatchedExecutor ( model , parameters ) ;
29
+
30
+ // Print some info
31
+ var name = executor . Model . Metadata . GetValueOrDefault ( "general.name" , "unknown model name" ) ;
32
+ Console . WriteLine ( $ "Created executor with model: { name } ") ;
33
+
34
+ // Load the two prompts into two conversations
35
+ var guided = executor . Prompt ( positivePrompt ) ;
36
+ var guidance = executor . Prompt ( negativePrompt ) ;
37
+
38
+ // Run inference to evaluate prompts
39
+ await AnsiConsole
40
+ . Status ( )
41
+ . Spinner ( Spinner . Known . Line )
42
+ . StartAsync ( "Evaluating Prompts..." , _ => executor . Infer ( ) ) ;
43
+
44
+ // Fork the "guided" conversation. We'll run this one without guidance for comparison
45
+ var unguided = guided . Fork ( ) ;
46
+
47
+ // Run inference loop
48
+ var unguidedSampler = new GuidedSampler ( null ) ;
49
+ var unguidedDecoder = new StreamingTokenDecoder ( executor . Context ) ;
50
+ var guidedSampler = new GuidedSampler ( guidance ) ;
51
+ var guidedDecoder = new StreamingTokenDecoder ( executor . Context ) ;
52
+ await AnsiConsole
53
+ . Progress ( )
54
+ . StartAsync ( async progress =>
55
+ {
56
+ var reporter = progress . AddTask ( "Running Inference" , maxValue : n_len ) ;
57
+
58
+ for ( var i = 0 ; i < n_len ; i ++ )
59
+ {
60
+ if ( i != 0 )
61
+ await executor . Infer ( ) ;
62
+
63
+ // Sample from the "unguided" conversation
64
+ var u = unguidedSampler . Sample ( executor . Context . NativeHandle , unguided . Sample ( ) . ToArray ( ) , Array . Empty < LLamaToken > ( ) ) ;
65
+ unguidedDecoder . Add ( u ) ;
66
+ unguided . Prompt ( u ) ;
67
+
68
+ // Sample form the "guided" conversation
69
+ var g = guidedSampler . Sample ( executor . Context . NativeHandle , guided . Sample ( ) . ToArray ( ) , Array . Empty < LLamaToken > ( ) ) ;
70
+ guidedDecoder . Add ( g ) ;
71
+
72
+ // Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
73
+ guided . Prompt ( g ) ;
74
+ guidance . Prompt ( g ) ;
75
+
76
+ // Early exit if we reach the natural end of the guided sentence
77
+ if ( g == model . EndOfSentenceToken )
78
+ break ;
79
+
80
+ reporter . Increment ( 1 ) ;
81
+ }
82
+ } ) ;
83
+
84
+ AnsiConsole . MarkupLine ( $ "[green]Unguided:[/][white]{ unguidedDecoder . Read ( ) } [/]") ;
85
+ AnsiConsole . MarkupLine ( $ "[green]Guided:[/][white]{ guidedDecoder . Read ( ) } [/]") ;
86
+ }
87
+
88
+ private class GuidedSampler ( Conversation ? guidance )
89
+ : BaseSamplingPipeline
90
+ {
91
+ public override void Accept ( SafeLLamaContextHandle ctx , LLamaToken token )
92
+ {
93
+ }
94
+
95
+ public override ISamplingPipeline Clone ( )
96
+ {
97
+ throw new NotSupportedException ( ) ;
98
+ }
99
+
100
+ protected override IReadOnlyList < LLamaToken > GetProtectedTokens ( SafeLLamaContextHandle ctx )
101
+ {
102
+ return Array . Empty < LLamaToken > ( ) ;
103
+ }
104
+
105
+ protected override void ProcessLogits ( SafeLLamaContextHandle ctx , Span < float > logits , ReadOnlySpan < LLamaToken > lastTokens )
106
+ {
107
+ if ( guidance != null )
108
+ {
109
+ // Get the logits generated by the guidance sequences
110
+ var guidanceLogits = guidance . Sample ( ) ;
111
+
112
+ // Use those logits to guide this sequence
113
+ NativeApi . llama_sample_apply_guidance ( ctx , logits , guidanceLogits , 2 ) ;
114
+ }
115
+ }
116
+
117
+ protected override LLamaToken ProcessTokenDataArray ( SafeLLamaContextHandle ctx , LLamaTokenDataArray candidates , ReadOnlySpan < LLamaToken > lastTokens )
118
+ {
119
+ candidates . Temperature ( ctx , 0.8f ) ;
120
+ candidates . TopK ( ctx , 25 ) ;
121
+
122
+ return candidates . SampleToken ( ctx ) ;
123
+ }
124
+ }
125
+ }
0 commit comments