|
| 1 | +<!--- |
| 2 | +Copyright 2025 The HuggingFace Team. All rights reserved. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +--> |
| 16 | + |
| 17 | +# Version 5 Migration guide |
| 18 | + |
| 19 | +## Library-wide changes with widespread impact |
| 20 | + |
| 21 | +### Removal of TensorFlow and Jax |
| 22 | + |
| 23 | +We're removing the TensorFlow and Jax parts of the library. This will help us focus fully on `torch` |
| 24 | +going forward and will greatly reduce the maintenance cost of models. We are working with tools from |
| 25 | +the Jax ecosystem still (such as MaxText) in order to see how we can remain compatible with their |
| 26 | +tool while keeping `torch` as the only backend for now. |
| 27 | + |
| 28 | +Linked PR: https://github.com/huggingface/transformers/pull/40760 |
| 29 | + |
| 30 | +### Dynamic weight loading |
| 31 | + |
| 32 | +We introduce a new weight loading API in `transformers`, which significantly improves on the previous API. This |
| 33 | +weight loading API is designed to apply operations to the checkpoints loaded by transformers. |
| 34 | + |
| 35 | +Instead of loading the checkpoint exactly as it is serialized within the model, these operations can reshape, merge, |
| 36 | +and split the layers according to how they're defined in this new API. These operations are often a necessity when |
| 37 | +working with quantization or parallelism algorithms. |
| 38 | + |
| 39 | +This new API is centered around the new `WeightConverter` class: |
| 40 | + |
| 41 | +```python |
| 42 | +class WeightConverter(WeightTransform): |
| 43 | + operations: list[ConversionOps] |
| 44 | + source_keys: Union[str, list[str]] |
| 45 | + target_keys: Union[str, list[str]] |
| 46 | +``` |
| 47 | + |
| 48 | +The weight converter is designed to apply a list of operations on the source keys, resulting in target keys. A common |
| 49 | +operation done on the attention layers is to fuse the query, key, values layers. Doing so with this API would amount |
| 50 | +to defining the following conversion: |
| 51 | + |
| 52 | +```python |
| 53 | +conversion = WeightConverter( |
| 54 | + ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], # The input layers |
| 55 | + "self_attn.qkv_proj", # The single layer as output |
| 56 | + operations=[Concatenate(dim=0)], |
| 57 | +) |
| 58 | +``` |
| 59 | + |
| 60 | +In this situation, we apply the `Concatenate` operation, which accepts a list of layers as input and returns a single |
| 61 | +layer. |
| 62 | + |
| 63 | +This allows us to define a mapping from architecture to a list of weight conversions. Applying those weight conversions |
| 64 | +can apply arbitrary transformations to the layers themselves. This significantly simplified the `from_pretrained` method |
| 65 | +and helped us remove a lot of technical debt that we accumulated over the past few years. |
| 66 | + |
| 67 | +This results in several improvements: |
| 68 | +- Much cleaner definition of transformations applied to the checkpoint |
| 69 | +- Reversible transformations, so loading and saving a checkpoint should result in the same checkpoint |
| 70 | +- Faster model loading thanks to scheduling of tensor materialization |
| 71 | +- Enables complex mix of transformations that wouldn't otherwise be possible (such as quantization + MoEs, or TP + MoEs) |
| 72 | + |
| 73 | +While this is being implemented, expect varying levels of support across different release candidates. |
| 74 | + |
| 75 | +Linked PR: https://github.com/huggingface/transformers/pull/41580 |
| 76 | + |
| 77 | +## Library-wide changes with lesser impact |
| 78 | + |
| 79 | +### `use_auth_token` |
| 80 | + |
| 81 | +The `use_auth_token` argument/parameter is deprecated in favor of `token` everywhere. |
| 82 | +You should be able to search and replace `use_auth_token` with `token` and get the same logic. |
| 83 | + |
| 84 | +Linked PR: https://github.com/huggingface/transformers/pull/41666 |
| 85 | + |
| 86 | +We decided to remove some features for the upcoming v5 as they are currently only supported in a few old models and no longer integrated in current model additions. It's recommended to stick to v4.x in case you need them. Following features are affected: |
| 87 | +- No more head masking, see #41076. This feature allowed to turn off certain heads during the attention calculation and only worked for eager. |
| 88 | +- No more relative positional biases in Bert-like models, see #41170. This feature was introduced to allow relative position scores within attention calculations (similar to T5). However, this feature is barely used in official models and a lot of complexity instead. It also only worked with eager. |
| 89 | +- No more head pruning, see #41417 by @gante. As the name suggests, it allowed to prune heads within your attention layers. |
| 90 | + |
| 91 | +### Updates to supported torch APIs |
| 92 | + |
| 93 | +We dropped support for two torch APIs: |
| 94 | +- `torchscript` in https://github.com/huggingface/transformers/pull/41688 |
| 95 | +- `torch.fx` in https://github.com/huggingface/transformers/pull/41683 |
| 96 | + |
| 97 | +Those APIs were deprecated by the PyTorch team, and we're instead focusing on the supported APIs `dynamo` and `export`. |
| 98 | + |
| 99 | +## Quantization changes |
| 100 | + |
| 101 | +We clean up the quantization API in transformers, and significantly refactor the weight loading as highlighted |
| 102 | +above. |
| 103 | + |
| 104 | +We drop support for two quantization arguments that have been deprecated for some time: |
| 105 | +- `load_in_4bit` |
| 106 | +- `load_in_8bit` |
| 107 | + |
| 108 | +We remove them in favor of the `quantization_config` argument which is much more complete. As an example, here is how |
| 109 | +you would load a 4-bit bitsandbytes model using this argument: |
| 110 | + |
| 111 | +```python |
| 112 | +from transformers import AutoModelForCausalLM, BitsAndBytesConfig |
| 113 | + |
| 114 | +quantization_config = BitsAndBytesConfig(load_in_4bit=True) |
| 115 | + |
| 116 | +model_4bit = AutoModelForCausalLM.from_pretrained( |
| 117 | + "meta-llama/Llama-3.2-3B", |
| 118 | + device_map="auto", |
| 119 | + quantization_config=quantization_config |
| 120 | +) |
| 121 | +``` |
| 122 | + |
| 123 | + |
| 124 | +## Configuration |
| 125 | + |
| 126 | +- Methods to init a nested config such as `from_xxx_config` are deleted. Configs can be init from the `__init__` method in the same way (https://github.com/huggingface/transformers/pull/41314) |
| 127 | + |
| 128 | +## Processing |
| 129 | + |
| 130 | +### Tokenization |
| 131 | + |
| 132 | +- Slow tokenizer files (aka: `tokenization_<model>.py` ) will be removed in favor of using fast tokenizer files `tokenization_<model>_fast.py` --> will be renamed to `tokenization_<model>.py`. As fast tokenizers are :hugs:`tokenizers` - backend, they include a wider range of features that are maintainable and reliable. |
| 133 | +- Other backends (sentence piece, tokenizers, etc.) will be supported with a light layer if loading a fast tokenizer fails |
| 134 | +- Remove legacy files like special_tokens_map.json and added_tokens.json |
| 135 | +- Remove _eventually_correct_t5_max_length |
| 136 | +- `encode_plus` --> `__call__` |
| 137 | +- `batch_decode` --> `decode` |
| 138 | + |
| 139 | +`apply_chat_template` by default returns naked `input_ids` rather than a `BatchEncoding` dict. |
| 140 | +This was inconvenient - it should return a `BatchEncoding` dict like `tokenizer.__call__()`, but we were stuck with |
| 141 | +it for backward compatibility. The method now returns a `BatchEncoding`. |
| 142 | + |
| 143 | +Linked PRs: |
| 144 | +- https://github.com/huggingface/transformers/issues/40938 |
| 145 | +- https://github.com/huggingface/transformers/pull/40936 |
| 146 | +- https://github.com/huggingface/transformers/pull/41626 |
| 147 | + |
| 148 | +### Processing classes |
| 149 | + |
| 150 | +- In processing classes each attribute will be serialized under `processor_config.json` as a nested dict, instead of serializing attributes in their own config files. Loading will be supported for all old format processors (https://github.com/huggingface/transformers/pull/41474) |
| 151 | +- `XXXFeatureExtractors` classes are completely removed in favor of `XXXImageProcessor` class for all vision models (https://github.com/huggingface/transformers/pull/41174) |
| 152 | +- Minor change: `XXXFastImageProcessorKwargs` is removed in favor of `XXXImageProcessorKwargs` which will be shared between fast and slow processors (https://github.com/huggingface/transformers/pull/40931) |
| 153 | + |
| 154 | +## Modeling |
| 155 | + |
| 156 | +- Some `RotaryEmbeddings` layers will start returning a dict of tuples, in case the model uses several RoPE configurations (Gemma2, ModernBert). Each value will be a tuple of "cos, sin" per RoPE type. |
| 157 | +- Config attribute for `RotaryEmbeddings` layer will be unified and accessed via `config.rope_parameters`. Config attr for `rope_theta` might not be accessible anymore for some models, and instead will be in `config.rope_parameters['rope_theta']`. BC will be supported for a while as much as possible, and in the near future we'll gradually move to the new RoPE format (https://github.com/huggingface/transformers/pull/39847) |
| 158 | + |
| 159 | +### Generate |
| 160 | + |
| 161 | +- Old, deprecated output type aliases were removed (e.g. `GreedySearchEncoderDecoderOutput`). We now only have 4 output classes built from the following matrix: decoder-only vs encoder-decoder, uses beams vs doesn't use beams (https://github.com/huggingface/transformers/pull/40998) |
| 162 | +- Removed deprecated classes regarding decoding methods that were moved to the Hub due to low usage (constraints and beam scores) (https://github.com/huggingface/transformers/pull/41223) |
| 163 | +- If `generate` doesn't receive any KV Cache argument, the default cache class used is now defined by the model (as opposed to always being `DynamicCache`) (https://github.com/huggingface/transformers/pull/41505) |
| 164 | + |
| 165 | +## Trainer |
| 166 | + |
| 167 | +### Removing arguments without deprecation cycle in `TrainingArguments` due to low usage |
| 168 | + |
| 169 | +- `mp_parameters` -> legacy param that was later on added to sagemaker trainer |
| 170 | +- `_n_gpu` -> not intended for users to set, we will initialize it correctly instead of putting it in the `TrainingArguments` |
| 171 | +- `overwrite_output_dir` - > replaced by `resume_from_checkpoint` and it was only used in examples script, no impact on Trainer. |
| 172 | +- `logging_dir` -> only used for tensorboard, set `TENSORBOARD_LOGGING_DIR` env var instead |
| 173 | +- `jit_mode_eval` -> use `use_torch_compile` instead as torchscript is not recommended anymore |
| 174 | +- `tpu_num_cores`-> It is actually better to remove it as it is not recommended to set the number of cores. By default, all tpu cores are used . Set `TPU_NUM_CORES` env var instead |
| 175 | +- `past_index` -> it was only used for a very small number of models that have special architecture like transformersxl + it was not documented at all how to train those model |
| 176 | +- `ray_scope` -> only for a minor arg for ray integration. Set `RAY_SCOPE` var env instead |
| 177 | +- `warmup_ratio` -> use `warmup_step` instead. We combined both args together by allowing passing float values in `warmup_step`. |
| 178 | + |
| 179 | +### Removing deprecated arguments in `TrainingArguments` |
| 180 | + |
| 181 | +- `fsdp_min_num_params` and `fsdp_transformer_layer_cls_to_wrap` -> use `fsdp_config` |
| 182 | +- `tpu_metrics_debug` -> `debug` |
| 183 | +- `push_to_hub_token` -> `hub_token` |
| 184 | +- `push_to_hub_model_id` and `push_to_hub_organization` -> `hub_model_id` |
| 185 | +- `include_inputs_for_metrics` -> `include_for_metrics` |
| 186 | +- `per_gpu_train_batch_size` -> `per_device_train_batch_size` |
| 187 | +- `per_gpu_eval_batch_size` -> `per_device_eval_batch_size` |
| 188 | +- `use_mps_device` -> mps will be used by default if detected |
| 189 | +- `fp16_backend` and `half_precision_backend` -> we will only rely on torch.amp as everything has been upstream to torch |
| 190 | +- `no_cuda` -> `use_cpu` |
| 191 | +- ` include_tokens_per_second` -> `include_num_input_tokens_seen` |
| 192 | +- `use_legacy_prediction_loop` -> we only use `evaluation_loop` function from now on |
| 193 | + |
| 194 | +### Removing deprecated arguments in `Trainer` |
| 195 | + |
| 196 | +- `tokenizer` in initialization -> `processor` |
| 197 | +- `model_path` in train() -> `resume_from_checkpoint` |
| 198 | + |
| 199 | +### Removed features for `Trainer` |
| 200 | + |
| 201 | +- sigpot integration for hp search was removed as the library was archived + the api stopped working |
| 202 | +- drop support for sagemaker API <1.10 |
| 203 | +- bump accelerate minimum version to 1.1.0 |
| 204 | + |
| 205 | +### New defaults for `Trainer` |
| 206 | + |
| 207 | +- `use_cache` in the model config will be set to `False`. You can still change the cache value through `TrainingArguments` `usel_cache` argument if needed. |
| 208 | + |
| 209 | +## CLI |
| 210 | + |
| 211 | +The deprecated `transformers-cli ...` command was deprecated, `transformers ...` is now the only CLI entry point. |
| 212 | + |
| 213 | +`transformers` CLI has been migrated to `Typer`, making it easier to maintain + adding some nice features out of |
| 214 | +the box (improved `--help` section, autocompletion). |
| 215 | + |
| 216 | +Biggest breaking change is in `transformers chat`. This command starts a terminal UI to interact with a chat model. |
| 217 | +It used to also be able to start a Chat Completion server powered by `transformers` and chat with it. In this revamped |
| 218 | +version, this feature has been removed in favor of `transformers serve`. The goal of splitting `transformers chat` |
| 219 | +and `transformers serve` is to define clear boundaries between client and server code. It helps with maintenance |
| 220 | +but also makes the commands less bloated. The new signature of `transformers chat` is: |
| 221 | + |
| 222 | +``` |
| 223 | +Usage: transformers chat [OPTIONS] BASE_URL MODEL_ID [GENERATE_FLAGS]... |
| 224 | +
|
| 225 | + Chat with a model from the command line. |
| 226 | +``` |
| 227 | + |
| 228 | +Example: |
| 229 | + |
| 230 | +```sh |
| 231 | +transformers chat https://router.huggingface.co/v1 HuggingFaceTB/SmolLM3-3B |
| 232 | +``` |
| 233 | + |
| 234 | +Linked PRs: |
| 235 | +- https://github.com/huggingface/transformers/pull/40997 |
| 236 | +- https://github.com/huggingface/transformers/pull/41487 |
0 commit comments