-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[core][optimization] use a pool of numpy ndarray to hold seq data #5942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
output_token_ids: Optional[List[int]] = None, | ||
) -> None: | ||
self.tokens = _SEQUENCE_DATA_POOL.alloc_array() | ||
self.prompt_token_ids_list = prompt_token_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any opportunity to get rid of this list (and output token ids list)? This is completely duplicated to the numpy arrays and we should avoid that as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to delete it, too. However, sometimes we need to get the list of int of prompt token ids because users want list of int. If we don't store it here, we need to create a copy from numpy array, which is expensive.
Fortunately, this is just a reference, performance-wise it is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I searched the code base and seems like only batch expansion uses get_prompt_token_ids()
and get_output_token_ids()
, so it should be possible, as batch expansion is going to be removed by @LiuXiaoxuanPKU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good to know.
|
||
def append_token_id(self, token_id: int, logprob: float) -> None: | ||
self.output_token_ids.append(token_id) | ||
self.tokens[self.num_prompt_tokens + self.num_output_tokens] = token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we should have an assertion to check the boundary, even 128k should always be sufficient atm. Let's add an assert if it doesn't hurt performance; otherwise we could just comment that we assume the context length won't go beyond 128k.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think numpy array indexing already has boundary check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to somehow know the max seq length in seqdata, but don't know how to pass that information across so many levels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting a fixed length makes sense to me considering the simplicity. Hmm maybe it's ok to keep the current implementation then. If someone really hits the boundary and see the numpy error, we could know what's going on...
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you! |
The remaining part of #5877 after separating #5882 out.
the same benchmark command:
python benchmarks/benchmark_throughput.py --output-len 256 --input 256 --model meta-llama/Llama-2-7b-hf -tp 8
the same machine: 8*H100
before (current main): Throughput: 38.89 requests/s, 19909.29 tokens/s
after (this PR): Throughput: 40.12 requests/s, 20541.11 tokens/s