1
+ using LLama . Batched ;
2
+ using LLama . Common ;
3
+ using LLama . Native ;
4
+ using LLama . Sampling ;
5
+
6
+ namespace LLama . Examples . Examples ;
7
+
8
+ /// <summary>
9
+ /// This demonstrates generating multiple replies to the same prompt, with a shared cache
10
+ /// </summary>
11
+ public class BatchedExecutorFork
12
+ {
13
+ private const int n_split = 16 ;
14
+ private const int n_len = 64 ;
15
+
16
+ public static async Task Run ( )
17
+ {
18
+ Console . Write ( "Please input your model path: " ) ;
19
+ var modelPath = Console . ReadLine ( ) ;
20
+
21
+ var parameters = new ModelParams ( modelPath ) ;
22
+ using var model = LLamaWeights . LoadFromFile ( parameters ) ;
23
+
24
+ Console . WriteLine ( "Prompt (leave blank to select automatically):" ) ;
25
+ var prompt = Console . ReadLine ( ) ;
26
+ if ( string . IsNullOrWhiteSpace ( prompt ) )
27
+ prompt = "Not many people know that" ;
28
+
29
+ // Create an executor that can evaluate a batch of conversations together
30
+ var executor = new BatchedExecutor ( model , parameters ) ;
31
+
32
+ // Print some info
33
+ var name = executor . Model . Metadata . GetValueOrDefault ( "general.name" , "unknown model name" ) ;
34
+ Console . WriteLine ( $ "Created executor with model: { name } ") ;
35
+
36
+ // Evaluate the initial prompt to create one conversation
37
+ var start = executor . Prompt ( prompt ) ;
38
+ await executor . Infer ( ) ;
39
+
40
+ // Create the root node of the tree
41
+ var root = new Node ( start ) ;
42
+
43
+ // Run inference loop
44
+ for ( var i = 0 ; i < n_len ; i ++ )
45
+ {
46
+ if ( i != 0 )
47
+ await executor . Infer ( ) ;
48
+
49
+ // Occasionally fork all the active conversations
50
+ if ( i != 0 && i % n_split == 0 )
51
+ root . Split ( ) ;
52
+
53
+ // Sample all active conversations
54
+ root . Sample ( ) ;
55
+ }
56
+
57
+ Console . WriteLine ( $ "{ prompt } ...") ;
58
+ root . Print ( 1 ) ;
59
+
60
+ Console . WriteLine ( "Press any key to exit demo" ) ;
61
+ Console . ReadKey ( true ) ;
62
+ }
63
+
64
+ class Node
65
+ {
66
+ private readonly StreamingTokenDecoder _decoder ;
67
+
68
+ private readonly DefaultSamplingPipeline _sampler ;
69
+ private Conversation ? _conversation ;
70
+
71
+ private Node ? _left ;
72
+ private Node ? _right ;
73
+
74
+ public int ActiveConversationCount => _conversation != null ? 1 : _left ! . ActiveConversationCount + _right ! . ActiveConversationCount ;
75
+
76
+ public Node ( Conversation conversation )
77
+ {
78
+ _sampler = new DefaultSamplingPipeline ( ) ;
79
+ _conversation = conversation ;
80
+ _decoder = new StreamingTokenDecoder ( conversation . Executor . Context ) ;
81
+ }
82
+
83
+ public void Sample ( )
84
+ {
85
+ if ( _conversation == null )
86
+ {
87
+ _left ? . Sample ( ) ;
88
+ _right ? . Sample ( ) ;
89
+ return ;
90
+ }
91
+
92
+ if ( _conversation . RequiresInference )
93
+ return ;
94
+
95
+ // Sample one token
96
+ var ctx = _conversation . Executor . Context . NativeHandle ;
97
+ var logitsCopy = _conversation . Sample ( ) . ToArray ( ) ;
98
+ var token = _sampler . Sample ( ctx , logitsCopy , Array . Empty < LLamaToken > ( ) ) ;
99
+ _sampler . Accept ( ctx , token ) ;
100
+ _decoder . Add ( token ) ;
101
+
102
+ // Prompt the conversation with this token, to continue generating from there
103
+ _conversation . Prompt ( token ) ;
104
+ }
105
+
106
+ public void Split ( )
107
+ {
108
+ if ( _conversation != null )
109
+ {
110
+ _left = new Node ( _conversation . Fork ( ) ) ;
111
+ _right = new Node ( _conversation . Fork ( ) ) ;
112
+
113
+ _conversation . Dispose ( ) ;
114
+ _conversation = null ;
115
+ }
116
+ else
117
+ {
118
+ _left ? . Split ( ) ;
119
+ _right ? . Split ( ) ;
120
+ }
121
+ }
122
+
123
+ public void Print ( int indendation )
124
+ {
125
+ var colors = new [ ] { ConsoleColor . Red , ConsoleColor . Green , ConsoleColor . Blue , ConsoleColor . Yellow , ConsoleColor . White } ;
126
+ Console . ForegroundColor = colors [ indendation % colors . Length ] ;
127
+
128
+ var message = _decoder . Read ( ) . ReplaceLineEndings ( "" ) ;
129
+
130
+ var prefix = new string ( ' ' , indendation * 3 ) ;
131
+ var suffix = _conversation == null ? "..." : "" ;
132
+ Console . WriteLine ( $ "{ prefix } ...{ message } { suffix } ") ;
133
+
134
+ _left ? . Print ( indendation + 2 ) ;
135
+ _right ? . Print ( indendation + 2 ) ;
136
+ }
137
+ }
138
+ }
0 commit comments