Skip to content

Commit 5bce923

Browse files
authored
Merge pull request #993 from Lyrcaxis/distribution-seed-fix
Non-deterministic default seed (+minor sampling parameters names & comments update)
2 parents 9aff11c + 4772eb7 commit 5bce923

File tree

4 files changed

+75
-22
lines changed

4 files changed

+75
-22
lines changed

LLama.KernelMemory/LlamaSharpTextGenerator.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
9292
SamplingPipeline = new DefaultSamplingPipeline()
9393
{
9494
Temperature = (float)options.Temperature,
95-
AlphaFrequency = (float)options.FrequencyPenalty,
96-
AlphaPresence = (float)options.PresencePenalty,
95+
FrequencyPenalty = (float)options.FrequencyPenalty,
96+
PresencePenalty = (float)options.PresencePenalty,
9797
TopP = (float)options.NucleusSampling,
9898
}
9999
};
@@ -107,8 +107,8 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
107107
SamplingPipeline = new DefaultSamplingPipeline()
108108
{
109109
Temperature = (float)options.Temperature,
110-
AlphaFrequency = (float)options.FrequencyPenalty,
111-
AlphaPresence = (float)options.PresencePenalty,
110+
FrequencyPenalty = (float)options.FrequencyPenalty,
111+
PresencePenalty = (float)options.PresencePenalty,
112112
TopP = (float)options.NucleusSampling,
113113
}
114114
};

LLama.SemanticKernel/ExtensionMethods.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ internal static LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LL
5353
{
5454
Temperature = (float)requestSettings.Temperature,
5555
TopP = (float)requestSettings.TopP,
56-
AlphaPresence = (float)requestSettings.PresencePenalty,
57-
AlphaFrequency = (float)requestSettings.FrequencyPenalty,
56+
PresencePenalty = (float)requestSettings.PresencePenalty,
57+
FrequencyPenalty = (float)requestSettings.FrequencyPenalty,
5858
}
5959
};
6060
}

LLama/Extensions/LLamaExecutorExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ private string CreatePrompt(IList<ChatMessage> messages)
142142
MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit
143143
SamplingPipeline = new DefaultSamplingPipeline()
144144
{
145-
AlphaFrequency = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaFrequency), out float af) is true ? af : s_defaultPipeline.AlphaFrequency,
146-
AlphaPresence = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaPresence), out float ap) is true ? ap : s_defaultPipeline.AlphaPresence,
147-
PenalizeEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeEOS), out bool eos) is true ? eos : s_defaultPipeline.PenalizeEOS,
145+
FrequencyPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.FrequencyPenalty), out float af) is true ? af : s_defaultPipeline.FrequencyPenalty,
146+
PresencePenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PresencePenalty), out float ap) is true ? ap : s_defaultPipeline.PresencePenalty,
147+
PreventEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PreventEOS), out bool eos) is true ? eos : s_defaultPipeline.PreventEOS,
148148
PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline,
149149
RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty,
150150
RepeatPenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.RepeatPenaltyCount,

LLama/Sampling/DefaultSamplingPipeline.cs

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,38 +25,76 @@ public sealed class DefaultSamplingPipeline
2525
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
2626
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
2727
/// </summary>
28+
[Obsolete($"Use {nameof(FrequencyPenalty)} instead.")]
2829
public float AlphaFrequency
2930
{
30-
get => _alphaFreq;
31+
get => _frequencyPenalty;
3132
init
3233
{
3334
if (value < -2)
34-
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
35+
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaFrequency)} must be greater than -2");
3536
if (value > 2)
36-
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
37-
_alphaFreq = value;
37+
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaFrequency)} must be less than 2");
38+
_frequencyPenalty = value;
3839
}
3940
}
40-
private readonly float _alphaFreq;
4141

4242
/// <summary>
4343
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
4444
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
4545
/// text so far, increasing the model's likelihood to talk about new topics.
4646
/// </summary>
47+
[Obsolete($"Use {nameof(PresencePenalty)} instead.")]
4748
public float AlphaPresence
4849
{
49-
get => _alphaPresence;
50+
get => _presencePenalty;
5051
init
5152
{
5253
if (value < -2)
53-
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
54+
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaPresence)} must be greater than -2");
5455
if (value > 2)
55-
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
56-
_alphaPresence = value;
56+
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaPresence)} must be less than 2");
57+
_presencePenalty = value;
5758
}
5859
}
59-
private readonly float _alphaPresence;
60+
61+
/// <summary>
62+
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
63+
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
64+
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
65+
/// </summary>
66+
public float FrequencyPenalty
67+
{
68+
get => _frequencyPenalty;
69+
init
70+
{
71+
if (value < -2)
72+
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(FrequencyPenalty)} must be greater than -2");
73+
if (value > 2)
74+
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(FrequencyPenalty)} must be less than 2");
75+
_frequencyPenalty = value;
76+
}
77+
}
78+
private readonly float _frequencyPenalty;
79+
80+
/// <summary>
81+
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
82+
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
83+
/// text so far, increasing the model's likelihood to talk about new topics.
84+
/// </summary>
85+
public float PresencePenalty
86+
{
87+
get => _presencePenalty;
88+
init
89+
{
90+
if (value < -2)
91+
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(PresencePenalty)} must be greater than -2");
92+
if (value > 2)
93+
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(PresencePenalty)} must be less than 2");
94+
_presencePenalty = value;
95+
}
96+
}
97+
private readonly float _presencePenalty;
6098

6199
/// <summary>
62100
/// How many tokens should be considered for penalizing repetition
@@ -71,8 +109,14 @@ public float AlphaPresence
71109
/// <summary>
72110
/// Whether the EOS token should be protected from being modified by penalty
73111
/// </summary>
112+
[Obsolete($"This doesn't do what the name implies. If you're sure you want to use it, use {nameof(PreventEOS)}.")]
74113
public bool PenalizeEOS { get; init; } = false;
75114

115+
/// <summary>
116+
/// Whether the EOS token should be suppressed. Setting this to 'true' prevents EOS from being sampled
117+
/// </summary>
118+
public bool PreventEOS { get; init; } = false;
119+
76120
/// <summary>
77121
/// Temperature to apply (higher temperature is more "creative")
78122
/// </summary>
@@ -111,7 +155,16 @@ public float AlphaPresence
111155
/// <summary>
112156
/// Seed to use for random sampling
113157
/// </summary>
114-
public uint Seed { get; set; } = 42;
158+
public uint Seed { get; set; } = GetRandomSeed();
159+
160+
161+
private static Random RandomSeedGenerator = new();
162+
private static uint GetRandomSeed()
163+
{
164+
lock (RandomSeedGenerator)
165+
return (uint) RandomSeedGenerator.Next(0, int.MaxValue) + (uint) RandomSeedGenerator.Next(0, int.MaxValue);
166+
}
167+
115168

116169
/// <inheritdoc />
117170
protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
@@ -147,8 +200,8 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl
147200
context.VocabCount,
148201
context.ModelHandle.Tokens.EOS, context.ModelHandle.Tokens.Newline ?? 0,
149202
RepeatPenaltyCount, RepeatPenalty,
150-
AlphaFrequency, AlphaPresence,
151-
PenalizeNewline, PenalizeEOS
203+
FrequencyPenalty, PresencePenalty,
204+
PenalizeNewline, PreventEOS
152205
);
153206

154207
chain.AddTopK(TopK);

0 commit comments

Comments
 (0)