Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6f4c53c
add LLamaReranker and tests
nipeone Apr 3, 2025
a69f814
Merge branch 'feature-llamareranker'
nipeone Apr 3, 2025
15c5247
Merge branch 'SciSharp:master' into feature-llamareranker
nipeone Apr 11, 2025
c604359
optimize LLamaReranker function
nipeone Apr 11, 2025
d99670c
fix Reranking if documents is too large
nipeone Apr 11, 2025
05677fe
fix Reranking if document contains null
nipeone Apr 15, 2025
4258cc1
Merge branch 'SciSharp:master' into feature-llamareranker
nipeone Apr 18, 2025
8d61a92
Merge branch 'SciSharp:master' into feature-llamareranker
nipeone Apr 21, 2025
e1939eb
Merge branch 'SciSharp:master' into master
nipeone Apr 29, 2025
49ae0a8
Merge branch 'SciSharp:master' into feature-llamareranker
nipeone Apr 29, 2025
474cfd1
Merge branch 'SciSharp:master' into master
nipeone May 6, 2025
a53f503
Merge branch 'master' of https://github.com/nipeone/LLamaSharp
nipeone May 6, 2025
9ed7378
Merge upstream/master and resolve conflicts
nipeone May 6, 2025
37bb3c3
Merge branch 'master' into feature-llamareranker
nipeone May 6, 2025
2bcb62e
Merge branch 'SciSharp:master' into master
nipeone May 7, 2025
69a5f42
Merge branch 'SciSharp:master' into feature-llamareranker
nipeone May 7, 2025
7b2ee55
Merge remote-tracking branch 'upstream/master'
nipeone May 12, 2025
8a34866
Merge remote-tracking branch 'upstream/master'
nipeone May 12, 2025
371fdcd
optimize LLamaReranker function
nipeone Apr 11, 2025
87059e8
fix Reranking if documents is too large
nipeone Apr 11, 2025
14ba50f
fix Reranking if document contains null
nipeone Apr 15, 2025
9f4bd96
optimize LLamaReranker function
nipeone Apr 11, 2025
8fde3bc
Merge branch 'feature-llamareranker' of https://github.com/nipeone/LL…
nipeone May 12, 2025
3f51a7f
Merge branch 'SciSharp:master' into feature-llamareranker
nipeone May 12, 2025
63ae374
fix code comments in llamareranker file
nipeone May 12, 2025
d838e1c
implement IDisposable in LLamaRerankerTests and SafeLlamaModelHandleV…
nipeone May 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LLama.Unittest/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ internal static class Constants
public static readonly string GenerativeModelPath = "Models/Llama-3.2-1B-Instruct-Q4_0.gguf";
public static readonly string GenerativeModelPath2 = "Models/smollm-360m-instruct-add-basics-q8_0.gguf";
public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf";
public static readonly string RerankingModelPath = "Models/jina-reranker-v1-tiny-en-FP16.gguf";

public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf";
public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf";
Expand Down
9 changes: 9 additions & 0 deletions LLama.Unittest/LLama.Unittest.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@
<LocalFileName>smollm-360m-instruct-add-basics-q8_0.gguf</LocalFileName>
</DownloadFileItem>

<DownloadFileItem Include="jina-reranker-v1-tiny-en-FP16.gguf">
<SourceUrl>https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-FP16.gguf</SourceUrl>
<DestinationFolder>Models</DestinationFolder>
<LocalFileName>jina-reranker-v1-tiny-en-FP16.gguf</LocalFileName>
</DownloadFileItem>

<DownloadFileItem Include="llava-v1.6-mistral-7b">
<SourceUrl>https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf</SourceUrl>
<DestinationFolder>Models</DestinationFolder>
Expand Down Expand Up @@ -130,6 +136,9 @@
<None Update="Models\Llama-3.2-1B-Instruct-Q4_0.gguf">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Models\jina-reranker-v1-tiny-en-FP16.gguf">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Models\smollm-360m-instruct-add-basics-q8_0.gguf">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
Expand Down
79 changes: 79 additions & 0 deletions LLama.Unittest/LLamaRerankerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using LLama.Common;
using LLama.Extensions;
using LLama.Native;
using Microsoft.Extensions.AI;
using System.Runtime.InteropServices;
using Xunit.Abstractions;

namespace LLama.Unittest;

