|
2 | 2 | import importlib.metadata |
3 | 3 | import json |
4 | 4 | import os |
5 | | -import warnings |
6 | 5 | from dataclasses import dataclass |
7 | 6 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
8 | 7 |
|
@@ -1063,199 +1062,18 @@ def _dequantize(self, qtensor): |
1063 | 1062 |
|
1064 | 1063 | class SinkCache(Cache): |
1065 | 1064 | """ |
1066 | | - Deprecated. |
1067 | | -
|
1068 | | - A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to |
1069 | | - generate beyond the length of its context window, without losing fluency in the conversation. As it discards past |
1070 | | - tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. |
1071 | | -
|
1072 | | - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
1073 | | - `[batch_size, num_heads, seq_len, head_dim]`. |
1074 | | -
|
1075 | | - Parameters: |
1076 | | - window_length (`int`): |
1077 | | - The length of the context window. |
1078 | | - num_sink_tokens (`int`): |
1079 | | - The number of sink tokens. See the original paper for more information. |
1080 | | -
|
1081 | | - Example: |
1082 | | -
|
1083 | | - ```python |
1084 | | - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache |
1085 | | -
|
1086 | | - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
1087 | | - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
1088 | | -
|
1089 | | - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") |
1090 | | -
|
1091 | | - >>> # Prepare a cache class and pass it to model's forward |
1092 | | - >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) |
1093 | | - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
1094 | | - >>> outputs.past_key_values # access cache filled with key/values from generation |
1095 | | - SinkCache() |
1096 | | - ``` |
| 1065 | + Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache. |
| 1066 | + See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for |
| 1067 | + general `custom_generate`usage. |
1097 | 1068 | """ |
1098 | 1069 |
|
1099 | | - def __init__(self, window_length: int, num_sink_tokens: int) -> None: |
1100 | | - super().__init__() |
1101 | | - self.key_cache: List[torch.Tensor] = [] |
1102 | | - self.value_cache: List[torch.Tensor] = [] |
1103 | | - self.window_length = window_length |
1104 | | - self.num_sink_tokens = num_sink_tokens |
1105 | | - self.cos_sin_rerotation_cache = {} |
1106 | | - self._cos_cache = None |
1107 | | - self._sin_cache = None |
1108 | | - self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen |
1109 | | - |
1110 | | - warnings.warn( |
1111 | | - "`SinkCache` is deprecated and will be removed in v4.53.0. You can achieve similar functionality by " |
1112 | | - "using a model with a sliding window attention mechanism, or by expanding RoPE and optionally using an " |
1113 | | - "offloaded cache implementation.", |
1114 | | - FutureWarning, |
| 1070 | + # TODO (joao, manuel): Remove this class in v4.59.0 |
| 1071 | + def __init__(self, **kwargs) -> None: |
| 1072 | + raise NotImplementedError( |
| 1073 | + "`SinkCache` has been moved as a `custom_generate` repository on the Hub: " |
| 1074 | + "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples." |
1115 | 1075 | ) |
1116 | 1076 |
|
1117 | | - @staticmethod |
1118 | | - def _rotate_half(x): |
1119 | | - x1 = x[..., : x.shape[-1] // 2] |
1120 | | - x2 = x[..., x.shape[-1] // 2 :] |
1121 | | - return torch.cat((-x2, x1), dim=-1) |
1122 | | - |
1123 | | - def _apply_key_rotary_pos_emb( |
1124 | | - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
1125 | | - ) -> torch.Tensor: |
1126 | | - rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) |
1127 | | - return rotated_key_states |
1128 | | - |
1129 | | - def _get_rerotation_cos_sin( |
1130 | | - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
1131 | | - ) -> Tuple[torch.Tensor, torch.Tensor]: |
1132 | | - if key_states.shape[-2] not in self.cos_sin_rerotation_cache: |
1133 | | - # Upcast to float32 temporarily for better accuracy |
1134 | | - cos = cos.to(torch.float32) |
1135 | | - sin = sin.to(torch.float32) |
1136 | | - |
1137 | | - # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence |
1138 | | - original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] |
1139 | | - shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] |
1140 | | - original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] |
1141 | | - shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] |
1142 | | - rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin |
1143 | | - rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin |
1144 | | - |
1145 | | - self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( |
1146 | | - rerotation_cos.to(key_states.dtype).unsqueeze(0), |
1147 | | - rerotation_sin.to(key_states.dtype).unsqueeze(0), |
1148 | | - ) |
1149 | | - return self.cos_sin_rerotation_cache[key_states.shape[-2]] |
1150 | | - |
1151 | | - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
1152 | | - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
1153 | | - # TODO: deprecate this function in favor of `cache_position` |
1154 | | - # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length |
1155 | | - if len(self.key_cache) <= layer_idx: |
1156 | | - return 0 |
1157 | | - return self.key_cache[layer_idx].shape[-2] |
1158 | | - |
1159 | | - def get_max_cache_shape(self) -> Optional[int]: |
1160 | | - """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length.""" |
1161 | | - return self.window_length |
1162 | | - |
1163 | | - def update( |
1164 | | - self, |
1165 | | - key_states: torch.Tensor, |
1166 | | - value_states: torch.Tensor, |
1167 | | - layer_idx: int, |
1168 | | - cache_kwargs: Optional[Dict[str, Any]] = None, |
1169 | | - ) -> Tuple[torch.Tensor, torch.Tensor]: |
1170 | | - """ |
1171 | | - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
1172 | | -
|
1173 | | - Parameters: |
1174 | | - key_states (`torch.Tensor`): |
1175 | | - The new key states to cache. |
1176 | | - value_states (`torch.Tensor`): |
1177 | | - The new value states to cache. |
1178 | | - layer_idx (`int`): |
1179 | | - The index of the layer to cache the states for. |
1180 | | - cache_kwargs (`Dict[str, Any]`, `optional`): |
1181 | | - Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, |
1182 | | - `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the |
1183 | | - rotation as the tokens are shifted. |
1184 | | -
|
1185 | | - Return: |
1186 | | - A tuple containing the updated key and value states. |
1187 | | - """ |
1188 | | - # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models |
1189 | | - # with partially rotated position embeddings, like Phi or Persimmon. |
1190 | | - if cache_kwargs is None: |
1191 | | - cache_kwargs = {} |
1192 | | - sin = cache_kwargs.get("sin") |
1193 | | - cos = cache_kwargs.get("cos") |
1194 | | - partial_rotation_size = cache_kwargs.get("partial_rotation_size") |
1195 | | - using_rope = cos is not None and sin is not None |
1196 | | - |
1197 | | - # Update the number of seen tokens |
1198 | | - if layer_idx == 0: |
1199 | | - self._seen_tokens += key_states.shape[-2] |
1200 | | - |
1201 | | - # Update the sin/cos cache, which holds sin/cos values for all possible positions |
1202 | | - if using_rope and layer_idx == 0: |
1203 | | - # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove |
1204 | | - # after all RoPE models have a llama-like cache utilization. |
1205 | | - if cos.dim() == 2: |
1206 | | - self._cos_cache = cos |
1207 | | - self._sin_cache = sin |
1208 | | - else: |
1209 | | - if self._cos_cache is None: |
1210 | | - self._cos_cache = cos[0, ...] |
1211 | | - self._sin_cache = sin[0, ...] |
1212 | | - elif self._cos_cache.shape[0] < self.window_length: |
1213 | | - self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) |
1214 | | - self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) |
1215 | | - |
1216 | | - # [bsz, num_heads, seq_len, head_dim] |
1217 | | - if len(self.key_cache) <= layer_idx: |
1218 | | - # Empty cache |
1219 | | - self.key_cache.append(key_states) |
1220 | | - self.value_cache.append(value_states) |
1221 | | - |
1222 | | - elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: |
1223 | | - # Growing cache |
1224 | | - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
1225 | | - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
1226 | | - |
1227 | | - else: |
1228 | | - # Shifting cache |
1229 | | - keys_to_keep = self.key_cache[layer_idx][ |
1230 | | - :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : |
1231 | | - ] |
1232 | | - |
1233 | | - # On RoPE models, we need to recompute the Key rotation as the tokens are shifted |
1234 | | - if using_rope: |
1235 | | - rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( |
1236 | | - key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] |
1237 | | - ) |
1238 | | - if partial_rotation_size is not None: |
1239 | | - keys_to_keep, keys_pass = ( |
1240 | | - keys_to_keep[..., :partial_rotation_size], |
1241 | | - keys_to_keep[..., partial_rotation_size:], |
1242 | | - ) |
1243 | | - keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) |
1244 | | - if partial_rotation_size is not None: |
1245 | | - keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) |
1246 | | - |
1247 | | - # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens |
1248 | | - sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] |
1249 | | - self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) |
1250 | | - |
1251 | | - sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] |
1252 | | - values_to_keep = self.value_cache[layer_idx][ |
1253 | | - :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : |
1254 | | - ] |
1255 | | - self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) |
1256 | | - |
1257 | | - return self.key_cache[layer_idx], self.value_cache[layer_idx] |
1258 | | - |
1259 | 1077 |
|
1260 | 1078 | class StaticCache(Cache): |
1261 | 1079 | """ |
|
0 commit comments