-
Couldn't load subscription status.
- Fork 474
Memory efficient context handling #1183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
LLama/LLamaEmbedder.cs
Outdated
| /// </summary> | ||
| /// <param name="text"></param> | ||
| /// <returns></returns> | ||
| public int CountTokens(string text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CountTokens and GetTokens methods are duplicated on LLamaEmbedder and LLamaStatelessExecutor. Also I don't think either of them actually requires a context (which is a very expensive object to create and destroy)!
Can these methods be moved up to LLamaWeights class instead? That's a more appropriate place for methods relating to tokens/vocabulary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CountTokensandGetTokensmethods are duplicated onLLamaEmbedderandLLamaStatelessExecutor.
The reason for this is because the contexts are made with the parameters of each specific object (text generator or embedding generator).
Also I don't think either of them actually requires a context (which is a very expensive object to create and destroy)!
The code is now streamlined to not have the context around, but created when needed and then destroyed. The logic of doing this is the same as in each object itself, for example, GetEmbeddingsWithTokenCount() does the same in LLamaEmbedder, and InferAsync() does the same in StatelessExecutor. So, the code is logical in all sense now. The overhead of creating the context on the fly is very small, and when using KernelMemory with this update, compared to before, 30% less GPU memory is used.
Can these methods be moved up to
LLamaWeightsclass instead? That's a more appropriate place for methods relating to tokens/vocabulary.
If we would move them to LLamaWeights, then we would need to change the code to keep the parameters in each object (required to make the context - different for LLamaEmbedder and for LLamaStatelessExecutor), and pass these parameters on several places in the code to these functions, etc. A lot of modifications on several places where these functions are used. I think that the simplest and cleanest would be to leave them where they are now.
As a conclusion, I would keep it how it is now. Please let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Martin, on a second thought, I think that you may be right (only the params need to be kept). I will look at it!
LLama/LLamaStatelessExecutor.cs
Outdated
| } | ||
|
|
||
| /// <inheritdoc/> | ||
| public int CountTokens(string text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See other comment about these methods
|
Moved the code to |
| var embeddings = await generator.GenerateAsync( | ||
| [ | ||
| "The cat is cute", | ||
| if (false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need resolving before merge?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generator.GetService<EmbeddingGeneratorMetadata>() uses the context and thus will fail because for a context efficient handling we do not keep the context. This was the main aim of this PR.
The code in the test assumes that there is a context. I think that for the test code to work we would need some extra work to create an embedding service that keeps the context (this could be done in a next PR, if anybody is interested in to do it). The aim of the embedder in our code is different. My opinion is that the test code is wrong because it assume that the embedder is a live service, and it should not be for efficiently handling of GPU memory. There are two options, delete the test code or leave it in switched off with the TODO comment I have added.
LLama/LLamaWeights.cs
Outdated
| using var context = CreateContext(parameters); | ||
| var count = context.Tokenize(text, special: true).Length; | ||
| return count; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a Tokenize method on the lower level model handle, no need to use a context: https://github.com/SciSharp/LLamaSharp/blob/master/LLama/Native/SafeLlamaModelHandle.cs#L480
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(i.e. just writing a wrapper over that method should suffice)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not realize that there is a tokenize method on parent level. I will use that then and since it does not need a context I can move all code back to KM. The main aim of this PR was to remove the context that was create on several places because it unnecessarily fills GPU memory (saves about 30% of memory!).
LLama/LLamaWeights.cs
Outdated
| /// <remarks> | ||
| /// It throws if text is null and Includes empty stop token because addBos is left true to be consistent with the CountTokens implementation.</remarks> | ||
| /// <see cref="CountTokens(string, IContextParams)"/> | ||
| public IReadOnlyList<string> GetTokens(string text, IContextParams parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation doesn't seem correct to me (I realise you're just moving it from LlamaSharpTextGenerator to LLamaWeights, but I don't really work on the KM stuff, so I haven't closely reviewed these methods before).
In general LLamaSharp is quite careful about never treating tokens as text, it's not safe for a number of reasons - for example a token could be half of a character, in which case it can't be decoded into text. That's what the StreamingTokenDecoder is for, you could add 10 tokens and get back just one character of text. At the very least, that means that GetTokens and CountTokens would have a mismatch.
Obviously KM needs something back to satisfy the contract of ITextTokenizer etc, so I'm not really sure what the right answer is here. Maybe we should move back into KM, an an extension method on LLamaWeights? That way you can still use it as if it's here, but it's not part of the main lib. I'm open to other ideas though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main aim of this PR was to remove the context that was created on several places in the KM code because it unnecessarily fills GPU memory (this saves about 30% of memory!). With using the native tokenize method and moving back the code to KM, I think that we have the right solution.
|
I have updated the code. |
|
Just gave this a quick skim, looks pretty good after those last changes 👍 I'll try to find time tomorrow to give it one alst in depth review and get it merged. |
|
Thanks for all the work on this! |
You too, Martin! |
Martin, in the LLamaEmbedderTests in CompareEmbeddings() I had to disable a Microsoft.Extensions.AI.IEmbeddingGenerator related code segment that does not work with the new efficient context handling. Please look at that code to decide what can be done to use it or we can also remove it. I guess that if we want that kind of functionality, then we will need to create a LLamaEmbedderService that is compatible with Microsoft.Extensions.AI.IEmbeddingGenerator.
I had to disable in SafeLlamaModelHandleTests -> MetadataValByKey_ReturnsCorrectly for Mac and Linux. On Windows it is OK. Please look at this also to decide if this is important. This has nothing to do with this PR, but I guess with llama.cpp.