1+ using System . Diagnostics . CodeAnalysis ;
2+ using System . Text ;
3+ using LLama . Batched ;
4+ using LLama . Common ;
5+ using LLama . Native ;
6+ using LLama . Sampling ;
7+ using Spectre . Console ;
8+
9+ namespace LLama . Examples . Examples ;
10+
11+ /// <summary>
12+ /// This demonstrates generating multiple replies to the same prompt, with a shared cache
13+ /// </summary>
14+ public class BatchedExecutorSimple
15+ {
16+ /// <summary>
17+ /// Set total length of the sequence to generate
18+ /// </summary>
19+ private const int TokenCount = 72 ;
20+
21+ public static async Task Run ( )
22+ {
23+ // Load model weights
24+ var parameters = new ModelParams ( UserSettings . GetModelPath ( ) ) ;
25+ using var model = await LLamaWeights . LoadFromFileAsync ( parameters ) ;
26+
27+ // Create an executor that can evaluate a batch of conversations together
28+ using var executor = new BatchedExecutor ( model , parameters ) ;
29+
30+ // we'll need this for evaluating if we are at the end of generation
31+ var modelTokens = executor . Context . NativeHandle . ModelHandle . Tokens ;
32+
33+ // Print some info
34+ var name = model . Metadata . GetValueOrDefault ( "general.name" , "unknown model name" ) ;
35+ Console . WriteLine ( $ "Created executor with model: { name } ") ;
36+
37+ var messages = new [ ]
38+ {
39+ "What's 2+2?" ,
40+ "Where is the coldest part of Texas?" ,
41+ "What's the capital of France?" ,
42+ "What's a one word name for a food item with ground beef patties on a bun?" ,
43+ "What are two toppings for a pizza?" ,
44+ "What american football play are you calling on a 3rd and 8 from our own 25?" ,
45+ "What liquor should I add to egg nog?" ,
46+ "I have two sons, Bert and Ernie. What should I name my daughter?" ,
47+ "What day comes after Friday?" ,
48+ "What color shoes should I wear with dark blue pants?" ,
49+ } ;
50+
51+ var conversations = new List < ConversationData > ( ) ;
52+ foreach ( var message in messages )
53+ {
54+ // apply the model's prompt template to our question and system prompt
55+ var template = new LLamaTemplate ( model ) ;
56+ template . Add ( "system" , "I am a helpful bot that returns short and concise answers. I include a ten word description of my reasoning when I finish." ) ;
57+ template . Add ( "user" , message ) ;
58+ template . AddAssistant = true ;
59+ var templatedMessage = Encoding . UTF8 . GetString ( template . Apply ( ) ) ;
60+
61+ // create a new conversation and prompt it. include special and bos because we are using the template
62+ var conversation = executor . Create ( ) ;
63+ conversation . Prompt ( executor . Context . Tokenize ( templatedMessage , addBos : true , special : true ) ) ;
64+
65+ conversations . Add ( new ConversationData {
66+ Prompt = message ,
67+ Conversation = conversation ,
68+ Sampler = new GreedySamplingPipeline ( ) ,
69+ Decoder = new StreamingTokenDecoder ( executor . Context )
70+ } ) ;
71+ }
72+
73+ var table = BuildTable ( conversations ) ;
74+ await AnsiConsole . Live ( table ) . StartAsync ( async ctx =>
75+ {
76+ for ( var i = 0 ; i < TokenCount ; i ++ )
77+ {
78+ // Run inference for all conversations in the batch which have pending tokens.
79+ var decodeResult = await executor . Infer ( ) ;
80+ if ( decodeResult == DecodeResult . NoKvSlot )
81+ throw new Exception ( "Could not find a KV slot for the batch. Try reducing the size of the batch or increase the context." ) ;
82+ if ( decodeResult == DecodeResult . Error )
83+ throw new Exception ( "Unknown error occured while inferring." ) ;
84+
85+ foreach ( var conversationData in conversations . Where ( c => c . IsComplete == false ) )
86+ {
87+ if ( conversationData . Conversation . RequiresSampling == false ) continue ;
88+
89+ // sample a single token for the executor, passing the sample index of the conversation
90+ var token = conversationData . Sampler . Sample (
91+ executor . Context . NativeHandle ,
92+ conversationData . Conversation . GetSampleIndex ( ) ) ;
93+
94+ if ( modelTokens . IsEndOfGeneration ( token ) )
95+ {
96+ conversationData . MarkComplete ( ) ;
97+ }
98+ else
99+ {
100+ // it isn't the end of generation, so add this token to the decoder and then add that to our tracked data
101+ conversationData . Decoder . Add ( token ) ;
102+ conversationData . AppendAnswer ( conversationData . Decoder . Read ( ) . ReplaceLineEndings ( " " ) ) ;
103+
104+ // add the token to the conversation
105+ conversationData . Conversation . Prompt ( token ) ;
106+ }
107+ }
108+
109+ // render the current state
110+ table = BuildTable ( conversations ) ;
111+ ctx . UpdateTarget ( table ) ;
112+
113+ if ( conversations . All ( c => c . IsComplete ) )
114+ {
115+ break ;
116+ }
117+ }
118+
119+ // if we ran out of tokens before completing just mark them as complete for rendering purposes.
120+ foreach ( var data in conversations . Where ( i => i . IsComplete == false ) )
121+ {
122+ data . MarkComplete ( ) ;
123+ }
124+
125+ table = BuildTable ( conversations ) ;
126+ ctx . UpdateTarget ( table ) ;
127+ } ) ;
128+ }
129+
130+ /// <summary>
131+ /// Helper to build a table to display the conversations.
132+ /// </summary>
133+ private static Table BuildTable ( List < ConversationData > conversations )
134+ {
135+ var table = new Table ( )
136+ . RoundedBorder ( )
137+ . AddColumns ( "Prompt" , "Response" ) ;
138+
139+ foreach ( var data in conversations )
140+ {
141+ table . AddRow ( data . Prompt . EscapeMarkup ( ) , data . AnswerMarkdown ) ;
142+ }
143+
144+ return table ;
145+ }
146+ }
147+
148+ public class ConversationData
149+ {
150+ public required string Prompt { get ; init ; }
151+ public required Conversation Conversation { get ; init ; }
152+ public required BaseSamplingPipeline Sampler { get ; init ; }
153+ public required StreamingTokenDecoder Decoder { get ; init ; }
154+
155+ public string AnswerMarkdown => IsComplete
156+ ? $ "[green]{ _inProgressAnswer . Message . EscapeMarkup ( ) } { _inProgressAnswer . LatestToken . EscapeMarkup ( ) } [/]"
157+ : $ "[grey]{ _inProgressAnswer . Message . EscapeMarkup ( ) } [/][white]{ _inProgressAnswer . LatestToken . EscapeMarkup ( ) } [/]";
158+
159+ public bool IsComplete { get ; private set ; }
160+
161+ // we are only keeping track of the answer in two parts to render them differently.
162+ private ( string Message , string LatestToken ) _inProgressAnswer = ( string . Empty , string . Empty ) ;
163+
164+ public void AppendAnswer ( string newText ) => _inProgressAnswer = ( _inProgressAnswer . Message + _inProgressAnswer . LatestToken , newText ) ;
165+
166+ public void MarkComplete ( )
167+ {
168+ IsComplete = true ;
169+ if ( Conversation . IsDisposed == false )
170+ {
171+ // clean up the conversation and sampler to release more memory for inference.
172+ // real life usage would protect against these two being referenced after being disposed.
173+ Conversation . Dispose ( ) ;
174+ Sampler . Dispose ( ) ;
175+ }
176+ }
177+ }
0 commit comments