Skip to content

Commit 2990b47

Browse files
authored
Merge pull request #787 from patrick-hovsepian/generic_prompt
Generic Prompt Formatter
2 parents a5de5f7 + 8c9bbb6 commit 2990b47

14 files changed

+386
-137
lines changed
Lines changed: 27 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,63 @@
1-
using LLama.Abstractions;
2-
using LLama.Common;
1+
using LLama.Common;
2+
using LLama.Transformers;
33

44
namespace LLama.Examples.Examples;
55

6-
// When using chatsession, it's a common case that you want to strip the role names
7-
// rather than display them. This example shows how to use transforms to strip them.
6+
/// <summary>
7+
/// This sample shows a simple chatbot
8+
/// It's configured to use the default prompt template as provided by llama.cpp and supports
9+
/// models such as llama3, llama2, phi3, qwen1.5, etc.
10+
/// </summary>
811
public class LLama3ChatSession
912
{
1013
public static async Task Run()
1114
{
12-
string modelPath = UserSettings.GetModelPath();
13-
15+
var modelPath = UserSettings.GetModelPath();
1416
var parameters = new ModelParams(modelPath)
1517
{
1618
Seed = 1337,
1719
GpuLayerCount = 10
1820
};
21+
1922
using var model = LLamaWeights.LoadFromFile(parameters);
2023
using var context = model.CreateContext(parameters);
2124
var executor = new InteractiveExecutor(context);
2225

2326
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
24-
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
27+
var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
2528

2629
ChatSession session = new(executor, chatHistory);
27-
session.WithHistoryTransform(new LLama3HistoryTransform());
30+
31+
// add the default templator. If llama.cpp doesn't support the template by default,
32+
// you'll need to write your own transformer to format the prompt correctly
33+
session.WithHistoryTransform(new PromptTemplateTransformer(model, withAssistant: true));
34+
35+
// Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed sometimes
2836
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
29-
new string[] { "User:", "Assistant:", "�" },
37+
[model.Tokens.EndOfTurnToken!, "�"],
3038
redundancyLength: 5));
3139

32-
InferenceParams inferenceParams = new InferenceParams()
40+
var inferenceParams = new InferenceParams()
3341
{
42+
MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
3443
Temperature = 0.6f,
35-
AntiPrompts = new List<string> { "User:" }
44+
AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string
3645
};
3746

3847
Console.ForegroundColor = ConsoleColor.Yellow;
3948
Console.WriteLine("The chat session has started.");
4049

4150
// show the prompt
4251
Console.ForegroundColor = ConsoleColor.Green;
43-
string userInput = Console.ReadLine() ?? "";
52+
Console.Write("User> ");
53+
var userInput = Console.ReadLine() ?? "";
4454

4555
while (userInput != "exit")
4656
{
57+
Console.ForegroundColor = ConsoleColor.White;
58+
Console.Write("Assistant> ");
59+
60+
// as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc
4761
await foreach (
4862
var text
4963
in session.ChatAsync(
@@ -56,71 +70,8 @@ in session.ChatAsync(
5670
Console.WriteLine();
5771

5872
Console.ForegroundColor = ConsoleColor.Green;
73+
Console.Write("User> ");
5974
userInput = Console.ReadLine() ?? "";
60-
61-
Console.ForegroundColor = ConsoleColor.White;
62-
}
63-
}
64-
65-
class LLama3HistoryTransform : IHistoryTransform
66-
{
67-
/// <summary>
68-
/// Convert a ChatHistory instance to plain text.
69-
/// </summary>
70-
/// <param name="history">The ChatHistory instance</param>
71-
/// <returns></returns>
72-
public string HistoryToText(ChatHistory history)
73-
{
74-
string res = Bos;
75-
foreach (var message in history.Messages)
76-
{
77-
res += EncodeMessage(message);
78-
}
79-
res += EncodeHeader(new ChatHistory.Message(AuthorRole.Assistant, ""));
80-
return res;
81-
}
82-
83-
private string EncodeHeader(ChatHistory.Message message)
84-
{
85-
string res = StartHeaderId;
86-
res += message.AuthorRole.ToString();
87-
res += EndHeaderId;
88-
res += "\n\n";
89-
return res;
90-
}
91-
92-
private string EncodeMessage(ChatHistory.Message message)
93-
{
94-
string res = EncodeHeader(message);
95-
res += message.Content;
96-
res += EndofTurn;
97-
return res;
9875
}
99-
100-
/// <summary>
101-
/// Converts plain text to a ChatHistory instance.
102-
/// </summary>
103-
/// <param name="role">The role for the author.</param>
104-
/// <param name="text">The chat history as plain text.</param>
105-
/// <returns>The updated history.</returns>
106-
public ChatHistory TextToHistory(AuthorRole role, string text)
107-
{
108-
return new ChatHistory(new ChatHistory.Message[] { new ChatHistory.Message(role, text) });
109-
}
110-
111-
/// <summary>
112-
/// Copy the transform.
113-
/// </summary>
114-
/// <returns></returns>
115-
public IHistoryTransform Clone()
116-
{
117-
return new LLama3HistoryTransform();
118-
}
119-
120-
private const string StartHeaderId = "<|start_header_id|>";
121-
private const string EndHeaderId = "<|end_header_id|>";
122-
private const string Bos = "<|begin_of_text|>";
123-
private const string Eos = "<|end_of_text|>";
124-
private const string EndofTurn = "<|eot_id|>";
12576
}
12677
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using System.Text;
2+
using LLama.Common;
3+
using LLama.Native;
4+
using LLama.Extensions;
5+
6+
namespace LLama.Unittest.Native;
7+
8+
public class SafeLlamaModelHandleTests
9+
{
10+
private readonly LLamaWeights _model;
11+
private readonly SafeLlamaModelHandle TestableHandle;
12+
13+
public SafeLlamaModelHandleTests()
14+
{
15+
var @params = new ModelParams(Constants.GenerativeModelPath)
16+
{
17+
ContextSize = 1,
18+
GpuLayerCount = Constants.CIGpuLayerCount
19+
};
20+
_model = LLamaWeights.LoadFromFile(@params);
21+
22+
TestableHandle = _model.NativeHandle;
23+
}
24+
25+
[Fact]
26+
public void MetadataValByKey_ReturnsCorrectly()
27+
{
28+
const string key = "general.name";
29+
var template = _model.NativeHandle.MetadataValueByKey(key);
30+
var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span);
31+
32+
const string expected = "LLaMA v2";
33+
Assert.Equal(expected, name);
34+
35+
var metadataLookup = _model.Metadata[key];
36+
Assert.Equal(expected, metadataLookup);
37+
Assert.Equal(name, metadataLookup);
38+
}
39+
}

LLama.Unittest/TemplateTests.cs

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
using System.Text;
22
using LLama.Common;
3-
using LLama.Native;
3+
using LLama.Extensions;
44

55
namespace LLama.Unittest;
66

77
public sealed class TemplateTests
88
: IDisposable
99
{
1010
private readonly LLamaWeights _model;
11-
11+
1212
public TemplateTests()
1313
{
1414
var @params = new ModelParams(Constants.GenerativeModelPath)
@@ -18,12 +18,12 @@ public TemplateTests()
1818
};
1919
_model = LLamaWeights.LoadFromFile(@params);
2020
}
21-
21+
2222
public void Dispose()
2323
{
2424
_model.Dispose();
2525
}
26-
26+
2727
[Fact]
2828
public void BasicTemplate()
2929
{
@@ -47,18 +47,10 @@ public void BasicTemplate()
4747
templater.Add("user", "ccc");
4848
Assert.Equal(8, templater.Count);
4949

50-
// Call once with empty array to discover length
51-
var length = templater.Apply(Array.Empty<byte>());
52-
var dest = new byte[length];
53-
54-
Assert.Equal(8, templater.Count);
55-
56-
// Call again to get contents
57-
length = templater.Apply(dest);
58-
50+
var dest = templater.Apply();
5951
Assert.Equal(8, templater.Count);
6052

61-
var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
53+
var templateResult = Encoding.UTF8.GetString(dest);
6254
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
6355
"<|im_start|>user\nworld<|im_end|>\n" +
6456
"<|im_start|>assistant\n" +
@@ -93,17 +85,10 @@ public void CustomTemplate()
9385
Assert.Equal(4, templater.Count);
9486

9587
// Call once with empty array to discover length
96-
var length = templater.Apply(Array.Empty<byte>());
97-
var dest = new byte[length];
98-
88+
var dest = templater.Apply();
9989
Assert.Equal(4, templater.Count);
10090

101-
// Call again to get contents
102-
length = templater.Apply(dest);
103-
104-
Assert.Equal(4, templater.Count);
105-
106-
var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
91+
var templateResult = Encoding.UTF8.GetString(dest);
10792
const string expected = "<start_of_turn>model\n" +
10893
"hello<end_of_turn>\n" +
10994
"<start_of_turn>user\n" +
@@ -143,17 +128,10 @@ public void BasicTemplateWithAddAssistant()
143128
Assert.Equal(8, templater.Count);
144129

145130
// Call once with empty array to discover length
146-
var length = templater.Apply(Array.Empty<byte>());
147-
var dest = new byte[length];
148-
131+
var dest = templater.Apply();
149132
Assert.Equal(8, templater.Count);
150133

151-
// Call again to get contents
152-
length = templater.Apply(dest);
153-
154-
Assert.Equal(8, templater.Count);
155-
156-
var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
134+
var templateResult = Encoding.UTF8.GetString(dest);
157135
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
158136
"<|im_start|>user\nworld<|im_end|>\n" +
159137
"<|im_start|>assistant\n" +
@@ -249,4 +227,40 @@ public void RemoveOutOfRange()
249227
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1));
250228
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2));
251229
}
230+
231+
[Fact]
232+
public void Clear_ResetsTemplateState()
233+
{
234+
var templater = new LLamaTemplate(_model);
235+
templater.Add("assistant", "1")
236+
.Add("user", "2");
237+
238+
Assert.Equal(2, templater.Count);
239+
240+
templater.Clear();
241+
242+
Assert.Equal(0, templater.Count);
243+
244+
const string userData = nameof(userData);
245+
templater.Add("user", userData);
246+
247+
// Generte the template string
248+
var dest = templater.Apply();
249+
var templateResult = Encoding.UTF8.GetString(dest);
250+
251+
const string expectedTemplate = $"<|im_start|>user\n{userData}<|im_end|>\n";
252+
Assert.Equal(expectedTemplate, templateResult);
253+
}
254+
255+
[Fact]
256+
public void EndOTurnToken_ReturnsExpected()
257+
{
258+
Assert.Null(_model.Tokens.EndOfTurnToken);
259+
}
260+
261+
[Fact]
262+
public void EndOSpeechToken_ReturnsExpected()
263+
{
264+
Assert.Equal("</s>", _model.Tokens.EndOfSpeechToken);
265+
}
252266
}

0 commit comments

Comments
 (0)