-
Notifications
You must be signed in to change notification settings - Fork 12k
Inference support for T5 and FLAN-T5 model families #8141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
45681a5
1c8d37a
bad0caf
c4ded1a
7293243
7d7fff4
6dc9eb4
78675f3
1d1cb01
7c610fa
b01ce7d
d40c9a1
03ab5dd
88270a3
ded682d
01cd5a6
8b560e6
9bcecf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -485,6 +485,13 @@ extern "C" { | |
// Get a llama model tensor | ||
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); | ||
|
||
// Returns true if the model contains an encoder that requires llama_encode() call | ||
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model); | ||
|
||
// For encoder-decoder models, this function returns id of the token that must be provided | ||
// to the decoder to start generating output sequence. For other models, it returns -1. | ||
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); | ||
|
||
// Returns 0 on success | ||
LLAMA_API uint32_t llama_model_quantize( | ||
const char * fname_inp, | ||
|
@@ -770,6 +777,14 @@ extern "C" { | |
// Frees a batch of tokens allocated with llama_batch_init() | ||
LLAMA_API void llama_batch_free(struct llama_batch batch); | ||
|
||
// Processes a batch of tokens with the ecoder part of the encoder-decoder model. | ||
// Stores the encoder output internally for later use by the decoder cross-attention layers. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my case, a prompt consists of a static part, which is unchanged and makes use of the KV cache, and dynamic part, which changes frequently. It works good with GPT, where I can call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vladfaust No, encoder requires all input tokens to be present in the input batch. It's because the attention in encoder is not causal, so each token in the input sequence attends to all tokens in the input sequence. It doesn't even use KV cache because there's no need to. I guess theoretically it would be possible to implement it in a way that would allow "adding" tokens to encoder output by calling llama_encode() multiple times, but the implementation would be much more complicated, definitely outside the scope of this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to clarify, @fairydreaming: one of my use-cases is converting a growing chat history to some structured representation for each new message. Do I understand correctly that for now I'd have to encode the whole history again and again for each inference without any form of caching? (No offence, obviously, as I'm very grateful for the T5 support at all!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vladfaust Yes, there's no caching in the encoder, so if the input sequence grows even by one token, you have to encode it again and during this process all previous calculations for this token sequence are repeated. |
||
// 0 - success | ||
// < 0 - error | ||
LLAMA_API int32_t llama_encode( | ||
struct llama_context * ctx, | ||
struct llama_batch batch); | ||
|
||
// Positive return values does not mean a fatal error, but rather a warning. | ||
// 0 - success | ||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) | ||
|
Uh oh!
There was an error while loading. Please reload this page.