Skip to content

Commit beaed8c

Browse files
authored
[generate] move SinkCache to a custom_generate repo (#38399)
remove sink cache
1 parent fe5bfaa commit beaed8c

File tree

6 files changed

+13
-240
lines changed

6 files changed

+13
-240
lines changed

docs/source/en/internal/generation_utils.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,6 @@ A [`Constraint`] can be used to force the generation to include specific tokens
380380

381381
[[autodoc]] HQQQuantizedCache
382382

383-
[[autodoc]] SinkCache
384-
- update
385-
- get_seq_length
386-
- reorder_cache
387-
388383
[[autodoc]] OffloadedCache
389384
- update
390385
- prefetch_layer
@@ -443,4 +438,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens
443438

444439
[[autodoc]] CompileConfig
445440
- __call__
446-

docs/source/en/kv_cache.md

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ Transformers offers several [`Cache`] classes that implement different caching m
3030
| Offloaded Static Cache | No | Yes | Yes | High | Yes |
3131
| Quantized Cache | Yes | No | No | Low | Yes |
3232
| Sliding Window Cache | No | Yes | Yes | High | No |
33-
| Sink Cache | Yes | No | Yes | Mid | Yes |
3433

3534
This guide introduces you to the different [`Cache`] classes and shows you how to use them for generation.
3635

@@ -174,28 +173,6 @@ I like rock music because it's loud and energetic. It's a great way to express m
174173
</hfoption>
175174
</hfoptions>
176175

177-
### Sink cache
178-
179-
[`SinkCache`] is capable of generating very long sequences ("infinite length" according to the paper) by only retaining a few initial tokens from the sequence. These are called the *sink tokens* because they account for a significant portion of the attention scores during generation. Subsequent tokens are discarded on a sliding windowed basis, and only the latest `window_size` tokens are kept. This means most of the previous knowledge is discarded.
180-
181-
The sink tokens allow a model to maintain stable performance even when it's dealing with very long text sequences.
182-
183-
Enable [`SinkCache`] by initializing it first with the [window_length](https://hf.co/docs/transformers/main/en/internal/generation_utils#transformers.SinkCache.window_length) and [num_sink_tokens](https://hf.co/docs/transformers/main/en/internal/generation_utils#transformers.SinkCache.num_sink_tokens) parameters before passing it to [past_key_values](https://hf.co/docs/transformers/internal/generation_utils#transformers.generation.GenerateDecoderOnlyOutput.past_key_values) in [`~GenerationMixin.generate`].
184-
185-
```py
186-
import torch
187-
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
188-
189-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
190-
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
191-
inputs = tokenizer("This is a long story about unicorns, fairies and magic.", return_tensors="pt").to(model.device)
192-
193-
past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
194-
out = model.generate(**inputs, do_sample=False, max_new_tokens=30, past_key_values=past_key_values)
195-
tokenizer.batch_decode(out, skip_special_tokens=True)[0]
196-
"This is a long story about unicorns, fairies and magic. It is a fantasy world where unicorns and fairies live together in harmony. The story follows a young girl named Lily"
197-
```
198-
199176
## Speed optimized caches
200177

201178
The default [`DynamicCache`] prevents you from taking advantage of just-in-time (JIT) optimizations because the cache size isn't fixed. JIT optimizations enable you to maximize latency at the expense of memory usage. All of the following cache types are compatible with JIT optimizations like [torch.compile](./llm_optims#static-kv-cache-and-torchcompile) to accelerate generation.
@@ -247,7 +224,7 @@ Enable [`SlidingWindowCache`] by configuring `cache_implementation="sliding_wind
247224

248225
```py
249226
import torch
250-
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
227+
from transformers import AutoTokenizer, AutoModelForCausalLM
251228

252229
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
253230
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16).to("cuda:0")
@@ -284,16 +261,13 @@ A cache can also work in iterative generation settings where there is back-and-f
284261

285262
For iterative generation with a cache, start by initializing an empty cache class and then you can feed in your new prompts. Keep track of dialogue history with a [chat template](./chat_templating).
286263

287-
If you're using [`SinkCache`], the inputs need to be truncated to the maximum length because [`SinkCache`] can generate text that exceeds its maximum window size. However, the first input shouldn't exceed the maximum cache length.
288-
289264
The example below demonstrates how to use a cache for iterative generation.
290265

291266
```py
292267
import torch
293268
from transformers import AutoTokenizer,AutoModelForCausalLM
294269
from transformers.cache_utils import (
295270
DynamicCache,
296-
SinkCache,
297271
StaticCache,
298272
SlidingWindowCache,
299273
QuantoQuantizedCache,
@@ -313,8 +287,6 @@ messages = []
313287
for prompt in user_prompts:
314288
messages.append({"role": "user", "content": prompt})
315289
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
316-
if isinstance(past_key_values, SinkCache):
317-
inputs = {k: v[:, -max_cache_length:] for k, v in inputs.items()}
318290
input_length = inputs["input_ids"].shape[1]
319291
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=256, past_key_values=past_key_values)
320292
completion = tokenizer.decode(outputs[0, input_length: ], skip_special_tokens=True)
@@ -336,7 +308,7 @@ model_id = "meta-llama/Llama-2-7b-chat-hf"
336308
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
337309
tokenizer = AutoTokenizer.from_pretrained(model_id)
338310

339-
# Init StaticCache with big enough max-length (1024 tokens for the below example)
311+
# Init StaticCache with big enough max-length (1024 tokens for the below example)
340312
# You can also init a DynamicCache, if that suits you better
341313
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
342314

@@ -351,7 +323,7 @@ responses = []
351323
for prompt in prompts:
352324
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
353325
past_key_values = copy.deepcopy(prompt_cache)
354-
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
326+
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
355327
response = tokenizer.batch_decode(outputs)[0]
356328
responses.append(response)
357329

docs/source/ko/internal/generation_utils.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,6 @@ generation_output[:2]
366366

367367
[[autodoc]] HQQQuantizedCache
368368

369-
[[autodoc]] SinkCache
370-
- update
371-
- get_seq_length
372-
- reorder_cache
373-
374369
[[autodoc]] OffloadedCache
375370
- update
376371
- prefetch_layer

src/transformers/cache_utils.py

Lines changed: 8 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import importlib.metadata
33
import json
44
import os
5-
import warnings
65
from dataclasses import dataclass
76
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
87

@@ -1063,199 +1062,18 @@ def _dequantize(self, qtensor):
10631062

10641063
class SinkCache(Cache):
10651064
"""
1066-
Deprecated.
1067-
1068-
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
1069-
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
1070-
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
1071-
1072-
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
1073-
`[batch_size, num_heads, seq_len, head_dim]`.
1074-
1075-
Parameters:
1076-
window_length (`int`):
1077-
The length of the context window.
1078-
num_sink_tokens (`int`):
1079-
The number of sink tokens. See the original paper for more information.
1080-
1081-
Example:
1082-
1083-
```python
1084-
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
1085-
1086-
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1087-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1088-
1089-
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
1090-
1091-
>>> # Prepare a cache class and pass it to model's forward
1092-
>>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
1093-
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1094-
>>> outputs.past_key_values # access cache filled with key/values from generation
1095-
SinkCache()
1096-
```
1065+
Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache.
1066+
See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for
1067+
general `custom_generate`usage.
10971068
"""
10981069

1099-
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
1100-
super().__init__()
1101-
self.key_cache: List[torch.Tensor] = []
1102-
self.value_cache: List[torch.Tensor] = []
1103-
self.window_length = window_length
1104-
self.num_sink_tokens = num_sink_tokens
1105-
self.cos_sin_rerotation_cache = {}
1106-
self._cos_cache = None
1107-
self._sin_cache = None
1108-
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
1109-
1110-
warnings.warn(
1111-
"`SinkCache` is deprecated and will be removed in v4.53.0. You can achieve similar functionality by "
1112-
"using a model with a sliding window attention mechanism, or by expanding RoPE and optionally using an "
1113-
"offloaded cache implementation.",
1114-
FutureWarning,
1070+
# TODO (joao, manuel): Remove this class in v4.59.0
1071+
def __init__(self, **kwargs) -> None:
1072+
raise NotImplementedError(
1073+
"`SinkCache` has been moved as a `custom_generate` repository on the Hub: "
1074+
"https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples."
11151075
)
11161076

1117-
@staticmethod
1118-
def _rotate_half(x):
1119-
x1 = x[..., : x.shape[-1] // 2]
1120-
x2 = x[..., x.shape[-1] // 2 :]
1121-
return torch.cat((-x2, x1), dim=-1)
1122-
1123-
def _apply_key_rotary_pos_emb(
1124-
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
1125-
) -> torch.Tensor:
1126-
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
1127-
return rotated_key_states
1128-
1129-
def _get_rerotation_cos_sin(
1130-
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
1131-
) -> Tuple[torch.Tensor, torch.Tensor]:
1132-
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
1133-
# Upcast to float32 temporarily for better accuracy
1134-
cos = cos.to(torch.float32)
1135-
sin = sin.to(torch.float32)
1136-
1137-
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
1138-
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
1139-
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
1140-
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
1141-
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
1142-
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
1143-
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
1144-
1145-
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
1146-
rerotation_cos.to(key_states.dtype).unsqueeze(0),
1147-
rerotation_sin.to(key_states.dtype).unsqueeze(0),
1148-
)
1149-
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
1150-
1151-
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
1152-
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
1153-
# TODO: deprecate this function in favor of `cache_position`
1154-
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
1155-
if len(self.key_cache) <= layer_idx:
1156-
return 0
1157-
return self.key_cache[layer_idx].shape[-2]
1158-
1159-
def get_max_cache_shape(self) -> Optional[int]:
1160-
"""Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
1161-
return self.window_length
1162-
1163-
def update(
1164-
self,
1165-
key_states: torch.Tensor,
1166-
value_states: torch.Tensor,
1167-
layer_idx: int,
1168-
cache_kwargs: Optional[Dict[str, Any]] = None,
1169-
) -> Tuple[torch.Tensor, torch.Tensor]:
1170-
"""
1171-
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
1172-
1173-
Parameters:
1174-
key_states (`torch.Tensor`):
1175-
The new key states to cache.
1176-
value_states (`torch.Tensor`):
1177-
The new value states to cache.
1178-
layer_idx (`int`):
1179-
The index of the layer to cache the states for.
1180-
cache_kwargs (`Dict[str, Any]`, `optional`):
1181-
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
1182-
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
1183-
rotation as the tokens are shifted.
1184-
1185-
Return:
1186-
A tuple containing the updated key and value states.
1187-
"""
1188-
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
1189-
# with partially rotated position embeddings, like Phi or Persimmon.
1190-
if cache_kwargs is None:
1191-
cache_kwargs = {}
1192-
sin = cache_kwargs.get("sin")
1193-
cos = cache_kwargs.get("cos")
1194-
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
1195-
using_rope = cos is not None and sin is not None
1196-
1197-
# Update the number of seen tokens
1198-
if layer_idx == 0:
1199-
self._seen_tokens += key_states.shape[-2]
1200-
1201-
# Update the sin/cos cache, which holds sin/cos values for all possible positions
1202-
if using_rope and layer_idx == 0:
1203-
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
1204-
# after all RoPE models have a llama-like cache utilization.
1205-
if cos.dim() == 2:
1206-
self._cos_cache = cos
1207-
self._sin_cache = sin
1208-
else:
1209-
if self._cos_cache is None:
1210-
self._cos_cache = cos[0, ...]
1211-
self._sin_cache = sin[0, ...]
1212-
elif self._cos_cache.shape[0] < self.window_length:
1213-
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
1214-
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
1215-
1216-
# [bsz, num_heads, seq_len, head_dim]
1217-
if len(self.key_cache) <= layer_idx:
1218-
# Empty cache
1219-
self.key_cache.append(key_states)
1220-
self.value_cache.append(value_states)
1221-
1222-
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
1223-
# Growing cache
1224-
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
1225-
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
1226-
1227-
else:
1228-
# Shifting cache
1229-
keys_to_keep = self.key_cache[layer_idx][
1230-
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
1231-
]
1232-
1233-
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
1234-
if using_rope:
1235-
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
1236-
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
1237-
)
1238-
if partial_rotation_size is not None:
1239-
keys_to_keep, keys_pass = (
1240-
keys_to_keep[..., :partial_rotation_size],
1241-
keys_to_keep[..., partial_rotation_size:],
1242-
)
1243-
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
1244-
if partial_rotation_size is not None:
1245-
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
1246-
1247-
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
1248-
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
1249-
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
1250-
1251-
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
1252-
values_to_keep = self.value_cache[layer_idx][
1253-
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
1254-
]
1255-
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
1256-
1257-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
1258-
12591077

12601078
class StaticCache(Cache):
12611079
"""

src/transformers/utils/fx.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from torch.fx.proxy import ParameterProxy
3535

3636
from .. import logging
37-
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
37+
from ..cache_utils import Cache, DynamicCache, StaticCache
3838
from ..modeling_utils import PretrainedConfig, PreTrainedModel
3939
from ..models.auto import get_values
4040
from ..models.auto.modeling_auto import (
@@ -832,12 +832,6 @@ def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
832832
{},
833833
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
834834
)
835-
ProxyableSinkCache = HFProxyableClassMeta(
836-
"ProxyableSinkCache",
837-
(SinkCache,),
838-
{},
839-
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
840-
)
841835
ProxyableStaticCache = HFProxyableClassMeta(
842836
"ProxyableStaticCache",
843837
(StaticCache,),
@@ -880,7 +874,6 @@ class HFTracer(Tracer):
880874
_CLASSES_TO_PATCH = {
881875
Cache: ProxyableCache,
882876
DynamicCache: ProxyableDynamicCache,
883-
SinkCache: ProxyableSinkCache,
884877
StaticCache: ProxyableStaticCache,
885878
}
886879

utils/check_repo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,7 @@ def find_all_documented_objects() -> List[str]:
10501050
"VitPoseBackbone", # Internal module
10511051
"VitPoseBackboneConfig", # Internal module
10521052
"get_values", # Internal object
1053+
"SinkCache", # Moved to a custom_generate repository, to be deleted from transformers in v4.59.0
10531054
]
10541055

10551056
# This list should be empty. Objects in it should get their own doc page.

0 commit comments

Comments
 (0)