diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
index 8148adc88..7806282de 100644
--- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
+++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
@@ -1,14 +1,7 @@
using LLama;
-using LLama.Abstractions;
using LLama.Common;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI;
-using Microsoft.SemanticKernel.AI.Embeddings;
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
namespace LLamaSharp.KernelMemory
{
@@ -80,24 +73,24 @@ public void Dispose()
}
///
- public Task>> GenerateEmbeddingsAsync(IList data, CancellationToken cancellationToken = default)
+ public async Task>> GenerateEmbeddingsAsync(IList data, CancellationToken cancellationToken = default)
{
IList> results = new List>();
foreach (var d in data)
{
- var embeddings = _embedder.GetEmbeddings(d);
+ var embeddings = await _embedder.GetEmbeddings(d, cancellationToken);
results.Add(new ReadOnlyMemory(embeddings));
}
- return Task.FromResult(results);
+ return results;
}
///
- public Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
+ public async Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
{
- var embeddings = _embedder.GetEmbeddings(text);
- return Task.FromResult(new Embedding(embeddings));
+ var embeddings = await _embedder.GetEmbeddings(text, cancellationToken);
+ return new Embedding(embeddings);
}
///
diff --git a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs
index 73ceb0f21..6889ba6a7 100644
--- a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs
+++ b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs
@@ -6,7 +6,7 @@ namespace LLamaSharp.SemanticKernel.TextEmbedding;
public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService
{
- private LLamaEmbedder _embedder;
+ private readonly LLamaEmbedder _embedder;
private readonly Dictionary _attributes = new();
@@ -20,7 +20,11 @@ public LLamaSharpEmbeddingGeneration(LLamaEmbedder embedder)
///
public async Task>> GenerateEmbeddingsAsync(IList data, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
- var embeddings = data.Select(text => new ReadOnlyMemory(_embedder.GetEmbeddings(text))).ToList();
- return await Task.FromResult(embeddings);
+ var result = new List>();
+
+ foreach (var item in data)
+ result.Add(await _embedder.GetEmbeddings(item, cancellationToken));
+
+ return result;
}
}
diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs
index 4c8fb37fa..052690480 100644
--- a/LLama.Unittest/LLamaEmbedderTests.cs
+++ b/LLama.Unittest/LLamaEmbedderTests.cs
@@ -1,14 +1,17 @@
using LLama.Common;
+using Xunit.Abstractions;
namespace LLama.Unittest;
public sealed class LLamaEmbedderTests
: IDisposable
{
+ private readonly ITestOutputHelper _testOutputHelper;
private readonly LLamaEmbedder _embedder;
- public LLamaEmbedderTests()
+ public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
{
+ _testOutputHelper = testOutputHelper;
var @params = new ModelParams(Constants.ModelPath)
{
EmbeddingMode = true,
@@ -41,21 +44,23 @@ private static float Dot(float[] a, float[] b)
}
[Fact]
- public void EmbedCompare()
+ public async Task EmbedCompare()
{
- var cat = _embedder.GetEmbeddings("cat");
- var kitten = _embedder.GetEmbeddings("kitten");
- var spoon = _embedder.GetEmbeddings("spoon");
+ var cat = await _embedder.GetEmbeddings("cat");
+ var kitten = await _embedder.GetEmbeddings("kitten");
+ var spoon = await _embedder.GetEmbeddings("spoon");
Normalize(cat);
Normalize(kitten);
Normalize(spoon);
- var close = Dot(cat, kitten);
- var far = Dot(cat, spoon);
+ var close = 1 - Dot(cat, kitten);
+ var far = 1 - Dot(cat, spoon);
- // This comparison seems backwards, but remember that with a
- // dot product 1.0 means **identical** and 0.0 means **completely opposite**!
- Assert.True(close > far);
+ Assert.True(close < far);
+
+ _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
+ _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
+ _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
}
}
\ No newline at end of file
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 8dfc4aaba..c375d2e93 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -3,6 +3,8 @@
using LLama.Exceptions;
using LLama.Abstractions;
using Microsoft.Extensions.Logging;
+using System.Threading;
+using System.Threading.Tasks;
namespace LLama
{
@@ -40,27 +42,12 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg
/// Get the embeddings of the text.
///
///
- /// unused
- /// Add bos to the text.
- /// unused
+ ///
///
///
- [Obsolete("'threads' and 'encoding' parameters are no longer used")]
- // ReSharper disable once MethodOverloadWithOptionalParameter
- public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
+ public Task GetEmbeddings(string text, CancellationToken cancellationToken = default)
{
- return GetEmbeddings(text, addBos);
- }
-
- ///
- /// Get the embeddings of the text.
- ///
- ///
- ///
- ///
- public float[] GetEmbeddings(string text)
- {
- return GetEmbeddings(text, true);
+ return GetEmbeddings(text, true, cancellationToken);
}
///
@@ -68,22 +55,48 @@ public float[] GetEmbeddings(string text)
///
///
/// Add bos to the text.
+ ///
///
///
- public float[] GetEmbeddings(string text, bool addBos)
+ public async Task GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default)
{
- var embed_inp_array = Context.Tokenize(text, addBos);
+ var tokens = Context.Tokenize(text, addBos);
+ if (tokens.Length > Context.ContextSize)
+ throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text));
+
+ // Evaluate prompt in batch-size chunks
+ var n_past = 0;
+ var batch = new LLamaBatch();
+ var batchSize = (int)Context.Params.BatchSize;
+ for (var i = 0; i < tokens.Length; i += batchSize)
+ {
+ var n_eval = tokens.Length - i;
+ if (n_eval > batchSize)
+ n_eval = batchSize;
+
+ batch.Clear();
+ for (var j = 0; j < n_eval; j++)
+ batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, false);
+
+ var returnCode = await Context.DecodeAsync(batch, cancellationToken);
+ if (returnCode != 0)
+ throw new LLamaDecodeError(returnCode);
+ }
- // TODO(Rinne): deal with log of prompt
+ var embeddings = GetEmbeddingsArray();
- if (embed_inp_array.Length > 0)
- Context.Eval(embed_inp_array.AsSpan(), 0);
+ // Remove everything we just evaluated from the context cache
+ Context.NativeHandle.KvCacheClear();
- var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
- if (embeddings == null)
- return Array.Empty();
+ return embeddings;
- return embeddings.ToArray();
+ float[] GetEmbeddingsArray()
+ {
+ var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
+ if (embeddings == null)
+ return Array.Empty();
+ return embeddings.ToArray();
+ }
}
///