Skip to content

Conversation

@remi-or
Copy link
Collaborator

@remi-or remi-or commented Nov 7, 2025

This PR adds a prefix sharing mechanism to the continuous batching API like the one present in VLLM.
It only activates if the model is a full-attention model, as is the case in VLLM.

The mechanism has two main components:

  • block hashing: each block in the cache, once it is filled up, is given a hash that depends on all the tokens in the sequence up to and including the ones in the block
  • prefix detection: when starting prefill for a request, we first look for a prefix with KV cache already computed, and if such a prefix is found, we skip the KV computation for it, using references to completed blocks to save compute
  • block de-reference: when a block is given a hash, we check that no other block shares the same hash. This ensures that each block sharing the same information is unique, and helps keep the cache size in control

What is missing from this PR:

  • more documentations
  • checking the code again
  • edge case: if the prefix is the entire initial request, we still need to do a forward with the last token of the request
  • checks TODOs to adress: the ones left are for a PR down the line

The PR is draft until these are resolved but any early comment is welcome

@remi-or remi-or requested a review from McPatate November 7, 2025 15:57
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

raise ValueError(f"Invalid group type: {group_type}")
self.group_cache_managers.append(cm)

# We only use prefix sharing if the whole model has only full attention layers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that a "for the moment" thing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is not compatible w/ sliding window (VLLM agrees)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand but why not have only on full attention layers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, there is only sliding window or full attention. The only other type of attention I know is in transformers is block attention

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry my question wasn't clear, in models that have a mix of sliding and full attn, why not enable prefix caching?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the layers with a sliding window overwrite their KV cache when they reach the end of their sliding window, so we have to disable prefix caching for those. And if we disable prefix caching for one layer of the model, we have to disable it for all layers of the model: we need a full foward pass to build the KV cache for all layers, we cannot only do a forward pass for the layers that have no prefix caching.

remi-or and others added 3 commits November 12, 2025 16:22
Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
@remi-or
Copy link
Collaborator Author

remi-or commented Nov 12, 2025

Maybe I will proof read again, but I am removing draft for now. Will address review soon.

@remi-or remi-or marked this pull request as ready for review November 13, 2025 08:39
@remi-or
Copy link
Collaborator Author

remi-or commented Nov 13, 2025

Removed draft, ready to merge IMO

@remi-or
Copy link
Collaborator Author

remi-or commented Nov 13, 2025

Benchmarks, on H100, with the last version of the example script that adds the --add-prefix option and sampling off by default
Command: python examples/pytorch/continuous_batching.py --attn $attn -mp none --add-prefix --samples 500

Attention Version Generated tokens Duration (s) Throughput (tok/s) Total prefill length matched
SDPA With prefix sharing 112033 90.47 1238.37 16096
SDPA No prefix sharing 112223 89.95 1247.55 -
SDPA Main branch 112223 90.31 1242.60 -
Flash attention With prefix sharing 111494 24.76 4503.30 16096
Flash attention No prefix sharing 112599 26.87 4190.22 -
Flash attention Main branch 112599 27.10 4155.44 -

The throughput gain increases when the prefix length increases. For reference, the prefixes in the table above are of length 0, 40, 60 or 80 tokens (roughly). If instead we add a large system prompt, so all requests share a 2500 tokens prefix, the gap is more noticeable. Here are the number with flash attention:

No prefix sharing:    195.49 seconds for 122435 tokens. 626.29tok/s
W/ prefix sharing:    129.37 seconds for 121763 tokens. 941.17tok/s

@remi-or remi-or requested a review from McPatate November 13, 2025 13:44
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed offline, we might want to simplify the stream, when we request the cache for n blocks, we know the n-1 are FULL / Completed and thus can keep the logic to compute hash there. This means it can be "scheduled" as you don't need the result until the next forward (thus can be done in the BG while the model runs a forward)

"""Returns the number of free blocks left. Both initialized and uninitialized blocks are considered free."""
return len(self._uninit_block_ids) + len(self._init_block_ids)

def is_enough_free_blocks(self, n_blocks: int) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def is_enough_free_blocks(self, n_blocks: int) -> bool:
def has_enough_free_blocks(self, n_blocks: int) -> bool:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do with f2, it's used in other place. Good catch, thanks!

# Update loop variables
parent_hash = block.hash

def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
def __hash__(self, parent_hash: int | None, tokens: list[int]) -> int:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker ArthurZucker merged commit 47227f4 into main Nov 17, 2025
20 of 24 checks passed
@ArthurZucker ArthurZucker deleted the cb-prefix-sharing branch November 17, 2025 12:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants