Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 27 additions & 76 deletions LLama.Examples/Examples/LLama3ChatSession.cs
Original file line number Diff line number Diff line change
@@ -1,49 +1,63 @@
using LLama.Abstractions;
using LLama.Common;
using LLama.Common;
using LLama.Transformers;

namespace LLama.Examples.Examples;

// When using chatsession, it's a common case that you want to strip the role names
// rather than display them. This example shows how to use transforms to strip them.
/// <summary>
/// This sample shows a simple chatbot
/// It's configured to use the default prompt template as provided by llama.cpp and supports
/// models such as llama3, llama2, phi3, qwen1.5, etc.
/// </summary>
public class LLama3ChatSession
{
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

var modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath)
{
Seed = 1337,
GpuLayerCount = 10
};

using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);
session.WithHistoryTransform(new LLama3HistoryTransform());

// add the default templator. If llama.cpp doesn't support the template by default,
// you'll need to write your own transformer to format the prompt correctly
session.WithHistoryTransform(new PromptTemplateTransformer(model, withAssistant: true));

// Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed sometimes
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
new string[] { "User:", "Assistant:", "�" },
[model.Tokens.EndOfTurnToken!, "�"],
redundancyLength: 5));

InferenceParams inferenceParams = new InferenceParams()
var inferenceParams = new InferenceParams()
{
MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
Temperature = 0.6f,
AntiPrompts = new List<string> { "User:" }
AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string
};

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";
Console.Write("User> ");
var userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Assistant> ");

// as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc
await foreach (
var text
in session.ChatAsync(
Expand All @@ -56,71 +70,8 @@ in session.ChatAsync(
Console.WriteLine();

Console.ForegroundColor = ConsoleColor.Green;
Console.Write("User> ");
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}

class LLama3HistoryTransform : IHistoryTransform
{
/// <summary>
/// Convert a ChatHistory instance to plain text.
/// </summary>
/// <param name="history">The ChatHistory instance</param>
/// <returns></returns>
public string HistoryToText(ChatHistory history)
{
string res = Bos;
foreach (var message in history.Messages)
{
res += EncodeMessage(message);
}
res += EncodeHeader(new ChatHistory.Message(AuthorRole.Assistant, ""));
return res;
}

private string EncodeHeader(ChatHistory.Message message)
{
string res = StartHeaderId;
res += message.AuthorRole.ToString();
res += EndHeaderId;
res += "\n\n";
return res;
}

private string EncodeMessage(ChatHistory.Message message)
{
string res = EncodeHeader(message);
res += message.Content;
res += EndofTurn;
return res;
}

/// <summary>
/// Converts plain text to a ChatHistory instance.
/// </summary>
/// <param name="role">The role for the author.</param>
/// <param name="text">The chat history as plain text.</param>
/// <returns>The updated history.</returns>
public ChatHistory TextToHistory(AuthorRole role, string text)
{
return new ChatHistory(new ChatHistory.Message[] { new ChatHistory.Message(role, text) });
}

/// <summary>
/// Copy the transform.
/// </summary>
/// <returns></returns>
public IHistoryTransform Clone()
{
return new LLama3HistoryTransform();
}

private const string StartHeaderId = "<|start_header_id|>";
private const string EndHeaderId = "<|end_header_id|>";
private const string Bos = "<|begin_of_text|>";
private const string Eos = "<|end_of_text|>";
private const string EndofTurn = "<|eot_id|>";
}
}
39 changes: 39 additions & 0 deletions LLama.Unittest/Native/SafeLlamaModelHandleTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using System.Text;
using LLama.Common;
using LLama.Native;
using LLama.Extensions;

namespace LLama.Unittest.Native;

public class SafeLlamaModelHandleTests
{
private readonly LLamaWeights _model;
private readonly SafeLlamaModelHandle TestableHandle;

public SafeLlamaModelHandleTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 1,
GpuLayerCount = Constants.CIGpuLayerCount
};
_model = LLamaWeights.LoadFromFile(@params);

TestableHandle = _model.NativeHandle;
}

[Fact]
public void MetadataValByKey_ReturnsCorrectly()
{
const string key = "general.name";
var template = _model.NativeHandle.MetadataValueByKey(key);
var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span);

const string expected = "LLaMA v2";
Assert.Equal(expected, name);

var metadataLookup = _model.Metadata[key];
Assert.Equal(expected, metadataLookup);
Assert.Equal(name, metadataLookup);
}
}
78 changes: 46 additions & 32 deletions LLama.Unittest/TemplateTests.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using System.Text;
using LLama.Common;
using LLama.Native;
using LLama.Extensions;

namespace LLama.Unittest;

public sealed class TemplateTests
: IDisposable
{
private readonly LLamaWeights _model;

public TemplateTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
Expand All @@ -18,12 +18,12 @@ public TemplateTests()
};
_model = LLamaWeights.LoadFromFile(@params);
}

public void Dispose()
{
_model.Dispose();
}

[Fact]
public void BasicTemplate()
{
Expand All @@ -47,18 +47,10 @@ public void BasicTemplate()
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(8, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

var dest = templater.Apply();
Assert.Equal(8, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
Expand Down Expand Up @@ -93,17 +85,10 @@ public void CustomTemplate()
Assert.Equal(4, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

var dest = templater.Apply();
Assert.Equal(4, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

Assert.Equal(4, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<start_of_turn>model\n" +
"hello<end_of_turn>\n" +
"<start_of_turn>user\n" +
Expand Down Expand Up @@ -143,17 +128,10 @@ public void BasicTemplateWithAddAssistant()
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

var dest = templater.Apply();
Assert.Equal(8, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

Assert.Equal(8, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
Expand Down Expand Up @@ -249,4 +227,40 @@ public void RemoveOutOfRange()
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1));
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2));
}

[Fact]
public void Clear_ResetsTemplateState()
{
var templater = new LLamaTemplate(_model);
templater.Add("assistant", "1")
.Add("user", "2");

Assert.Equal(2, templater.Count);

templater.Clear();

Assert.Equal(0, templater.Count);

const string userData = nameof(userData);
templater.Add("user", userData);

// Generte the template string
var dest = templater.Apply();
var templateResult = Encoding.UTF8.GetString(dest);

const string expectedTemplate = $"<|im_start|>user\n{userData}<|im_end|>\n";
Assert.Equal(expectedTemplate, templateResult);
}

[Fact]
public void EndOTurnToken_ReturnsExpected()
{
Assert.Null(_model.Tokens.EndOfTurnToken);
}

[Fact]
public void EndOSpeechToken_ReturnsExpected()
{
Assert.Equal("</s>", _model.Tokens.EndOfSpeechToken);
}
}
Loading