diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index 790a1f9c6..b74170e3a 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -8,6 +8,7 @@ public class ExampleRunner { "Chat Session: History", ChatSessionWithHistory.Run }, { "Chat Session: Role names", ChatSessionWithRoleName.Run }, { "Chat Session: Role names stripped", ChatSessionStripRoleName.Run }, + { "Chat Session: Pre-processing and reset", ChatSessionWithRestart.Run }, { "Chat Session: Coding Assistant", CodingAssistant.Run }, { "Chat Session: Automatic conversation", TalkToYourself.Run }, { "Chat Session: Chinese characters", ChatChineseGB2312.Run }, diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index 17908908d..31b6a7718 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -48,6 +48,10 @@ public static async Task Run() Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started."); + Console.WriteLine("Type 'exit' to end the chat session."); + Console.WriteLine("Type 'save' to save the chat session to disk."); + Console.WriteLine("Type 'load' to load the chat session from disk."); + Console.WriteLine("Type 'regenerate' to regenerate the last response."); // show the prompt Console.ForegroundColor = ConsoleColor.Green; @@ -55,12 +59,20 @@ public static async Task Run() while (userInput != "exit") { + // Save the chat state to disk if (userInput == "save") { session.SaveSession("Assets/chat-with-bob"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session saved."); } + // Load the chat state from disk + else if (userInput == "load") + { + session.LoadSession("Assets/chat-with-bob"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session loaded."); + } else if (userInput == "regenerate") { Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs new file mode 100644 index 000000000..923f78f67 --- /dev/null +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -0,0 +1,107 @@ +using LLama.Common; + +namespace LLama.Examples.Examples; + +public class ChatSessionWithRestart +{ + public static async Task Run() + { + string modelPath = UserSettings.GetModelPath(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + ChatSession prototypeSession = + await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); + prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( + new string[] { "User:", "Assistant:" }, + redundancyLength: 8)); + var resetState = prototypeSession.GetSessionState(); + + ChatSession session = new ChatSession(executor); + session.LoadSession(resetState); + + InferenceParams inferenceParams = new InferenceParams() + { + Temperature = 0.9f, + AntiPrompts = new List { "User:" } + }; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started. Starting point saved."); + Console.WriteLine("Type 'exit' to end the chat session."); + Console.WriteLine("Type 'save' to save chat session state in memory."); + Console.WriteLine("Type 'reset' to reset the chat session to its saved state."); + Console.WriteLine("Type 'answer for assistant' to add and process provided user and assistant messages."); + + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + string userInput = Console.ReadLine() ?? ""; + + while (userInput != "exit") + { + // Load the session state from the reset state + if(userInput == "reset") + { + session.LoadSession(resetState); + Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session reset."); + } + // Assign new reset state. + else if (userInput == "save") + { + resetState = session.GetSessionState(); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session saved."); + } + // Provide user and override assistant answer with your own. + else if (userInput == "answer for assistant") + { + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Provide user input: "); + + Console.ForegroundColor = ConsoleColor.Green; + string userInputOverride = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Provide assistant input: "); + + Console.ForegroundColor = ConsoleColor.Green; + string assistantInputOverride = Console.ReadLine() ?? ""; + + await session.AddAndProcessUserMessage(userInputOverride); + await session.AddAndProcessAssistantMessage(assistantInputOverride); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("User and assistant messages processed. Provide next user message:"); + } + else + { + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + } + + Console.ForegroundColor = ConsoleColor.Green; + userInput = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.White; + } + } +} diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs index c9217ae0f..9644b3e1d 100644 --- a/LLama/Abstractions/IHistoryTransform.cs +++ b/LLama/Abstractions/IHistoryTransform.cs @@ -1,10 +1,12 @@ using LLama.Common; +using System.Text.Json.Serialization; namespace LLama.Abstractions { /// /// Transform history to plain text and vice versa. /// + [JsonConverter(typeof(PolymorphicJSONConverter))] public interface IHistoryTransform { /// @@ -21,5 +23,11 @@ public interface IHistoryTransform /// The chat history as plain text. /// The updated history. ChatHistory TextToHistory(AuthorRole role, string text); + + /// + /// Copy the transform. + /// + /// + IHistoryTransform Clone(); } } diff --git a/LLama/Abstractions/ITextStreamTransform.cs b/LLama/Abstractions/ITextStreamTransform.cs index 2725214f5..3ebdba675 100644 --- a/LLama/Abstractions/ITextStreamTransform.cs +++ b/LLama/Abstractions/ITextStreamTransform.cs @@ -1,10 +1,13 @@ -using System.Collections.Generic; +using LLama.Common; +using System.Collections.Generic; +using System.Text.Json.Serialization; namespace LLama.Abstractions { /// /// Takes a stream of tokens and transforms them. /// + [JsonConverter(typeof(PolymorphicJSONConverter))] public interface ITextStreamTransform { /// @@ -13,5 +16,11 @@ public interface ITextStreamTransform /// /// IAsyncEnumerable TransformAsync(IAsyncEnumerable tokens); + + /// + /// Copy the transform. + /// + /// + ITextStreamTransform Clone(); } } diff --git a/LLama/Abstractions/ITextTransform.cs b/LLama/Abstractions/ITextTransform.cs index ac196644e..f6f743f9f 100644 --- a/LLama/Abstractions/ITextTransform.cs +++ b/LLama/Abstractions/ITextTransform.cs @@ -1,4 +1,7 @@ -namespace LLama.Abstractions +using System.Text.Json.Serialization; +using LLama.Common; + +namespace LLama.Abstractions { /// /// An interface for text transformations. @@ -9,6 +12,7 @@ /// - Trimming /// - etc. /// + [JsonConverter(typeof(PolymorphicJSONConverter))] public interface ITextTransform { /// @@ -17,5 +21,11 @@ public interface ITextTransform /// /// string Transform(string text); + + /// + /// Copy the transform. + /// + /// + ITextTransform Clone(); } } diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 45985b21c..0a5accc5e 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -3,11 +3,14 @@ using System.IO; using System.Linq; using System.Runtime.CompilerServices; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; using LLama.Common; using static LLama.InteractiveExecutor; +using static LLama.LLamaContext; +using static LLama.StatefulExecutorBase; namespace LLama; @@ -16,9 +19,30 @@ namespace LLama; /// public class ChatSession { - private const string _modelStateFilename = "ModelState.st"; - private const string _executorStateFilename = "ExecutorState.json"; - private const string _hsitoryFilename = "ChatHistory.json"; + /// + /// The filename for the serialized model state (KV cache, etc). + /// + public const string MODEL_STATE_FILENAME = "ModelState.st"; + /// + /// The filename for the serialized executor state. + /// + public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json"; + /// + /// The filename for the serialized chat history. + /// + public const string HISTORY_STATE_FILENAME = "ChatHistory.json"; + /// + /// The filename for the serialized input transform pipeline. + /// + public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json"; + /// + /// The filename for the serialized output transform. + /// + public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json"; + /// + /// The filename for the serialized history transform. + /// + public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json"; /// /// The executor for this session. @@ -45,6 +69,24 @@ public class ChatSession /// public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); + /// + /// Create a new chat session and preprocess history. + /// + /// The executor for this session + /// History for this session + /// + public static async Task InitializeSessionFromHistoryAsync( + ILLamaExecutor executor, ChatHistory history) + { + if (executor is not StatefulExecutorBase statefulExecutor) + { + throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); + } + var session = new ChatSession(executor, history); + await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history)); + return session; + } + /// /// Create a new chat session. /// @@ -112,56 +154,76 @@ public ChatSession WithOutputTransform(ITextStreamTransform transform) /// public void SaveSession(string path) { - if (string.IsNullOrWhiteSpace(path)) + GetSessionState().Save(path); + } + + /// + /// Get the session state. + /// + /// SessionState object representing session state in-memory + public SessionState GetSessionState() + { + var executorState = ((StatefulExecutorBase)Executor).GetStateData(); + return new SessionState( + executorState.PastTokensCount > 0 + ? Executor.Context.GetState() : null, + executorState, + History, + InputTransformPipeline, + OutputTransform, + HistoryTransform); + } + + /// + /// Load a session from a session state. + /// + /// + /// If true loads transforms saved in the session state. + /// + /// + public void LoadSession(SessionState state, bool loadTransforms = true) + { + if (Executor is StatefulExecutorBase statefulExecutor) { - throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + if (state.ExecutorState is not null) + { + statefulExecutor.LoadState(state.ExecutorState); + } } - - if (Directory.Exists(path)) + if (state.ContextState is null) { - Directory.Delete(path, recursive: true); + Executor.Context.NativeHandle.KvCacheClear(); + } + else + { + Executor.Context.LoadState(state.ContextState); + } + History = new ChatHistory(state.History); + if (loadTransforms) + { + InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); + OutputTransform = state.OutputTransform.Clone(); + HistoryTransform = state.HistoryTransform.Clone(); } - - Directory.CreateDirectory(path); - - string modelStateFilePath = Path.Combine(path, _modelStateFilename); - Executor.Context.SaveState(modelStateFilePath); - - string executorStateFilepath = Path.Combine(path, _executorStateFilename); - ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath); - - string historyFilepath = Path.Combine(path, _hsitoryFilename); - File.WriteAllText(historyFilepath, History.ToJson()); } /// /// Load a session from a directory. /// /// + /// If true loads transforms saved in the session state. /// /// - public void LoadSession(string path) + public void LoadSession(string path, bool loadTransforms = true) { - if (string.IsNullOrWhiteSpace(path)) + var state = SessionState.Load(path); + // Handle non-polymorphic serialization of executor state + if (state.ExecutorState is null) { - throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); - } - - if (!Directory.Exists(path)) - { - throw new ArgumentException("Directory does not exist", nameof(path)); + var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); + ((StatefulExecutorBase) Executor).LoadState(filename: executorPath); } - - string modelStateFilePath = Path.Combine(path, _modelStateFilename); - Executor.Context.LoadState(modelStateFilePath); - - string executorStateFilepath = Path.Combine(path, _executorStateFilename); - ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath); - - string historyFilepath = Path.Combine(path, _hsitoryFilename); - string historyJson = File.ReadAllText(historyFilepath); - History = ChatHistory.FromJson(historyJson) - ?? throw new ArgumentException("History file is invalid", nameof(path)); + LoadSession(state, loadTransforms); } /// @@ -238,6 +300,49 @@ public ChatSession RemoveLastMessage() return this; } + /// + /// Compute KV cache for the message and add it to the chat history. + /// + /// + /// + public async Task AddAndProcessMessage(ChatHistory.Message message) + { + if (Executor is not StatefulExecutorBase statefulExecutor) + { + throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); + } + AddMessage(message); + var content = message.Content; + if (message.AuthorRole != AuthorRole.Assistant) + { + foreach (var inputTransform in InputTransformPipeline) + { + content = inputTransform.Transform(content); + } + } + + await statefulExecutor.PrefillPromptAsync(content); + return this; + } + + /// + /// Compute KV cache for the system message and add it to the chat history. + /// + public Task AddAndProcessSystemMessage(string content) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content)); + + /// + /// Compute KV cache for the user message and add it to the chat history. + /// + public Task AddAndProcessUserMessage(string content) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content)); + + /// + /// Compute KV cache for the assistant message and add it to the chat history. + /// + public Task AddAndProcessAssistantMessage(string content) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + /// /// Replace a user message with a new message and remove all messages after the new message. /// This is useful when the user wants to edit a message. And regenerate the response. @@ -494,3 +599,185 @@ in OutputTransform } } } + +/// +/// The state of a chat session in-memory. +/// +public record SessionState +{ + /// + /// Saved executor state for the session in JSON format. + /// + public ExecutorBaseState? ExecutorState { get; set; } + + /// + /// Saved context state (KV cache) for the session. + /// + public State? ContextState { get; set; } + + /// + /// The input transform pipeline used in this session. + /// + public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty(); + + /// + /// The output transform used in this session. + /// + public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform(); + + /// + /// The history transform used in this session. + /// + public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); + + /// + /// The the chat history messages for this session. + /// + public ChatHistory.Message[] History { get; set; } = Array.Empty(); + + /// + /// Create a new session state. + /// + /// + /// + /// + /// + /// + /// + public SessionState( + State? contextState, ExecutorBaseState executorState, + ChatHistory history, List inputTransformPipeline, + ITextStreamTransform outputTransform, IHistoryTransform historyTransform) + { + ContextState = contextState; + ExecutorState = executorState; + History = history.Messages.ToArray(); + InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray(); + OutputTransform = outputTransform.Clone(); + HistoryTransform = historyTransform.Clone(); + } + + /// + /// Save the session state to folder. + /// + /// + public void Save(string path) + { + if (string.IsNullOrWhiteSpace(path)) + { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } + + if (Directory.Exists(path)) + { + Directory.Delete(path, recursive: true); + } + + Directory.CreateDirectory(path); + + string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); + var bytes = ContextState?.ToByteArray(); + if (bytes is not null) + { + File.WriteAllBytes(modelStateFilePath, bytes); + } + + string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); + File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); + + string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); + File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson()); + + string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); + File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline)); + + string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); + File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform)); + + string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); + File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform)); + } + + /// + /// Load the session state from folder. + /// + /// + /// + /// Throws when session state is incorrect + public static SessionState Load(string path) + { + if (string.IsNullOrWhiteSpace(path)) + { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } + + if (!Directory.Exists(path)) + { + throw new ArgumentException("Directory does not exist", nameof(path)); + } + + string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); + var contextState = File.Exists(modelStateFilePath) ? + State.FromByteArray(File.ReadAllBytes(modelStateFilePath)) + : null; + + string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); + var executorState = JsonSerializer.Deserialize(File.ReadAllText(executorStateFilepath)); + + string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); + string historyJson = File.ReadAllText(historyFilepath); + var history = ChatHistory.FromJson(historyJson) + ?? throw new ArgumentException("History file is invalid", nameof(path)); + + string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); + ITextTransform[] inputTransforms; + try + { + inputTransforms = File.Exists(inputTransformFilepath) ? + (JsonSerializer.Deserialize(File.ReadAllText(inputTransformFilepath)) + ?? throw new ArgumentException("Input transform file is invalid", nameof(path))) + : Array.Empty(); + } + catch (JsonException) + { + throw new ArgumentException("Input transform file is invalid", nameof(path)); + } + + string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); + + ITextStreamTransform outputTransform; + try + { + outputTransform = File.Exists(outputTransformFilepath) ? + (JsonSerializer.Deserialize(File.ReadAllText(outputTransformFilepath)) + ?? throw new ArgumentException("Output transform file is invalid", nameof(path))) + : new LLamaTransforms.EmptyTextOutputStreamTransform(); + } + catch (JsonException) + { + throw new ArgumentException("Output transform file is invalid", nameof(path)); + } + + string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); + IHistoryTransform historyTransform; + try + { + historyTransform = File.Exists(historyTransformFilepath) ? + (JsonSerializer.Deserialize(File.ReadAllText(historyTransformFilepath)) + ?? throw new ArgumentException("History transform file is invalid", nameof(path))) + : new LLamaTransforms.DefaultHistoryTransform(); + } + catch (JsonException) + { + throw new ArgumentException("History transform file is invalid", nameof(path)); + } + + return new SessionState( + contextState, + executorState, + history, + inputTransforms.ToList(), + outputTransform, + historyTransform); + } +} \ No newline at end of file diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index dc7414490..c22cc7c06 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; @@ -80,6 +81,15 @@ public Message(AuthorRole authorRole, string content) [JsonConstructor] public ChatHistory() { } + /// + /// Create a new instance of the chat history from array of messages + /// + /// + public ChatHistory(Message[] messageHistory) + { + this.Messages = messageHistory.ToList(); + } + /// /// Add a message to the chat history /// diff --git a/LLama/Common/PolymorphicJSONConverter.cs b/LLama/Common/PolymorphicJSONConverter.cs new file mode 100644 index 000000000..1af4011cc --- /dev/null +++ b/LLama/Common/PolymorphicJSONConverter.cs @@ -0,0 +1,57 @@ +using LLama.Abstractions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace LLama.Common +{ + internal class PolymorphicJSONConverter : JsonConverter + { + public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + throw new JsonException(); + reader.Read(); + if (reader.TokenType != JsonTokenType.PropertyName) + throw new JsonException(); + string? propertyName = reader.GetString(); + if (propertyName != "Name") + return default; + reader.Read(); + if (reader.TokenType != JsonTokenType.String) + throw new JsonException(); + string? name = reader.GetString() ?? throw new JsonException(); + var inheritedTypes = Assembly.GetExecutingAssembly().GetTypes().Where( + t => typeof(T).IsAssignableFrom(t) && !t.IsAbstract && !t.IsInterface + ); + var type = inheritedTypes.FirstOrDefault(t => t.Name == name); + if (type == null) + throw new JsonException(); + reader.Read(); + if (reader.TokenType != JsonTokenType.PropertyName) + throw new JsonException(); + propertyName = reader.GetString(); + if (propertyName != "Data") + throw new JsonException(); + var data = JsonSerializer.Deserialize(ref reader, type, options); + if (data == null) + throw new JsonException(); + reader.Read(); + reader.Read(); + return (T)data; + } + + public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + writer.WriteString("Name", value.GetType().Name); + writer.WritePropertyName("Data"); + JsonSerializer.Serialize(writer, value, value.GetType(), options); + writer.WriteEndObject(); + } + } +} diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index d8b418c31..4a63be362 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -166,7 +166,7 @@ public State GetState() memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); // Wrap memory in a "state" - var state = new State(memory); + var state = new State(memory, actualSize); // Set memory to zero, to prevent it being freed in finally block memory = IntPtr.Zero; @@ -384,9 +384,12 @@ public void Dispose() public class State : SafeLLamaHandleBase { - internal State(IntPtr memory) + private ulong _size; + + internal State(IntPtr memory, ulong size) : base(memory, true) { + _size = size; } /// @@ -395,6 +398,29 @@ protected override bool ReleaseHandle() Marshal.FreeHGlobal(handle); return true; } + + /// + /// Convert this state to a byte array + /// + /// + public byte[] ToByteArray() + { + var bytes = new byte[_size]; + Marshal.Copy(handle, bytes, 0, (int)_size); + return bytes; + } + + /// + /// Load state from a byte array + /// + /// + /// + public static State FromByteArray(byte[] bytes) + { + var memory = Marshal.AllocHGlobal(bytes.Length); + Marshal.Copy(bytes, 0, memory, bytes.Length); + return new State(memory, (ulong)bytes.Length); + } } } } diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 3a697507b..ec72a25ad 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -315,6 +315,34 @@ public virtual async IAsyncEnumerable InferAsync(string text, IInference } } + /// + /// Asynchronously runs a prompt through the model to compute KV cache without generating any new tokens. + /// It could reduce the latency of the first time response if the first input from the user is not immediate. + /// + /// Prompt to process + /// + public virtual async Task PrefillPromptAsync(string prompt) + { + var inferenceParams = new InferenceParams + { + MaxTokens = 0 + }; + var args = new InferStateArgs + { + Antiprompts = new List(), + RemainedTokens = 0, + ReturnValue = false, + WaitForInput = true, + NeedToSaveSession = false + }; + + await PreprocessInputs(prompt, args); + // First run adds the prompt to the _embeds + await InferInternal(inferenceParams, args); + // Second run puts it through decode + await InferInternal(inferenceParams, args); + } + /// /// State arguments that are used in single inference /// @@ -342,6 +370,7 @@ protected class InferStateArgs public bool NeedToSaveSession { get; set; } } + [JsonConverter(typeof(PolymorphicJSONConverter))] public class ExecutorBaseState { [JsonPropertyName("n_past")] @@ -360,13 +389,13 @@ public class ExecutorBaseState public string? SessionFilePath { get; set; } [JsonPropertyName("embd")] - public List Embeds { get; set; } + public LLamaToken[] Embeds { get; set; } [JsonPropertyName("embd_inps")] - public List EmbedInps { get; set; } + public LLamaToken[] EmbedInps { get; set; } [JsonPropertyName("session_tokens")] - public List SessionTokens { get; set; } + public LLamaToken[] SessionTokens { get; set; } [JsonPropertyName("last_n_tokens")] public LLamaToken[] LastTokens { get; set; } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 9476976e9..99d45e5a5 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -49,17 +49,17 @@ public override ExecutorBaseState GetStateData() InstructExecutorState state = new() { ConsumedSessionCount = _n_session_consumed, - EmbedInps = _embed_inps, + EmbedInps = _embed_inps.ToArray(), IsPromptRun = _is_prompt_run, ConsumedTokensCount = _consumedTokensCount, - Embeds = _embeds, + Embeds = _embeds.ToArray(), LastTokens = _last_n_tokens.ToArray(), InputPrefixTokens = _inp_pfx, InputSuffixTokens = _inp_sfx, MatchingSessionTokensCount = _n_matching_session_tokens, PastTokensCount = _pastTokensCount, SessionFilePath = _pathSession, - SessionTokens = _session_tokens, + SessionTokens = _session_tokens.ToArray(), LastTokensCapacity = _last_n_tokens.Capacity, MirostatMu = MirostatMu }; @@ -71,17 +71,17 @@ public override Task LoadState(ExecutorBaseState data) if(data is InstructExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; - _embed_inps = state.EmbedInps; + _embed_inps = state.EmbedInps.ToList(); _is_prompt_run = state.IsPromptRun; _consumedTokensCount = state.ConsumedTokensCount; - _embeds = state.Embeds; + _embeds = state.Embeds.ToList(); _last_n_tokens = new FixedSizeQueue(state.LastTokensCapacity, state.LastTokens); _inp_pfx = state.InputPrefixTokens; _inp_sfx = state.InputSuffixTokens; _n_matching_session_tokens = state.MatchingSessionTokensCount; _pastTokensCount = state.PastTokensCount; _pathSession = state.SessionFilePath; - _session_tokens = state.SessionTokens; + _session_tokens = state.SessionTokens.ToList(); } else { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 79f1b8cc4..2a14eeaf5 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -39,15 +39,15 @@ public override ExecutorBaseState GetStateData() InteractiveExecutorState state = new() { ConsumedSessionCount = _n_session_consumed, - EmbedInps = _embed_inps, + EmbedInps = _embed_inps.ToArray(), IsPromptRun = _is_prompt_run, ConsumedTokensCount = _consumedTokensCount, - Embeds = _embeds, + Embeds = _embeds.ToArray(), LastTokens = _last_n_tokens.ToArray(), MatchingSessionTokensCount = _n_matching_session_tokens, PastTokensCount = _pastTokensCount, SessionFilePath = _pathSession, - SessionTokens = _session_tokens, + SessionTokens = _session_tokens.ToArray(), LastTokensCapacity = _last_n_tokens.Capacity, MirostatMu = MirostatMu }; @@ -59,15 +59,15 @@ public override Task LoadState(ExecutorBaseState data) if (data is InteractiveExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; - _embed_inps = state.EmbedInps; + _embed_inps = state.EmbedInps.ToList(); _is_prompt_run = state.IsPromptRun; _consumedTokensCount = state.ConsumedTokensCount; - _embeds = state.Embeds; + _embeds = state.Embeds.ToList(); _last_n_tokens = new FixedSizeQueue(state.LastTokensCapacity, state.LastTokens); _n_matching_session_tokens = state.MatchingSessionTokensCount; _pastTokensCount = state.PastTokensCount; _pathSession = state.SessionFilePath; - _session_tokens = state.SessionTokens; + _session_tokens = state.SessionTokens.ToList(); } else throw new ArgumentException("Invalid state data type."); diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs index 29c16c187..d74d9ddaf 100644 --- a/LLama/LLamaTransforms.cs +++ b/LLama/LLamaTransforms.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using System.Text.Json.Serialization; namespace LLama { @@ -29,6 +30,12 @@ public class DefaultHistoryTransform : IHistoryTransform private readonly string _unknownName; private readonly bool _isInstructMode; + public string UserName => _userName; + public string AssistantName => _assistantName; + public string SystemName => _systemName; + public string UnknownName => _unknownName; + public bool IsInstructMode => _isInstructMode; + /// /// /// @@ -47,6 +54,12 @@ public DefaultHistoryTransform(string? userName = null, string? assistantName = _isInstructMode = isInstructMode; } + /// + public IHistoryTransform Clone() + { + return new DefaultHistoryTransform(_userName, _assistantName, _systemName, _unknownName, _isInstructMode); + } + /// public virtual string HistoryToText(ChatHistory history) { @@ -116,6 +129,12 @@ public string Transform(string text) { return text.Trim(); } + + /// + public ITextTransform Clone() + { + return new NaiveTextInputTransform(); + } } /// @@ -129,6 +148,12 @@ public IAsyncEnumerable TransformAsync(IAsyncEnumerable tokens) { return tokens; } + + /// + public ITextStreamTransform Clone() + { + return new EmptyTextOutputStreamTransform(); + } } /// @@ -140,6 +165,42 @@ public class KeywordTextOutputStreamTransform : ITextStreamTransform private readonly int _maxKeywordLength; private readonly bool _removeAllMatchedTokens; + /// + /// Keywords that you want to remove from the response. + /// This property is used for JSON serialization. + /// + [JsonPropertyName("keywords")] + public HashSet Keywords => _keywords; + + /// + /// Maximum length of the keywords. + /// This property is used for JSON serialization. + /// + [JsonPropertyName("maxKeywordLength")] + public int MaxKeywordLength => _maxKeywordLength; + + /// + /// If set to true, when getting a matched keyword, all the related tokens will be removed. + /// Otherwise only the part of keyword will be removed. + /// This property is used for JSON serialization. + /// + [JsonPropertyName("removeAllMatchedTokens")] + public bool RemoveAllMatchedTokens => _removeAllMatchedTokens; + + /// + /// JSON constructor. + /// + [JsonConstructor] + public KeywordTextOutputStreamTransform( + HashSet keywords, + int maxKeywordLength, + bool removeAllMatchedTokens) + { + _keywords = new(keywords); + _maxKeywordLength = maxKeywordLength; + _removeAllMatchedTokens = removeAllMatchedTokens; + } + /// /// /// @@ -157,6 +218,12 @@ public KeywordTextOutputStreamTransform(IEnumerable keywords, int redund _removeAllMatchedTokens = removeAllMatchedTokens; } + /// + public ITextStreamTransform Clone() + { + return new KeywordTextOutputStreamTransform(_keywords, _maxKeywordLength, _removeAllMatchedTokens); + } + /// public async IAsyncEnumerable TransformAsync(IAsyncEnumerable tokens) {