public sealed class LLamaRerankerTests: IDisposable
{
private readonly ITestOutputHelper _testOutputHelper;
private readonly LLamaReranker _reranker;
public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;

var @params = new ModelParams(Constants.RerankingModelPath)
{
ContextSize = 0,
PoolingType = LLamaPoolingType.Rank,
GpuLayerCount = Constants.CIGpuLayerCount,

};
using var weights = LLamaWeights.LoadFromFile(@params);
_reranker = new LLamaReranker(weights, @params);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights is disposed when exiting this method, but _reranker holds onto a reference and uses it later. Should probably make this class disposable (e.g. see https://github.com/SciSharp/LLamaSharp/blob/master/LLama.Unittest/BasicTest.cs#L13-L27)

}

public void Dispose()
{
_reranker.Dispose();
}

[Fact]
public async Task CompareRerankingScore()
{


var input = "what is panda?";
var documents = new string[] {
"hi",
"it's a bear",
string.Join(", ","The giant panda (Ailuropoda melanoleuca)",
"sometimes called a panda bear or simply panda",
"is a bear species endemic to China.")
};
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: false);

Assert.True(documents.Length == scores.Count);

_testOutputHelper.WriteLine($"Rerank score 0: {scores[0]:F4}");
_testOutputHelper.WriteLine($"Rerank score 1: {scores[1]:F4}");
_testOutputHelper.WriteLine($"Rerank score 2: {scores[2]:F4}");
}

[Fact]
public async Task MostRelevantDocument()
{
var input = "what is panda?";
var documents = new string[] {
"hi",
"it's a bear",
string.Join(", ","The giant panda (Ailuropoda melanoleuca)",
"sometimes called a panda bear or simply panda",
"is a bear species endemic to China.")
};
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: true);

Assert.NotNull(scores);
Assert.True(documents.Length == scores.Count);

int maxIndex = scores.Select((score, index) => (score, index))
.MaxBy(x => x.score)
.index;

var maxScoreDocument = documents[maxIndex];
Assert.Equal(documents[2], maxScoreDocument);
}
}
42 changes: 42 additions & 0 deletions LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using System.Text;
using System.Xml.Linq;
using LLama.Common;
using LLama.Extensions;
using Microsoft.Extensions.Logging;


namespace LLama.Unittest.Native;

public class SafeLlamaModelHandleVocabularyTests: IDisposable
{
private readonly LLamaWeights _model;

public SafeLlamaModelHandleVocabularyTests()
{
var @params = new ModelParams(Constants.RerankingModelPath)
{
ContextSize = 0,
PoolingType = LLama.Native.LLamaPoolingType.Rank,
GpuLayerCount = Constants.CIGpuLayerCount
};
_model = LLamaWeights.LoadFromFile(@params);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_model is never disposed, should probably make this class disposable (see https://github.com/SciSharp/LLamaSharp/blob/master/LLama.Unittest/BasicTest.cs#L13-L27)

}

public void Dispose()
{
_model.Dispose();
}

[Fact]
public void GetLLamaTokenString()
{
var bos = _model.Vocab.BOS;
var eos = _model.Vocab.EOS;

var bosStr = _model.Vocab.LLamaTokenToString(bos, true);
var eosStr = _model.Vocab.LLamaTokenToString(eos, true);

Assert.Equal("<s>", bosStr);
Assert.Equal("</s>", eosStr);
}
}
201 changes: 201 additions & 0 deletions LLama/LLamaReranker.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Xml.Linq;
using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Native;
using Microsoft.Extensions.Logging;

namespace LLama;

/// <summary>
/// Get rank scores between prompt and documents
/// </summary>
public sealed partial class LLamaReranker
: IDisposable
{
/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => Context.EmbeddingSize;

/// <summary>
/// LLama Context
/// </summary>
public LLamaContext Context { get; }

/// <summary>
/// Create a new reranker, using the given LLamaWeights
/// </summary>
/// <param name="weights"></param>
/// <param name="params"></param>
/// <param name="logger"></param>
public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
if (@params.UBatchSize != @params.BatchSize)
throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params));
if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true })
throw new NotSupportedException("Computing rank in encoder-decoder models is not supported");
if (@params.PoolingType != LLamaPoolingType.Rank)
throw new NotSupportedException("Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank");
Context = weights.CreateContext(@params, logger);
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
}

/// <inheritdoc />
public void Dispose()
{
Context.Dispose();
}

/// <summary>
/// Retrieve relevance scores for input and documents by reranking, execute once.
/// </summary>
/// <param name="input"></param>
/// <param name="documents"></param>
/// <param name="normalize">Whether to normalize the score to the range (0, 1)</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
/// <exception cref="NotSupportedException"></exception>
public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOnlyList<string> documents, bool normalize = false, CancellationToken cancellationToken = default)
{
List<float> scores = new List<float>(documents.Count);
var inputTokens = Context.Tokenize(input);
var batch = new LLamaBatch();
var clearFlag = 0;

for(var idx = 0; idx < documents.Count; idx++)
{
var docTokens = Context.Tokenize(documents[idx] ?? "");
LLamaToken[] tokens = [.. inputTokens, .. docTokens];

if (batch.TokenCount + tokens.Length > Context.ContextSize)
{
scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken));
batch.Clear();
clearFlag = idx;
}

