-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Add prefix sharing to continuous batching #42094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
43bb315 to
3e1b4f3
Compare
| raise ValueError(f"Invalid group type: {group_type}") | ||
| self.group_cache_managers.append(cm) | ||
|
|
||
| # We only use prefix sharing if the whole model has only full attention layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that a "for the moment" thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it is not compatible w/ sliding window (VLLM agrees)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand but why not have only on full attention layers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, there is only sliding window or full attention. The only other type of attention I know is in transformers is block attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry my question wasn't clear, in models that have a mix of sliding and full attn, why not enable prefix caching?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the layers with a sliding window overwrite their KV cache when they reach the end of their sliding window, so we have to disable prefix caching for those. And if we disable prefix caching for one layer of the model, we have to disable it for all layers of the model: we need a full foward pass to build the KV cache for all layers, we cannot only do a forward pass for the layers that have no prefix caching.
src/transformers/generation/continuous_batching/continuous_api.py
Outdated
Show resolved
Hide resolved
src/transformers/generation/continuous_batching/cache_manager.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
|
Maybe I will proof read again, but I am removing draft for now. Will address review soon. |
|
Removed draft, ready to merge IMO |
|
Benchmarks, on H100, with the last version of the example script that adds the
The throughput gain increases when the prefix length increases. For reference, the prefixes in the table above are of length 0, 40, 60 or 80 tokens (roughly). If instead we add a large system prompt, so all requests share a 2500 tokens prefix, the gap is more noticeable. Here are the number with flash attention: |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed offline, we might want to simplify the stream, when we request the cache for n blocks, we know the n-1 are FULL / Completed and thus can keep the logic to compute hash there. This means it can be "scheduled" as you don't need the result until the next forward (thus can be done in the BG while the model runs a forward)
| """Returns the number of free blocks left. Both initialized and uninitialized blocks are considered free.""" | ||
| return len(self._uninit_block_ids) + len(self._init_block_ids) | ||
|
|
||
| def is_enough_free_blocks(self, n_blocks: int) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def is_enough_free_blocks(self, n_blocks: int) -> bool: | |
| def has_enough_free_blocks(self, n_blocks: int) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do with f2, it's used in other place. Good catch, thanks!
| # Update loop variables | ||
| parent_hash = block.hash | ||
|
|
||
| def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int: | |
| def __hash__(self, parent_hash: int | None, tokens: list[int]) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only takes self as an argument :(
https://docs.python.org/3.5/reference/datamodel.html#object.__hash__
This PR adds a prefix sharing mechanism to the continuous batching API like the one present in VLLM.
It only activates if the model is a full-attention model, as is the case in VLLM.
The mechanism has two main components:
What is missing from this PR:
The PR is draft until these are resolved but any early comment is welcome