Skip to content

Commit b868b05

Browse files
committed
Added metadata overrides to IModelParams
1 parent b22d8b7 commit b868b05

File tree

9 files changed

+157
-21
lines changed

9 files changed

+157
-21
lines changed

LLama.Examples/Examples/BatchedDecoding.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Diagnostics;
22
using System.Text;
3+
using LLama.Abstractions;
34
using LLama.Common;
45
using LLama.Native;
56

@@ -30,6 +31,7 @@ public static async Task Run()
3031

3132
// Load model
3233
var parameters = new ModelParams(modelPath);
34+
3335
using var model = LLamaWeights.LoadFromFile(parameters);
3436

3537
// Tokenize prompt

LLama.Examples/LLama.Examples.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<Import Project="..\LLama\LLamaSharp.Runtime.targets" />
33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
5-
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks>
5+
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>
66
<ImplicitUsings>enable</ImplicitUsings>
77
<Nullable>enable</Nullable>
88
<Platforms>AnyCPU;x64</Platforms>

LLama.Examples/Program.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
NativeLibraryConfig
1111
.Instance
1212
.WithCuda()
13-
.WithLogs()
14-
.WithAvx(NativeLibraryConfig.AvxLevel.Avx512);
13+
.WithLogs();
1514

1615
NativeApi.llama_empty_call();
1716
Console.WriteLine();

LLama.Unittest/ModelsParamsTests.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using LLama.Common;
22
using System.Text.Json;
3+
using LLama.Abstractions;
34

45
namespace LLama.Unittest
56
{
@@ -14,7 +15,12 @@ public void SerializeRoundTripSystemTextJson()
1415
ContextSize = 42,
1516
Seed = 42,
1617
GpuLayerCount = 111,
17-
TensorSplits = { [0] = 3 }
18+
TensorSplits = { [0] = 3 },
19+
MetadataOverrides =
20+
{
21+
MetadataOverride.Create("hello", true),
22+
MetadataOverride.Create("world", 17),
23+
}
1824
};
1925

2026
var json = JsonSerializer.Serialize(expected);

LLama.Web/Common/ModelOptions.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ public class ModelOptions
5959
/// <inheritdoc />
6060
public TensorSplitsCollection TensorSplits { get; set; } = new();
6161

62+
/// <inheritdoc />
63+
public List<MetadataOverride> MetadataOverrides { get; } = new();
64+
6265
/// <inheritdoc />
6366
public float? RopeFrequencyBase { get; set; }
6467

LLama/Abstractions/IModelParams.cs

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ public interface IModelParams
5959
/// base model path for the lora adapter (lora_base)
6060
/// </summary>
6161
string LoraBase { get; set; }
62+
63+
/// <summary>
64+
/// Override specific metadata items in the model
65+
/// </summary>
66+
List<MetadataOverride> MetadataOverrides { get; }
6267
}
6368

