Skip to content

[Bug] Suggested Fixes for mathematical inaccuracy in llama_sample_repetition_penalty function #2970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
tysam-code opened this issue Sep 2, 2023 · 12 comments
Labels

Comments

@tysam-code
Copy link

tysam-code commented Sep 2, 2023

Hello,

I am working now with llama.cpp for a personal project. My absolute hats off to whoever is working on code quality, structure, and readability, it is an absolute delight to work with and a very rare example of c/c++ code that I am not afraid to work with despite not knowing it inside and out. My thanks to anyone doing code structure policing, BDFLing or the like.

I'll likely open up a few issue tickets here as I see anything working on the lower end of things as I am getting more involved in the code. An explanation is below. For technical qualifications, I've worked with neural networks for years and have a couple of successful open source projects.

I am volunteering to fix this code, should any suggestions be approved.

Expected Behavior

In the llama_sample_repetition_penalty function, we expect to penalize a token based upon how many times it is used. This is done by dividing the token if it is above zero, and multiplying it by the penalty if it is below zero. It's not really necessarily documented in the commandline what this is doing, so one has to read the code to find this out. It's somewhat implied that it needs to be >1 but I think we can make it a bit more clear.

What we expect to happen is a somewhat universal behavior where the token likelihood smoothly goes down over time based upon how often it is repeated.

In practice, this seems to hit a few snags and is mathematically incorrect, looking at the function, we can see why.

Current Behavior

If we look a the linear behavior of the pre-softmax function, we can see that it changes drastically around zero.

Screen Shot 2023-09-02 at 12 49 06 PM

One problem with this is that softmax is supposed to be zero-invariant, and many functions depend upon this, for example, the trick where we subtract the maximum value to make everything 0 before softmaxing for numerical stability. This in turn, can cause some bizzare inflection points where token suppressing behavior does not work as expected. This is likely an unintended side effect.

Possible Solutions

-- Scale based on minimum value

The frequency penalty seems to add a bias based upon the counts found, so returning to a simple bias to remove the fixed point around 0 seems like we would simply be duplicating the behavior in a separate function.

What we could do is to instead, similar to the softmax value, scale things by adding the minimum value, scaling multiplicatively (no division), then subtract the minimum value. This makes the minimum value of the incoming logits the fixed point, and prevents any bizzarely-scaled 'division in the middle'.

-- Scale based on maximum value

This is also possible, however, it might be unintuitive due to 0 being the upper fixed point.

Impact From Changes

Fixing this will likely cause some impact to some behaviors currently-deployed scripts, including peoples' parameter tuning around the bug. It should have a net long-term benefit, however.

Code contributions

I am willing to write and walk the code through a PR/code review, as well as document the original paper this is based on (if someone can arxiv link me please).

Conclusion

Thank you for your attention in this matter. I plan on making at least one of these changes (or something similar if I personally find a more fitting solution) in my personal version of llama.cpp, and I hope to see them in the main version as well. This is an excellent program and I am glad to be a potential contributor to it.

@tysam-code tysam-code changed the title Fixes for mathematical inaccuracy in llama_sample_repetition_penalty function [Bug] Suggested Fixes for mathematical inaccuracy in llama_sample_repetition_penalty function Sep 2, 2023
@tysam-code
Copy link
Author

tysam-code commented Sep 2, 2023

Examples in Action (Current Code)

Here is the original code in action.

The base prompt used is ./main --model ../modelfile.bin --prompt "Hi...........I hope you are well today............It.....is.....very.....much...a.....nice......day.......indeed..........love.....grandma............I....was.....hoping........" -ngl 1

Randomly sampled examples follow.

Baseline Prompt (1.0)

Screen Shot 2023-09-02 at 1 39 00 PM Screen Shot 2023-09-02 at 1 39 36 PM

(the above image went along for far too long, it is truncated)

Light Repetition Penalty (1.4)

Screen Shot 2023-09-02 at 1 37 52 PM Screen Shot 2023-09-02 at 1 38 08 PM

Extreme Repetition Penalty (5000) [Be forewarned: Semi-cherrypicked due to sampling weirdness]

Screen Shot 2023-09-02 at 1 42 38 PM

!!! Note that our repetition penalty does not stop the .that................ string, here! This is a clear failure of what we want here (and a source of personal confusion as to why the LLM was 'ignoring' the repitition penalty)! :'((((

Screen Shot 2023-09-02 at 1 44 00 PM

More '........' scattered about, unfortunately :'(

Explanation of Why This Likely Happens

Because LLMs are autoregressive, they attempt to predict the most likely next token. In collapsing to a string of the same token, the same token is likely.

However, since this function has a fixed point around 0, and the softmax is nonlinear, we now introduce a weird neighborhood of sensitivity. That is, no matter how strong we make our function, if the logit happens to be above 0, then it will at worst be reduced to....0. This means that if other logits are all less than 0, we can and will get repetition, which I think you can see above here.

However, fixing this sensitivity so that we remove 0 as a fixed point should fix this, and let us rep pen to our heart's content. Shortly should come some examples with a prototypically fixed instance of this code.

@ggerganov
Copy link
Member

Nice analysis - I can see that the current implementation is problematic. Any suggestions and proposals to make it better are welcome. We don't worry that it will change the old behaviour, as long as the new version is correct.

