11using LLama . Common ;
22using LLama . Native ;
3-
43using System . Numerics . Tensors ;
5- using System . Runtime . InteropServices ;
64using System . Text ;
75
86using Xunit . Abstractions ;
97
108namespace LLama . Unittest
119{
12- public class SamplingTests : IDisposable
10+ public class SamplingTests
11+ : IDisposable
1312 {
1413 private readonly ITestOutputHelper _testOutputHelper ;
1514 private readonly LLamaWeights _model ;
@@ -61,7 +60,7 @@ public void Sampling()
6160 var array = LLamaTokenDataArray . Create ( logits ) ;
6261 {
6362 using var _ = LLamaTokenDataArrayNative . Create ( array , out var cur_p ) ;
64- var rawLogits = new float [ _model . VocabCount ] ;
63+ var rawLogits = new float [ _model . Vocab . Count ] ;
6564 for ( int j = 0 ; j < cur_p . Data . Length ; j ++ )
6665 {
6766 rawLogits [ ( int ) cur_p . Data [ j ] . ID ] = cur_p . Data [ j ] . Logit ;
@@ -119,7 +118,7 @@ public void BatchedSampling()
119118
120119 for ( int b = 0 ; b < batch_count ; b ++ )
121120 {
122- var logits = all_logits . Slice ( b * _model . VocabCount , _model . VocabCount ) ;
121+ var logits = all_logits . Slice ( b * _model . Vocab . Count , _model . Vocab . Count ) ;
123122
124123 // Test raw sampling
125124 Assert . Equal ( expected , TensorPrimitives . IndexOfMax ( logits ) ) ;
@@ -128,7 +127,7 @@ public void BatchedSampling()
128127 var array = LLamaTokenDataArray . Create ( logits ) ;
129128 {
130129 using var _ = LLamaTokenDataArrayNative . Create ( array , out var cur_p ) ;
131- var rawLogits = new float [ _model . VocabCount ] ;
130+ var rawLogits = new float [ _model . Vocab . Count ] ;
132131 for ( int j = 0 ; j < cur_p . Data . Length ; j ++ )
133132 {
134133 rawLogits [ ( int ) cur_p . Data [ j ] . ID ] = cur_p . Data [ j ] . Logit ;
@@ -170,7 +169,7 @@ private static SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle co
170169 penaltyCount : 60 , repeat : 1 , freq : 0 , presence : 0
171170 ) ;
172171
173- if ( logit_bias != null ) { chain . AddLogitBias ( context . VocabCount , logit_bias ) ; }
172+ if ( logit_bias != null ) { chain . AddLogitBias ( context . Vocab . Count , logit_bias ) ; }
174173
175174 chain . AddTopK ( 10 ) ;
176175 chain . AddTemperature ( 0.1f ) ;
0 commit comments