diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 1d54a7e94..ea02280d6 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -1,6 +1,4 @@ -using System.Text; using LLama.Common; -using LLama.Native; using Xunit.Abstractions; namespace LLama.Unittest diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index 3d75a0cbc..c8d23c2ed 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -47,12 +47,6 @@ public class ModelOptions /// public string ModelPath { get; set; } - /// - public AdapterCollection LoraAdapters { get; set; } = new(); - - /// - public string LoraBase { get; set; } = string.Empty; - /// public uint? Threads { get; set; } diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index e6a64c930..7dc28f671 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -2,7 +2,6 @@ using System.Collections; using System.Collections.Generic; using System.ComponentModel; -using System.Linq; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -69,67 +68,12 @@ public interface IModelParams /// bool VocabOnly { get; } - /// - /// List of LoRA adapters to apply - /// - AdapterCollection LoraAdapters { get; } - - /// - /// base model path for the lora adapter (lora_base) - /// - string LoraBase { get; } - /// /// Override specific metadata items in the model /// List MetadataOverrides { get; } } - /// - /// A LoRA adapter to apply to a model - /// - /// Path to the LoRA file - /// Strength of this LoRA - public readonly record struct LoraAdapter(string Path, float Scale); - - /// - /// A list of LoraAdapter objects - /// - public sealed class AdapterCollection - : List, IEquatable - { - /// - public bool Equals(AdapterCollection? other) - { - if (other == null) - return false; - - return this.SequenceEqual(other); - } - - /// - public override bool Equals(object? obj) - { - return Equals(obj as AdapterCollection); - } - - /// - public override int GetHashCode() - { - unchecked - { - var hash = 17; - for (var i = 0; i < Count; i++) - { - hash += this[i].GetHashCode(); - hash *= 7823; - } - return hash; - } - } - } - - /// /// A fixed size array to set the tensor splits across multiple GPUs /// diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index b93959a49..2f2ab4eb6 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -39,12 +39,6 @@ public record ModelParams /// public string ModelPath { get; set; } - /// - public AdapterCollection LoraAdapters { get; set; } = new(); - - /// - public string LoraBase { get; set; } = string.Empty; - /// public uint? Threads { get; set; } diff --git a/LLama/LLamaQuantizer.cs b/LLama/LLamaQuantizer.cs index 2ddbe2273..23f0f8b4e 100644 --- a/LLama/LLamaQuantizer.cs +++ b/LLama/LLamaQuantizer.cs @@ -62,7 +62,7 @@ public static bool Quantize(string srcFileName, string dstFilename, string ftype private static bool ValidateFtype(LLamaFtype ftype) { // Validation copies from here: - // https://github.com/ggerganov/llama.cpp/blob/f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7/llama.cpp#L13450 + // https://github.com/ggerganov/llama.cpp/blob/345c8c0c87a97c1595f9c8b14833d531c8c7d8df/src/llama.cpp#L15624 switch (ftype) { @@ -105,9 +105,12 @@ private static bool ValidateFtype(LLamaFtype ftype) case LLamaFtype.MOSTLY_IQ3_S: case LLamaFtype.MOSTLY_IQ3_M: + + case LLamaFtype.MOSTLY_Q4_0_4_4: + case LLamaFtype.MOSTLY_Q4_0_4_8: + case LLamaFtype.MOSTLY_Q4_0_8_8: return true; - case LLamaFtype.MOSTLY_Q4_1_SOME_F16: case LLamaFtype.GUESSED: default: return false; diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index 272c81e37..efc639e72 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -53,7 +53,7 @@ - 368645698ab648e390dc + 345c8c0c87a97c1595f9c8b diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 8646e4d93..a134f7aca 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Text; using System.Threading; @@ -72,17 +72,6 @@ public static LLamaWeights LoadFromFile(IModelParams @params) { using var pin = @params.ToLlamaModelParams(out var lparams); var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); - - foreach (var adapter in @params.LoraAdapters) - { - if (string.IsNullOrEmpty(adapter.Path)) - continue; - if (adapter.Scale <= 0) - continue; - - weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase); - } - return new LLamaWeights(weights); } @@ -100,14 +89,6 @@ public static async Task LoadFromFileAsync(IModelParams @params, C // don't touch the @params object inside the task, it might be changed // externally! Save a copy of everything that we need later. var modelPath = @params.ModelPath; - var loraBase = @params.LoraBase; - var loraAdapters = @params.LoraAdapters.ToArray(); - - // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a - // slightly smaller range to allow some space for reporting LoRA loading too. - var modelLoadProgressRange = 1f; - if (loraAdapters.Length > 0) - modelLoadProgressRange = 0.9f; using (@params.ToLlamaModelParams(out var lparams)) { @@ -119,7 +100,7 @@ public static async Task LoadFromFileAsync(IModelParams @params, C lparams.progress_callback = (progress, ctx) => { // Update the progress reporter (remapping the value into the smaller range). - progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange); + progressReporter?.Report(Math.Clamp(progress, 0, 1)); // If the user set a callback in the model params, call that and see if we should cancel if (internalCallback != null && !internalCallback(progress, ctx)) @@ -141,30 +122,6 @@ public static async Task LoadFromFileAsync(IModelParams @params, C // Load the model var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); - // Apply the LoRA adapters - for (var i = 0; i < loraAdapters.Length; i++) - { - // Interrupt applying LoRAs if the token is cancelled - if (token.IsCancellationRequested) - { - weights.Dispose(); - token.ThrowIfCancellationRequested(); - } - - // Don't apply invalid adapters - var adapter = loraAdapters[i]; - if (string.IsNullOrEmpty(adapter.Path)) - continue; - if (adapter.Scale <= 0) - continue; - - weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase); - - // Report progress. Model loading reported progress from 0 -> 0.9, use - // the last 0.1 to represent all of the LoRA adapters being applied. - progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1)); - } - // Update progress reporter to indicate completion progressReporter?.Report(1); diff --git a/LLama/Native/LLamaAttentionType.cs b/LLama/Native/LLamaAttentionType.cs index 543f89b47..f26c73278 100644 --- a/LLama/Native/LLamaAttentionType.cs +++ b/LLama/Native/LLamaAttentionType.cs @@ -1,5 +1,9 @@ namespace LLama.Native; +/// +/// +/// +/// llama_attention_type public enum LLamaAttentionType { Unspecified = -1, diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs index 15216b71b..4e05de999 100644 --- a/LLama/Native/LLamaFtype.cs +++ b/LLama/Native/LLamaFtype.cs @@ -3,6 +3,7 @@ namespace LLama.Native /// /// Supported model file types /// + /// C# representation of llama_ftype public enum LLamaFtype { /// @@ -35,10 +36,10 @@ public enum LLamaFtype /// Benchmark@7B: 3.90GB, +0.1846 ppl MOSTLY_Q4_1 = 3, - /// - /// Mostly 4 bit, tok_embeddings.weight and output.weight are f16 - /// - MOSTLY_Q4_1_SOME_F16 = 4, + ///// + ///// Mostly 4 bit, tok_embeddings.weight and output.weight are f16 + ///// + //MOSTLY_Q4_1_SOME_F16 = 4, /// /// Mostly 5 bit diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index cc2515639..c341cca83 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -97,7 +97,7 @@ public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? gra using (LLamaTokenDataArrayNative.Create(this, out var st)) { - NativeApi.llama_sample_grammar(ctx, ref st, grammar); + NativeApi.llama_grammar_sample(grammar, ctx, ref st); Sorted = st.sorted; } } diff --git a/LLama/Native/LLamaVocabPreType.cs b/LLama/Native/LLamaVocabPreType.cs index 6a3c89f53..3e5bc287c 100644 --- a/LLama/Native/LLamaVocabPreType.cs +++ b/LLama/Native/LLamaVocabPreType.cs @@ -27,4 +27,7 @@ internal enum LLamaVocabPreType CHATGLM4 = 17, VIKING = 18, JAIS = 19, + TEKKEN = 20, + SMOLLM = 21, + CODESHELL = 22, } \ No newline at end of file diff --git a/LLama/Native/LoraAdapter.cs b/LLama/Native/LoraAdapter.cs new file mode 100644 index 000000000..27d142275 --- /dev/null +++ b/LLama/Native/LoraAdapter.cs @@ -0,0 +1,46 @@ +using System; + +namespace LLama.Native; + +/// +/// A LoRA adapter which can be applied to a context for a specific model +/// +public class LoraAdapter +{ + /// + /// The model which this LoRA adapter was loaded with. + /// + public SafeLlamaModelHandle Model { get; } + + /// + /// The full path of the file this adapter was loaded from + /// + public string Path { get; } + + /// + /// Native pointer of the loaded adapter, will be automatically freed when the model is unloaded + /// + internal IntPtr Pointer { get; } + + /// + /// Indicates if this adapter has been unloaded + /// + internal bool Loaded { get; private set; } + + internal LoraAdapter(SafeLlamaModelHandle model, string path, IntPtr nativePtr) + { + Model = model; + Path = path; + Pointer = nativePtr; + Loaded = true; + } + + /// + /// Unload this adapter + /// + public void Unload() + { + Loaded = false; + NativeApi.llama_lora_adapter_free(Pointer); + } +} \ No newline at end of file diff --git a/LLama/Native/NativeApi.Grammar.cs b/LLama/Native/NativeApi.Grammar.cs index 2e9811371..8e5512e1c 100644 --- a/LLama/Native/NativeApi.Grammar.cs +++ b/LLama/Native/NativeApi.Grammar.cs @@ -35,16 +35,16 @@ public static partial class NativeApi /// /// /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_grammar(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, SafeLLamaGrammarHandle grammar); + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_grammar_sample(SafeLLamaGrammarHandle grammar, SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); - /// - /// Accepts the sampled token into the grammar - /// - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, LLamaToken token); +/// + /// Accepts the sampled token into the grammar + /// + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_grammar_accept_token(SafeLLamaGrammarHandle grammar, SafeLLamaContextHandle ctx, LLamaToken token); } } diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 41d3a130b..f7b97bead 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -445,5 +445,12 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// Returns the split_prefix length. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count); + + /// + /// Manually free a LoRA adapter. loaded adapters will be free when the associated model is deleted + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_lora_adapter_free(IntPtr adapter); } } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index c520a61d1..628936352 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -352,6 +352,58 @@ static SafeLLamaContextHandle() /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern void llama_synchronize(SafeLLamaContextHandle ctx); + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_lora_adapter_set(SafeLLamaContextHandle context, IntPtr adapter, float scale); + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_lora_adapter_remove(SafeLLamaContextHandle context, IntPtr adapter); + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_lora_adapter_clear(SafeLLamaContextHandle context); + #endregion + + #region LoRA + /// + /// Add a LoRA adapter to this context + /// + /// + /// + /// + /// + public void AddLoraAdapter(LoraAdapter lora, float scale) + { + if (lora.Model != ModelHandle) + throw new ArgumentException("Cannot add LoRA adapter which was loaded for a different model"); + if (!lora.Loaded) + throw new ArgumentException("Cannot add LoRA adapter which has been unloaded"); + + var err = llama_lora_adapter_set(this, lora.Pointer, scale); + if (err != 0) + throw new RuntimeError("Failed to set lora adapter"); + } + + /// + /// Remove a LoRA adapter from this context + /// + /// + /// Indicates if the lora was in this context and was remove + public bool RemoveLoraAdapter(LoraAdapter lora) + { + if (lora.Model != ModelHandle) + return false; + + var err = llama_lora_adapter_remove(this, lora.Pointer); + return err == 0; + } + + /// + /// Remove all LoRA adapters from this context + /// + public void ClearLoraAdapters() + { + llama_lora_adapter_clear(this); + } #endregion /// @@ -734,6 +786,16 @@ public void KvCacheSequenceDivide(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int { NativeApi.llama_kv_cache_seq_div(this, seq, p0, p1, divisor); } + + /// + /// Returns the largest position present in the KV cache for the specified sequence + /// + /// + /// + public LLamaPos KvCacheMaxPosition(LLamaSeqId seq) + { + return NativeApi.llama_kv_cache_seq_pos_max(this, seq); + } #endregion } } diff --git a/LLama/Native/SafeLLamaGrammarHandle.cs b/LLama/Native/SafeLLamaGrammarHandle.cs index 53c8fda28..0a81afc12 100644 --- a/LLama/Native/SafeLLamaGrammarHandle.cs +++ b/LLama/Native/SafeLLamaGrammarHandle.cs @@ -111,7 +111,7 @@ public SafeLLamaGrammarHandle Clone() /// public void AcceptToken(SafeLLamaContextHandle ctx, LLamaToken token) { - NativeApi.llama_grammar_accept_token(ctx, this, token); + NativeApi.llama_grammar_accept_token(this, ctx, token); } } } diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 0e55e8eea..eaf76421a 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.IO; using System.Text; +using System.Threading; using LLama.Exceptions; namespace LLama.Native @@ -111,8 +112,7 @@ protected override bool ReleaseHandle() /// /// /// - /// - /// + /// public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams) { // Try to open the model file, this will check: @@ -433,39 +433,32 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] [return: MarshalAs(UnmanagedType.U1)] private static extern bool llama_model_has_encoder(SafeLlamaModelHandle model); + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern IntPtr llama_lora_adapter_init(SafeLlamaModelHandle model, string path); #endregion #region LoRA /// - /// Apply a LoRA adapter to a loaded model + /// Load a LoRA adapter from file. The adapter will be associated with this model but will not be applied /// - /// - /// - /// A path to a higher quality model to use as a base for the layers modified by the - /// adapter. Can be NULL to use the current loaded model. - /// + /// + /// /// - /// - public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null) + public LoraAdapter LoadLoraFromFile(string path) { + path = Path.GetFullPath(path); + // Try to open the model file, this will check: // - File exists (automatically throws FileNotFoundException) // - File is readable (explicit check) // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. - using (var fs = new FileStream(lora, FileMode.Open)) + using (var fs = new FileStream(path, FileMode.Open)) if (!fs.CanRead) - throw new InvalidOperationException($"LoRA file '{lora}' is not readable"); - - var err = llama_model_apply_lora_from_file( - this, - lora, - scale, - string.IsNullOrEmpty(modelBase) ? null : modelBase, - threads ?? Math.Max(1, Environment.ProcessorCount / 2) - ); - - if (err != 0) - throw new RuntimeError($"Failed to apply lora adapter (err={err})."); + throw new InvalidOperationException($"LoRA file '{path}' is not readable"); + + var ptr = llama_lora_adapter_init(this, path); + return new LoraAdapter(this, path, ptr); } #endregion