for (var i = 0; i < tokens.Length; i++)
batch.Add(tokens[i], i, (LLamaSeqId)(idx - clearFlag), true);
}
if (batch.LogitPositionCount > 0)
{
scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken));
batch.Clear();
}

return scores;
}

/// <summary>
/// Retrieve relevance score for input and document by reranking
/// </summary>
/// <param name="input"></param>
/// <param name="document"></param>
/// <param name="cancellationToken"></param>
/// <param name="normalize">Whether to normalize the score to the range (0, 1)</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
/// <exception cref="NotSupportedException"></exception>
public async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, bool normalize = false, CancellationToken cancellationToken = default)
{
var inputTokens = Context.Tokenize(input);
var docTokens = Context.Tokenize(document);
LLamaToken[] tokens = [..inputTokens, ..docTokens];
var batch = new LLamaBatch();
for (var i = 0; i < tokens.Length; i++)
batch.Add(tokens[i], i, LLamaSeqId.Zero, true);

// clear previous kv_cache values
Context.NativeHandle.KvCacheClear();

// Check if we should cancel the work, just before doing anything expensive (encode/decode)
cancellationToken.ThrowIfCancellationRequested();

// Run model
switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder)
{
case (true, false):
{
var result = await Context.EncodeAsync(batch, cancellationToken);
if (result != EncodeResult.Ok)
throw new RuntimeError($"Failed to encode: {result}");
break;
}

case (false, true):
{
var result = await Context.DecodeAsync(batch, cancellationToken);
if (result != DecodeResult.Ok)
throw new RuntimeError($"Failed to decode: {result}");
break;
}

default:
throw new NotSupportedException("Unsupported model type");
}

var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0];

Context.NativeHandle.KvCacheClear();

return (normalize ? Sigmoid(score) : score, tokens.Length);
}

private async Task<IReadOnlyList<float>> CalcRelevanceScores(LLamaBatch batch, bool normalize = false, CancellationToken cancellationToken = default)
{
var (logicCap, _) = batch.GetLogitPositions()[batch.LogitPositionCount - 1];
var seqNum = logicCap.Value + 1;
List<float> scores = new List<float>(seqNum);
// clear previous kv_cache values
Context.NativeHandle.KvCacheClear();

// Check if we should cancel the work, just before doing anything expensive (encode/decode)
cancellationToken.ThrowIfCancellationRequested();

// Run model
switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder)
{
case (true, false):
{
var result = await Context.EncodeAsync(batch, cancellationToken);
if (result != EncodeResult.Ok)
throw new RuntimeError($"Failed to encode: {result}");
break;
}

case (false, true):
{
var result = await Context.DecodeAsync(batch, cancellationToken);
if (result != DecodeResult.Ok)
throw new RuntimeError($"Failed to decode: {result}");
break;
}

default:
throw new NotSupportedException("Unsupported model type");
}

for (var seq = 0; seq < seqNum; seq++)
{
var score = Context.NativeHandle.GetEmbeddingsSeq((LLamaSeqId)seq)[0];
scores.Add(normalize ? Sigmoid(score) : score);
}

Context.NativeHandle.KvCacheClear();

return scores;
}

private float Sigmoid(float x)
{
return (float)(1 / (1 + Math.Exp(-x)));
}
}
18 changes: 12 additions & 6 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,18 @@ internal Vocabulary(SafeLlamaModelHandle model)
_model = model;
}

private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
private static LLamaToken? Normalize(LLamaToken token)
{
return token == -1 ? null : token;
}

/// <summary>
/// Translate LLamaToken to String
/// </summary>
/// <param name="token"></param>
/// <param name="isSpecialToken"></param>
/// <returns></returns>
public string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed after the latest changes? It looks like it's not used any more

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer needed in llamareranker, but I suggest that this can be opened as a public function

{
if (!token.HasValue)
return null;
Expand All @@ -676,11 +687,6 @@ internal Vocabulary(SafeLlamaModelHandle model)
return Encoding.UTF8.GetStringFromSpan(slice);
}

private static LLamaToken? Normalize(LLamaToken token)
{
return token == -1 ? null : token;
}

/// <summary>
/// Total number of tokens in this vocabulary
/// </summary>
Expand Down
Loading