-
Notifications
You must be signed in to change notification settings - Fork 476
add LLamaReranker and tests #1150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6f4c53c
a69f814
15c5247
c604359
d99670c
05677fe
4258cc1
8d61a92
e1939eb
49ae0a8
474cfd1
a53f503
9ed7378
37bb3c3
2bcb62e
69a5f42
7b2ee55
8a34866
371fdcd
87059e8
14ba50f
9f4bd96
8fde3bc
3f51a7f
63ae374
d838e1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| using LLama.Common; | ||
| using LLama.Extensions; | ||
| using LLama.Native; | ||
| using Microsoft.Extensions.AI; | ||
| using System.Runtime.InteropServices; | ||
| using Xunit.Abstractions; | ||
|
|
||
| namespace LLama.Unittest; | ||
|
|
||
| public sealed class LLamaRerankerTests: IDisposable | ||
| { | ||
| private readonly ITestOutputHelper _testOutputHelper; | ||
| private readonly LLamaReranker _reranker; | ||
| public LLamaRerankerTests(ITestOutputHelper testOutputHelper) | ||
| { | ||
| _testOutputHelper = testOutputHelper; | ||
|
|
||
| var @params = new ModelParams(Constants.RerankingModelPath) | ||
| { | ||
| ContextSize = 0, | ||
| PoolingType = LLamaPoolingType.Rank, | ||
| GpuLayerCount = Constants.CIGpuLayerCount, | ||
|
|
||
| }; | ||
| using var weights = LLamaWeights.LoadFromFile(@params); | ||
| _reranker = new LLamaReranker(weights, @params); | ||
| } | ||
|
|
||
| public void Dispose() | ||
| { | ||
| _reranker.Dispose(); | ||
| } | ||
|
|
||
| [Fact] | ||
| public async Task CompareRerankingScore() | ||
| { | ||
|
|
||
|
|
||
| var input = "what is panda?"; | ||
| var documents = new string[] { | ||
| "hi", | ||
| "it's a bear", | ||
| string.Join(", ","The giant panda (Ailuropoda melanoleuca)", | ||
| "sometimes called a panda bear or simply panda", | ||
| "is a bear species endemic to China.") | ||
| }; | ||
| var scores = await _reranker.GetRelevanceScores(input, documents, normalize: false); | ||
|
|
||
| Assert.True(documents.Length == scores.Count); | ||
|
|
||
| _testOutputHelper.WriteLine($"Rerank score 0: {scores[0]:F4}"); | ||
| _testOutputHelper.WriteLine($"Rerank score 1: {scores[1]:F4}"); | ||
| _testOutputHelper.WriteLine($"Rerank score 2: {scores[2]:F4}"); | ||
| } | ||
|
|
||
| [Fact] | ||
| public async Task MostRelevantDocument() | ||
| { | ||
| var input = "what is panda?"; | ||
| var documents = new string[] { | ||
| "hi", | ||
| "it's a bear", | ||
| string.Join(", ","The giant panda (Ailuropoda melanoleuca)", | ||
| "sometimes called a panda bear or simply panda", | ||
| "is a bear species endemic to China.") | ||
| }; | ||
| var scores = await _reranker.GetRelevanceScores(input, documents, normalize: true); | ||
|
|
||
| Assert.NotNull(scores); | ||
| Assert.True(documents.Length == scores.Count); | ||
|
|
||
| int maxIndex = scores.Select((score, index) => (score, index)) | ||
| .MaxBy(x => x.score) | ||
| .index; | ||
|
|
||
| var maxScoreDocument = documents[maxIndex]; | ||
| Assert.Equal(documents[2], maxScoreDocument); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| using System.Text; | ||
| using System.Xml.Linq; | ||
| using LLama.Common; | ||
| using LLama.Extensions; | ||
| using Microsoft.Extensions.Logging; | ||
|
|
||
|
|
||
| namespace LLama.Unittest.Native; | ||
|
|
||
| public class SafeLlamaModelHandleVocabularyTests: IDisposable | ||
| { | ||
| private readonly LLamaWeights _model; | ||
|
|
||
| public SafeLlamaModelHandleVocabularyTests() | ||
| { | ||
| var @params = new ModelParams(Constants.RerankingModelPath) | ||
| { | ||
| ContextSize = 0, | ||
| PoolingType = LLama.Native.LLamaPoolingType.Rank, | ||
| GpuLayerCount = Constants.CIGpuLayerCount | ||
| }; | ||
| _model = LLamaWeights.LoadFromFile(@params); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| } | ||
|
|
||
| public void Dispose() | ||
| { | ||
| _model.Dispose(); | ||
| } | ||
|
|
||
| [Fact] | ||
| public void GetLLamaTokenString() | ||
| { | ||
| var bos = _model.Vocab.BOS; | ||
| var eos = _model.Vocab.EOS; | ||
|
|
||
| var bosStr = _model.Vocab.LLamaTokenToString(bos, true); | ||
| var eosStr = _model.Vocab.LLamaTokenToString(eos, true); | ||
|
|
||
| Assert.Equal("<s>", bosStr); | ||
| Assert.Equal("</s>", eosStr); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,201 @@ | ||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.IO; | ||
| using System.Linq; | ||
| using System.Text; | ||
| using System.Threading; | ||
| using System.Threading.Tasks; | ||
| using System.Xml.Linq; | ||
| using LLama.Abstractions; | ||
| using LLama.Exceptions; | ||
| using LLama.Native; | ||
| using Microsoft.Extensions.Logging; | ||
|
|
||
| namespace LLama; | ||
|
|
||
| /// <summary> | ||
| /// Get rank scores between prompt and documents | ||
| /// </summary> | ||
| public sealed partial class LLamaReranker | ||
| : IDisposable | ||
| { | ||
| /// <summary> | ||
| /// Dimension of embedding vectors | ||
| /// </summary> | ||
| public int EmbeddingSize => Context.EmbeddingSize; | ||
|
|
||
| /// <summary> | ||
| /// LLama Context | ||
| /// </summary> | ||
| public LLamaContext Context { get; } | ||
|
|
||
| /// <summary> | ||
| /// Create a new reranker, using the given LLamaWeights | ||
| /// </summary> | ||
| /// <param name="weights"></param> | ||
| /// <param name="params"></param> | ||
| /// <param name="logger"></param> | ||
| public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logger = null) | ||
| { | ||
| if (@params.UBatchSize != @params.BatchSize) | ||
| throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params)); | ||
| if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true }) | ||
| throw new NotSupportedException("Computing rank in encoder-decoder models is not supported"); | ||
| if (@params.PoolingType != LLamaPoolingType.Rank) | ||
| throw new NotSupportedException("Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank"); | ||
| Context = weights.CreateContext(@params, logger); | ||
| NativeApi.llama_set_embeddings(Context.NativeHandle, true); | ||
| } | ||
|
|
||
| /// <inheritdoc /> | ||
| public void Dispose() | ||
| { | ||
| Context.Dispose(); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Retrieve relevance scores for input and documents by reranking, execute once. | ||
| /// </summary> | ||
| /// <param name="input"></param> | ||
| /// <param name="documents"></param> | ||
| /// <param name="normalize">Whether to normalize the score to the range (0, 1)</param> | ||
| /// <param name="cancellationToken"></param> | ||
| /// <returns></returns> | ||
| /// <exception cref="RuntimeError"></exception> | ||
| /// <exception cref="NotSupportedException"></exception> | ||
| public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOnlyList<string> documents, bool normalize = false, CancellationToken cancellationToken = default) | ||
| { | ||
| List<float> scores = new List<float>(documents.Count); | ||
| var inputTokens = Context.Tokenize(input); | ||
| var batch = new LLamaBatch(); | ||
| var clearFlag = 0; | ||
|
|
||
| for(var idx = 0; idx < documents.Count; idx++) | ||
| { | ||
| var docTokens = Context.Tokenize(documents[idx] ?? ""); | ||
| LLamaToken[] tokens = [.. inputTokens, .. docTokens]; | ||
|
|
||
| if (batch.TokenCount + tokens.Length > Context.ContextSize) | ||
| { | ||
| scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken)); | ||
| batch.Clear(); | ||
| clearFlag = idx; | ||
| } | ||
|
|
||
| for (var i = 0; i < tokens.Length; i++) | ||
| batch.Add(tokens[i], i, (LLamaSeqId)(idx - clearFlag), true); | ||
| } | ||
| if (batch.LogitPositionCount > 0) | ||
| { | ||
| scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken)); | ||
| batch.Clear(); | ||
| } | ||
|
|
||
| return scores; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Retrieve relevance score for input and document by reranking | ||
| /// </summary> | ||
| /// <param name="input"></param> | ||
| /// <param name="document"></param> | ||
| /// <param name="cancellationToken"></param> | ||
| /// <param name="normalize">Whether to normalize the score to the range (0, 1)</param> | ||
| /// <returns></returns> | ||
| /// <exception cref="RuntimeError"></exception> | ||
| /// <exception cref="NotSupportedException"></exception> | ||
| public async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, bool normalize = false, CancellationToken cancellationToken = default) | ||
| { | ||
| var inputTokens = Context.Tokenize(input); | ||
| var docTokens = Context.Tokenize(document); | ||
| LLamaToken[] tokens = [..inputTokens, ..docTokens]; | ||
| var batch = new LLamaBatch(); | ||
| for (var i = 0; i < tokens.Length; i++) | ||
| batch.Add(tokens[i], i, LLamaSeqId.Zero, true); | ||
|
|
||
| // clear previous kv_cache values | ||
| Context.NativeHandle.KvCacheClear(); | ||
|
|
||
| // Check if we should cancel the work, just before doing anything expensive (encode/decode) | ||
| cancellationToken.ThrowIfCancellationRequested(); | ||
|
|
||
| // Run model | ||
| switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) | ||
| { | ||
| case (true, false): | ||
| { | ||
| var result = await Context.EncodeAsync(batch, cancellationToken); | ||
| if (result != EncodeResult.Ok) | ||
| throw new RuntimeError($"Failed to encode: {result}"); | ||
| break; | ||
| } | ||
|
|
||
| case (false, true): | ||
| { | ||
| var result = await Context.DecodeAsync(batch, cancellationToken); | ||
| if (result != DecodeResult.Ok) | ||
| throw new RuntimeError($"Failed to decode: {result}"); | ||
| break; | ||
| } | ||
|
|
||
| default: | ||
| throw new NotSupportedException("Unsupported model type"); | ||
| } | ||
|
|
||
| var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0]; | ||
|
|
||
| Context.NativeHandle.KvCacheClear(); | ||
|
|
||
| return (normalize ? Sigmoid(score) : score, tokens.Length); | ||
| } | ||
|
|
||
| private async Task<IReadOnlyList<float>> CalcRelevanceScores(LLamaBatch batch, bool normalize = false, CancellationToken cancellationToken = default) | ||
| { | ||
| var (logicCap, _) = batch.GetLogitPositions()[batch.LogitPositionCount - 1]; | ||
| var seqNum = logicCap.Value + 1; | ||
| List<float> scores = new List<float>(seqNum); | ||
| // clear previous kv_cache values | ||
| Context.NativeHandle.KvCacheClear(); | ||
|
|
||
| // Check if we should cancel the work, just before doing anything expensive (encode/decode) | ||
| cancellationToken.ThrowIfCancellationRequested(); | ||
|
|
||
| // Run model | ||
| switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) | ||
| { | ||
| case (true, false): | ||
| { | ||
| var result = await Context.EncodeAsync(batch, cancellationToken); | ||
| if (result != EncodeResult.Ok) | ||
| throw new RuntimeError($"Failed to encode: {result}"); | ||
| break; | ||
| } | ||
|
|
||
| case (false, true): | ||
| { | ||
| var result = await Context.DecodeAsync(batch, cancellationToken); | ||
| if (result != DecodeResult.Ok) | ||
| throw new RuntimeError($"Failed to decode: {result}"); | ||
| break; | ||
| } | ||
|
|
||
| default: | ||
| throw new NotSupportedException("Unsupported model type"); | ||
| } | ||
|
|
||
| for (var seq = 0; seq < seqNum; seq++) | ||
| { | ||
| var score = Context.NativeHandle.GetEmbeddingsSeq((LLamaSeqId)seq)[0]; | ||
| scores.Add(normalize ? Sigmoid(score) : score); | ||
| } | ||
|
|
||
| Context.NativeHandle.KvCacheClear(); | ||
|
|
||
| return scores; | ||
| } | ||
|
|
||
| private float Sigmoid(float x) | ||
| { | ||
| return (float)(1 / (1 + Math.Exp(-x))); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -651,7 +651,18 @@ internal Vocabulary(SafeLlamaModelHandle model) | |
| _model = model; | ||
| } | ||
|
|
||
| private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken) | ||
| private static LLamaToken? Normalize(LLamaToken token) | ||
| { | ||
| return token == -1 ? null : token; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Translate LLamaToken to String | ||
| /// </summary> | ||
| /// <param name="token"></param> | ||
| /// <param name="isSpecialToken"></param> | ||
| /// <returns></returns> | ||
| public string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this still needed after the latest changes? It looks like it's not used any more
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No longer needed in llamareranker, but I suggest that this can be opened as a public function |
||
| { | ||
| if (!token.HasValue) | ||
| return null; | ||
|
|
@@ -676,11 +687,6 @@ internal Vocabulary(SafeLlamaModelHandle model) | |
| return Encoding.UTF8.GetStringFromSpan(slice); | ||
| } | ||
|
|
||
| private static LLamaToken? Normalize(LLamaToken token) | ||
| { | ||
| return token == -1 ? null : token; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Total number of tokens in this vocabulary | ||
| /// </summary> | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weightsis disposed when exiting this method, but_rerankerholds onto a reference and uses it later. Should probably make this class disposable (e.g. see https://github.com/SciSharp/LLamaSharp/blob/master/LLama.Unittest/BasicTest.cs#L13-L27)