Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
using LLama;
using LLama.Abstractions;
using LLama.Common;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI;
using Microsoft.SemanticKernel.AI.Embeddings;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace LLamaSharp.KernelMemory
{
Expand Down Expand Up @@ -80,24 +73,24 @@ public void Dispose()
}

/// <inheritdoc/>
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
{
IList<ReadOnlyMemory<float>> results = new List<ReadOnlyMemory<float>>();

foreach (var d in data)
{
var embeddings = _embedder.GetEmbeddings(d);
var embeddings = await _embedder.GetEmbeddings(d, cancellationToken);
results.Add(new ReadOnlyMemory<float>(embeddings));
}

return Task.FromResult(results);
return results;
}

/// <inheritdoc/>
public Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
{
var embeddings = _embedder.GetEmbeddings(text);
return Task.FromResult(new Embedding(embeddings));
var embeddings = await _embedder.GetEmbeddings(text, cancellationToken);
return new Embedding(embeddings);
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace LLamaSharp.SemanticKernel.TextEmbedding;

public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService
{
private LLamaEmbedder _embedder;
private readonly LLamaEmbedder _embedder;

private readonly Dictionary<string, object?> _attributes = new();

Expand All @@ -20,7 +20,11 @@ public LLamaSharpEmbeddingGeneration(LLamaEmbedder embedder)
/// <inheritdoc/>
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var embeddings = data.Select(text => new ReadOnlyMemory<float>(_embedder.GetEmbeddings(text))).ToList();
return await Task.FromResult(embeddings);
var result = new List<ReadOnlyMemory<float>>();

foreach (var item in data)
result.Add(await _embedder.GetEmbeddings(item, cancellationToken));

return result;
}
}
25 changes: 15 additions & 10 deletions LLama.Unittest/LLamaEmbedderTests.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
using LLama.Common;
using Xunit.Abstractions;

namespace LLama.Unittest;

public sealed class LLamaEmbedderTests
: IDisposable
{
private readonly ITestOutputHelper _testOutputHelper;
private readonly LLamaEmbedder _embedder;

public LLamaEmbedderTests()
public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
var @params = new ModelParams(Constants.ModelPath)
{
EmbeddingMode = true,
Expand Down Expand Up @@ -41,21 +44,23 @@ private static float Dot(float[] a, float[] b)
}

[Fact]
public void EmbedCompare()
public async Task EmbedCompare()
{
var cat = _embedder.GetEmbeddings("cat");
var kitten = _embedder.GetEmbeddings("kitten");
var spoon = _embedder.GetEmbeddings("spoon");
var cat = await _embedder.GetEmbeddings("cat");
var kitten = await _embedder.GetEmbeddings("kitten");
var spoon = await _embedder.GetEmbeddings("spoon");

Normalize(cat);
Normalize(kitten);
Normalize(spoon);

var close = Dot(cat, kitten);
var far = Dot(cat, spoon);
var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);

// This comparison seems backwards, but remember that with a
// dot product 1.0 means **identical** and 0.0 means **completely opposite**!
Assert.True(close > far);
Assert.True(close < far);

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
}
}
67 changes: 40 additions & 27 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using LLama.Exceptions;
using LLama.Abstractions;
using Microsoft.Extensions.Logging;
using System.Threading;
using System.Threading.Tasks;

namespace LLama
{
Expand Down Expand Up @@ -40,50 +42,61 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="threads">unused</param>
/// <param name="addBos">Add bos to the text.</param>
/// <param name="encoding">unused</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("'threads' and 'encoding' parameters are no longer used")]
// ReSharper disable once MethodOverloadWithOptionalParameter
public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
public Task<float[]> GetEmbeddings(string text, CancellationToken cancellationToken = default)
{
return GetEmbeddings(text, addBos);
}

/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text)
{
return GetEmbeddings(text, true);
return GetEmbeddings(text, true, cancellationToken);
}

/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="addBos">Add bos to the text.</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text, bool addBos)
public async Task<float[]> GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default)
{
var embed_inp_array = Context.Tokenize(text, addBos);
var tokens = Context.Tokenize(text, addBos);
if (tokens.Length > Context.ContextSize)
throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text));

// Evaluate prompt in batch-size chunks
var n_past = 0;
var batch = new LLamaBatch();
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Length; i += batchSize)
{
var n_eval = tokens.Length - i;
if (n_eval > batchSize)
n_eval = batchSize;

batch.Clear();
for (var j = 0; j < n_eval; j++)
batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, false);

var returnCode = await Context.DecodeAsync(batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}

// TODO(Rinne): deal with log of prompt
var embeddings = GetEmbeddingsArray();

if (embed_inp_array.Length > 0)
Context.Eval(embed_inp_array.AsSpan(), 0);
// Remove everything we just evaluated from the context cache
Context.NativeHandle.KvCacheClear();

var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();
return embeddings;

return embeddings.ToArray();
float[] GetEmbeddingsArray()
{
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();
return embeddings.ToArray();
}
}

/// <summary>
Expand Down