Btw, wouldn't it be more simple to just penalize after the softmax where we have actual probs? What do we win by penalizing logits?

@tysam-code
Copy link
Author

tysam-code commented Sep 2, 2023

Many thanks (also, woohoo on the 'correct-version-moving-forward' philosophy! makes life a lot easier :))! Working on a suggested code fix now (and examples as necessary)

My personal 2 cents would be that fixing it after the softmax is a bit problematic I think as the expected behavior is often 'pre-softmax' in a lot of the LLM world unless absolutely necessary (like in top-p, etc), in terms of expected implementation (for future developers), as well as how it plays with other sampling techniques.

I think it's a bit cleaner too as it doesn't require any sum-rescaling, and maybe lets us play nicely still with any assumptions any of the sampling methods make. There's something that just feels a bit hacky about post-softmax scaling, since it has a bit of an 'if-then-else' feel to it, which could go south.

That's at least why my initial gut feeling is. To the point where I would even prefer a exp-transform -> maniupulate -> log-transform-back pre-softmax over a post-softmax change.

Maybe because it would leave the other logits less-touched? I'm not sure. Just the spidey sense talking here ATM.

@tysam-code
Copy link
Author

If there's a good reference to the original paper, I can take a look through too. I had some trouble finding it based on the descriptions that I've run into in the code and online. I'm sucker for following a method's pedigree for wisdom/insights/etc.

@tysam-code
Copy link
Author

tysam-code commented Sep 2, 2023

Update: Error in previous assertion

Quick update, the bug seems to still be a problem, but when we print out the token id after each token an extremely high suppression value (50000), we can see that the model is using token-dodging here.

Screen Shot 2023-09-02 at 3 35 09 PM

I think the bug is still there and bears fixing, but looks like there's the classic issue of token dodging to get around. I have some friends that have dealt with this for a few years, I'll see if they have anything to say about it.

@Jipok
Copy link

Jipok commented Sep 2, 2023

Does this make sense if there are plans to make penalties by sequence(#2593)?

@KerfuffleV2
Copy link
Collaborator

KerfuffleV2 commented Sep 2, 2023

Does this make sense if there are plans to make penalties by sequence(#2593)?

I've actually been following this thread since my code takes the same approach for applying the penalty in flag_divide_by_penalty mode (which I think is better than using a flat amount). My sampler there attempts to penalize sequences, but it can (currently) only do that by penalizing the tokens that would continue a sequence. In other words, it penalizes a token same as the repetition penalty sampler. It just chooses the token to penalize in a different way.

So if this is a problem, the seqrep sampler has it too and if a solution is identified I'll update it to yoink that fix.

@tysam-code
Copy link
Author

I think some of the token-dodging will be quite hard to get around unless there is regex, or if the tokenizer is in the bpe family of tokenizers (I have no idea what the llama1/llama2 tokenizers are unfortunately :'/)

@KerfuffleV2
Copy link
Collaborator

I think some of the token-dodging will be quite hard to get around

I'm not familiar with the term "token-dodging". Could you explain a little more for the slower members in the class (cough yours truly)?

Do you mean something like if we ban the LLM from producing blah maybe it just builds blah out of other individual tokens. That kind of thing?

@tysam-code
Copy link
Author

Yeah, that's correct. It's simpler for some of the simpler tokens, like '.'. With BPE, you should be able to just do matches on the derived tokens and such since the more complex tokens are all built from combinations of simpler tokens IIRC. So you can traverse that tree if you pick the simplest one. That doesn't cover all of the cases, just some of them.

Another approach is a regex-based approach that just checks the actual output. That is a bit stricter and can catch anything with weird token boundaries but exactly the same output in the end.

In my "grandma's period-filled email" testing, for one of the much longer passages, near the end as it ran out of tokens it started using the actual ellipses symbols and such, since those have similar semantic meaning to the tokens. One could do some slightly more arcane tricks to try to find token semantic similarity, but I think that might be a bit more involved.

I'm sure there are some decent solutions out there as I saw the NovelAI (and I'm sure Kobold too) people struggle quite viscerally with this in 2021. These models can be quite persistent in Count Grey-ing you (in loose general reference to a phenomenon where a character named Count Grey would appear out of thin air and instantly kill your character in the original AI Dungeon no matter what you did) if they have a mind to.

At some point, it transitions over towards the territory of being a model problem, we can only constrain the output so much before it becomes prohibitive. But there are a few ways to sorta 'stretch our budget', as it were, before we go completely broke on the external methods. That's about the limit of my knowledge, though, the most knowledgeable people on that I think would be in the NovelAI discord, some of them have worked with this problem quite closely for years and would know much more than me on the subject.

Hope this helps. :)

@KerfuffleV2
Copy link
Collaborator

Sorry, I thought I replied here already. Your explanation did help. Thanks!

I wouldn't really look at stuff like "token dodging" as a bug in the sampler, just the problem of a relatively simple approach. Of course, there's still room for using more sophisticated approaches that try harder to constrain the output.

I think all sampling methods at the logit level are going have significant limitations. The LLM "wants" to express something and you can only do so much to stop it. Something like #2654 would probably be needed to actually reliably redirect the concept that gets expressed rather than just the way the concept is worded.

@github-actions github-actions bot added the stale label Mar 21, 2024
Copy link
Contributor

github-actions bot commented Apr 5, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants