Add continuous generation #781
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR is for solving #667.
API
continuouscontinuouswraps anySequenceGeneratorobject, it could be:outlines.generate.choiceoutlines.generate.textoutlines.generate.json...The
continuouswrapper allows the generator to save the state of aSequence, it means that, if you continuously generate a sequence as shown:KV Cache (under some conditions) will be saved. Algorithms such as beam search could be used to optimize the whole sequence rather than separately.
You can mix different types of
SequenceGeneratorobjects:Once a prompt is given to the
continuouswrapper, it becomes aSequenceStateobject.SequenceStateIndexing
Each
SequenceStatehas three dimensionsSequenceState[batch_key: Union[int, slice], sample_key: Union[int, slice], ids_size_key: Union[int, slice]].However, there are three cases where this is handled differently:
batch_size == 1andsample_size == 1SequenceState[ids_size_key: Union[int, slice]], instead ofSequenceState[0, 0, ids_size_key: Union[int, slice]].batch_size == 1SequenceState[sample_key: Union[int, slice], ids_size_key: Union[int, slice]], instead ofSequenceState[0, sample_key: Union[int, slice], ids_size_key: Union[int, slice]].sample_size == 1SequenceState[batch_key: Union[int, slice], ids_size_key: Union[int, slice]], instead ofSequenceState[batch_key: Union[int, slice], 0, ids_size_key: Union[int, slice]].Operations
You can apply two operations on a SequenceState:
Slicing
Adding (
SequenceStateto aSequenceStateandSequenceStateto a prompt)Adding
SequenceStateto aSequenceStateThis won't save the first part of the KV Cache for the moment, but it does accumulate the weights between both sequences.
I don't have an idea how to implement it, the KV Cache implementation from HuggingFace accepts either (1) a
Nonevalue or (2) a KV Cache with a context size less than one than the one for thetoken_ids.I've just done an experiment where I use the model to compute (or complete) the KV Cache for the second sequence using the model to satisfy (2). The function is called
complete_kv_cache_from_token_ids, it's not implemented because it's slow.SequenceStateto a promptThis will reinitialize everything.
Slicing
Conditions under which KV Cache is saved:
The slice considers only one element
(batch_size_after_the_slice == 1, sample_size_after_the_slice == 1), slicing more than one element will reset the KV Cache. The condition includes the base case where(batch_size == 1, sample_size == 1).The slice starts from the first index for the prompt
(SequenceState[..., :M], SequenceState[..., 0:M]).There are some technical intricacies that don't allow saving KV Cache even under 1. and 2., see
[NOTE] [SPECIAL CASE]flags intoken_level_slice_from_string_level_sliceutility.It's also one of the reasons to not go wander to get KV Cache work if
batch_size > 1andnum_samples > 1. The tradeoff complexity-usefulness seems just way off to me.Using
list(SequenceState)list(SequenceState)allows to convert theSequenceStateobject into a list of strings.Exceptions
Three types of exceptions could be raised while using the
continuouswrapper.SampleMismatch: This is (1) raised when the sequence's samples are sliced, then thrown to the wrapper (a mismatch between the number of samples in the sequence and the one in the generator) and (2) two sequences with different number of samples are added.
BatchMismatch: This is raised when two sequences of different batch sizes are added.
SlicingError: This is raised when the slice doesn't allow the KV Cache to be saved, it is handled through resetting the KV Cache.
FLAGs
You will see multiple flags that I've put in the code comments:
[NOTE]: Those are general notes, explaining how I approached some problems.
[QUESTION]: Those are questions that I had when I was coding certain mechanisms.
[POTENTIAL BUG]: Those are lines of code that could potentially trigger bugs.
Modifications
GenerationStatereturnsattention_masksas well.sequence_generatortakeskv_cacheas a keyword argument with the valueNoneas a default.test_continuous.pyThose are some tests I added for the different parts of the code.
PS: I use the name
SequenceStateinstead ofSequencejust because it made the coding more obvious to me, tell me if you want to switch it back toSequence.