Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion LLama.Unittest/TemplateTests.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
using System.Text;
using LLama.Common;
using LLama.Extensions;
using LLama.Native;
using Xunit.Abstractions;

namespace LLama.Unittest;

public sealed class TemplateTests
: IDisposable
{
private readonly ITestOutputHelper _output;
private readonly LLamaWeights _model;

public TemplateTests()
public TemplateTests(ITestOutputHelper output)
{
_output = output;
var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 1,
Expand Down Expand Up @@ -260,6 +265,37 @@ public void EndOTurnToken_ReturnsExpected()
[Fact]
public void EndOSpeechToken_ReturnsExpected()
{
_output.WriteLine($"EOS: {_model.Tokens.EOS}");
_output.WriteLine($"EOT: {_model.Tokens.EOT}");
_output.WriteLine($"BOS: {_model.Tokens.BOS}");

var eosStr = ConvertTokenToString(_model.Tokens.EOS!.Value);
_output.WriteLine(eosStr ?? "null");

Assert.Equal("</s>", _model.Tokens.EndOfSpeechToken);
}

private string? ConvertTokenToString(LLamaToken token)
{
_output.WriteLine($"ConvertTokenToString: {token}");

const int buffSize = 32;
Span<byte> buff = stackalloc byte[buffSize];
var tokenLength = _model.NativeHandle.TokenToSpan(token, buff, 0, true);

_output.WriteLine($"tokenLength = {tokenLength}");
if (tokenLength <= 0)
return null;

// if the original buffer wasn't large enough, create a new one
_output.WriteLine($"tokenLength = {tokenLength}, buffSize = {buffSize}");
if (tokenLength > buffSize)
{
buff = stackalloc byte[(int)tokenLength];
_ = _model.NativeHandle.TokenToSpan(token, buff, 0, true);
}

var slice = buff.Slice(0, (int)tokenLength);
return Encoding.UTF8.GetStringFromSpan(slice);
}
}
3 changes: 3 additions & 0 deletions LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,8 @@ public class ModelOptions

/// <inheritdoc />
public LLamaPoolingType PoolingType { get; set; }

/// <inheritdoc />
public LLamaAttentionType AttentionType { get; set; } = LLamaAttentionType.Unspecified;
}
}
5 changes: 5 additions & 0 deletions LLama/Abstractions/IContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,9 @@ public interface IContextParams
/// How to pool (sum) embedding results by sequence id (ignored if no pooling layer)
/// </summary>
LLamaPoolingType PoolingType { get; }

/// <summary>
/// Attention type to use for embeddings
/// </summary>
LLamaAttentionType AttentionType { get; }
}
3 changes: 3 additions & 0 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ public record ModelParams
/// <inheritdoc />
public LLamaPoolingType PoolingType { get; set; } = LLamaPoolingType.Unspecified;

/// <inheritdoc />
public LLamaAttentionType AttentionType { get; set; } = LLamaAttentionType.Unspecified;

/// <inheritdoc />
public bool VocabOnly { get; set; }

Expand Down
1 change: 1 addition & 0 deletions LLama/Extensions/IContextParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.offload_kqv = [email protected];
result.flash_attention = @params.FlashAttention;
result.llama_pooling_type = @params.PoolingType;
result.attention_type = @params.AttentionType;

result.n_threads = Threads(@params.Threads);
result.n_threads_batch = Threads(@params.BatchThreads);
Expand Down
64 changes: 32 additions & 32 deletions LLama/LLamaQuantizer.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LLama.Native;
using LLama.Native;
using System;
using System.Collections.Generic;

Expand Down Expand Up @@ -66,49 +66,49 @@ private static bool ValidateFtype(LLamaFtype ftype)

switch (ftype)
{
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_0:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_0:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_1:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q8_0:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_F16:
case LLamaFtype.LLAMA_FTYPE_ALL_F32:
case LLamaFtype.MOSTLY_Q4_0:
case LLamaFtype.MOSTLY_Q4_1:
case LLamaFtype.MOSTLY_Q5_0:
case LLamaFtype.MOSTLY_Q5_1:
case LLamaFtype.MOSTLY_Q8_0:
case LLamaFtype.MOSTLY_F16:
case LLamaFtype.ALL_F32:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K:
case LLamaFtype.MOSTLY_Q2_K_S:
case LLamaFtype.MOSTLY_Q2_K:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_K_XS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_M:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_L:
case LLamaFtype.MOSTLY_IQ3_K_XS:
case LLamaFtype.MOSTLY_Q3_K_S:
case LLamaFtype.MOSTLY_Q3_K_M:
case LLamaFtype.MOSTLY_Q3_K_L:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_K_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_K_M:
case LLamaFtype.MOSTLY_Q4_K_S:
case LLamaFtype.MOSTLY_Q4_K_M:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_K_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q5_K_M:
case LLamaFtype.MOSTLY_Q5_K_S:
case LLamaFtype.MOSTLY_Q5_K_M:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q6_K:
case LLamaFtype.MOSTLY_Q6_K:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XXS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_M:
case LLamaFtype.MOSTLY_IQ2_XXS:
case LLamaFtype.MOSTLY_IQ2_XS:
case LLamaFtype.MOSTLY_IQ2_S:
case LLamaFtype.MOSTLY_IQ2_M:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_XXS:
case LLamaFtype.MOSTLY_IQ3_XXS:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ1_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ1_M:
case LLamaFtype.MOSTLY_IQ1_S:
case LLamaFtype.MOSTLY_IQ1_M:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_NL:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_XS:
case LLamaFtype.MOSTLY_IQ4_NL:
case LLamaFtype.MOSTLY_IQ4_XS:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_M:
case LLamaFtype.MOSTLY_IQ3_S:
case LLamaFtype.MOSTLY_IQ3_M:
return true;

case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
case LLamaFtype.LLAMA_FTYPE_GUESSED:
case LLamaFtype.MOSTLY_Q4_1_SOME_F16:
case LLamaFtype.GUESSED:
default:
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
</ItemGroup>

<PropertyGroup>
<BinaryReleaseId>1c5eba6f8e62</BinaryReleaseId>
<BinaryReleaseId>368645698ab648e390dc</BinaryReleaseId>
</PropertyGroup>

<PropertyGroup>
Expand Down
8 changes: 8 additions & 0 deletions LLama/Native/LLamaAttentionType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace LLama.Native;

public enum LLamaAttentionType
{
Unspecified = -1,
Causal = 0,
NonCausal = 1,
}
5 changes: 5 additions & 0 deletions LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ public struct LLamaContextParams
/// whether to pool (sum) embedding results by sequence id
/// </summary>
public LLamaPoolingType llama_pooling_type;

/// <summary>
/// Attention type to use for embeddings
/// </summary>
public LLamaAttentionType attention_type;

/// <summary>
/// RoPE base frequency, 0 = from model
Expand Down
Loading