6469
/// <summary>
@@ -186,7 +191,7 @@ public class TensorSplitsCollectionConverter
186191
: JsonConverter<TensorSplitsCollection>
187192
{
188193
/// <inheritdoc/>
189-
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
194+
public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
190195
{
191196
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
192197
return new TensorSplitsCollection(arr);
@@ -198,4 +203,97 @@ public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value,
198203
JsonSerializer.Serialize(writer, value.Splits, options);
199204
}
200205
}
206+
207+
/// <summary>
208+
/// An override for a single key/value pair in model metadata
209+
/// </summary>
210+
[JsonConverter(typeof(MetadataOverrideConverter))]
211+
public abstract record MetadataOverride
212+
{
213+
/// <summary>
214+
/// Create a new override for an int key
215+
/// </summary>
216+
/// <param name="key"></param>
217+
/// <param name="value"></param>
218+
/// <returns></returns>
219+
public static MetadataOverride Create(string key, int value)
220+
{
221+
return new IntOverride(key, value);
222+
}
223+
224+
/// <summary>
225+
/// Create a new override for a float key
226+
/// </summary>
227+
/// <param name="key"></param>
228+
/// <param name="value"></param>
229+
/// <returns></returns>
230+
public static MetadataOverride Create(string key, float value)
231+
{
232+
return new FloatOverride(key, value);
233+
}
234+
235+
/// <summary>
236+
/// Create a new override for a boolean key
237+
/// </summary>
238+
/// <param name="key"></param>
239+
/// <param name="value"></param>
240+
/// <returns></returns>
241+
public static MetadataOverride Create(string key, bool value)
242+
{
243+
return new BoolOverride(key, value);
244+
}
245+
246+
internal abstract void Write(ref LLamaModelMetadataOverride dest);
247+
248+
/// <summary>
249+
/// Get the key being overriden by this override
250+
/// </summary>
251+
public abstract string Key { get; init; }
252+
253+
private record IntOverride(string Key, int Value) : MetadataOverride
254+
{
255+
internal override void Write(ref LLamaModelMetadataOverride dest)
256+
{
257+
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
258+
dest.IntValue = Value;
259+
}
260+
}
261+
262+
private record FloatOverride(string Key, float Value) : MetadataOverride
263+
{
264+
internal override void Write(ref LLamaModelMetadataOverride dest)
265+
{
266+
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
267+
dest.FloatValue = Value;
268+
}
269+
}
270+
271+
private record BoolOverride(string Key, bool Value) : MetadataOverride
272+
{
273+
internal override void Write(ref LLamaModelMetadataOverride dest)
274+
{
275+
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
276+
dest.BoolValue = Value ? -1 : 0;
277+
}
278+
}
279+
}
280+
281+
public class MetadataOverrideConverter
282+
: JsonConverter<MetadataOverride>
283+
{
284+
/// <inheritdoc/>
285+
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
286+
{
287+
throw new NotImplementedException();
288+
//var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
289+
//return new TensorSplitsCollection(arr);
290+
}
291+
292+
/// <inheritdoc/>
293+
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
294+
{
295+
throw new NotImplementedException();
296+
//JsonSerializer.Serialize(writer, value.Splits, options);
297+
}
298+
}
201299
}

LLama/Common/ModelParams.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
using LLama.Abstractions;
2-
using System;
32
using System.Text;
4-
using System.Text.Json;
53
using System.Text.Json.Serialization;
64
using LLama.Native;
5+
using System.Collections.Generic;
76

87
namespace LLama.Common
98
{
@@ -55,6 +54,9 @@ public record ModelParams
5554
/// <inheritdoc />
5655
public TensorSplitsCollection TensorSplits { get; set; } = new();
5756

57+
/// <inheritdoc />
58+
public List<MetadataOverride> MetadataOverrides { get; } = new();
59+
5860
/// <inheritdoc />
5961
public float? RopeFrequencyBase { get; set; }
6062

LLama/Extensions/IModelParamsExtensions.cs

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using System.IO;
22
using System;
3-
using System.Buffers;
3+
using System.Text;
44
using LLama.Abstractions;
55
using LLama.Native;
66

@@ -36,18 +36,44 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
3636
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
3737
}
3838

39-
//todo: MetadataOverrides
40-
//if (@params.MetadataOverrides.Count == 0)
41-
//{
42-
// unsafe
43-
// {
44-
// result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
45-
// }
46-
//}
47-
//else
48-
//{
49-
// throw new NotImplementedException("MetadataOverrides");
50-
//}
39+
if (@params.MetadataOverrides.Count == 0)
40+
{
41+
unsafe
42+
{
43+
result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
44+
}
45+
}
46+
else
47+
{
48+
// Allocate enough space for all the override items
49+
var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1];
50+
var overridesPin = overrides.AsMemory().Pin();
51+
unsafe
52+
{
53+
result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer;
54+
}
55+
56+
// Convert each item
57+
for (var i = 0; i < @params.MetadataOverrides.Count; i++)
58+
{
59+
var item = @params.MetadataOverrides[i];
60+
var native = new LLamaModelMetadataOverride();
61+
62+
// Init value and tag
63+
item.Write(ref native);
64+
65+
// Convert key to bytes
66+
unsafe
67+
{
68+
fixed (char* srcKey = item.Key)
69+
{
70+
Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128);
71+
}
72+
}
73+
74+
overrides[i] = native;
75+
}
76+
}
5177

5278
return disposer;
5379
}

LLama/Native/LLamaModelMetadataOverride.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public unsafe struct LLamaModelMetadataOverride
1212
/// Key to override
1313
/// </summary>
1414
[FieldOffset(0)]
15-
public fixed char key[128];
15+
public fixed byte key[128];
1616

1717
/// <summary>
1818
/// Type of value

0 commit comments

Comments
 (0)