Skip to content

Commit d51f88c

Browse files
committed
Initial migration guide
1 parent cb739f8 commit d51f88c

File tree

1 file changed

+236
-0
lines changed

1 file changed

+236
-0
lines changed

MIGRATION_GUIDE_V5.md

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
<!---
2+
Copyright 2025 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
# Version 5 Migration guide
18+
19+
## Library-wide changes with widespread impact
20+
21+
### Removal of TensorFlow and Jax
22+
23+
We're removing the TensorFlow and Jax parts of the library. This will help us focus fully on `torch`
24+
going forward and will greatly reduce the maintenance cost of models. We are working with tools from
25+
the Jax ecosystem still (such as MaxText) in order to see how we can remain compatible with their
26+
tool while keeping `torch` as the only backend for now.
27+
28+
Linked PR: https://github.com/huggingface/transformers/pull/40760
29+
30+
### Dynamic weight loading
31+
32+
We introduce a new weight loading API in `transformers`, which significantly improves on the previous API. This
33+
weight loading API is designed to apply operations to the checkpoints loaded by transformers.
34+
35+
Instead of loading the checkpoint exactly as it is serialized within the model, these operations can reshape, merge,
36+
and split the layers according to how they're defined in this new API. These operations are often a necessity when
37+
working with quantization or parallelism algorithms.
38+
39+
This new API is centered around the new `WeightConverter` class:
40+
41+
```python
42+
class WeightConverter(WeightTransform):
43+
operations: list[ConversionOps]
44+
source_keys: Union[str, list[str]]
45+
target_keys: Union[str, list[str]]
46+
```
47+
48+
The weight converter is designed to apply a list of operations on the source keys, resulting in target keys. A common
49+
operation done on the attention layers is to fuse the query, key, values layers. Doing so with this API would amount
50+
to defining the following conversion:
51+
52+
```python
53+
conversion = WeightConverter(
54+
["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], # The input layers
55+
"self_attn.qkv_proj", # The single layer as output
56+
operations=[Concatenate(dim=0)],
57+
)
58+
```
59+
60+
In this situation, we apply the `Concatenate` operation, which accepts a list of layers as input and returns a single
61+
layer.
62+
63+
This allows us to define a mapping from architecture to a list of weight conversions. Applying those weight conversions
64+
can apply arbitrary transformations to the layers themselves. This significantly simplified the `from_pretrained` method
65+
and helped us remove a lot of technical debt that we accumulated over the past few years.
66+
67+
This results in several improvements:
68+
- Much cleaner definition of transformations applied to the checkpoint
69+
- Reversible transformations, so loading and saving a checkpoint should result in the same checkpoint
70+
- Faster model loading thanks to scheduling of tensor materialization
71+
- Enables complex mix of transformations that wouldn't otherwise be possible (such as quantization + MoEs, or TP + MoEs)
72+
73+
While this is being implemented, expect varying levels of support across different release candidates.
74+
75+
Linked PR: https://github.com/huggingface/transformers/pull/41580
76+
77+
## Library-wide changes with lesser impact
78+
79+
### `use_auth_token`
80+
81+
The `use_auth_token` argument/parameter is deprecated in favor of `token` everywhere.
82+
You should be able to search and replace `use_auth_token` with `token` and get the same logic.
83+
84+
Linked PR: https://github.com/huggingface/transformers/pull/41666
85+
86+
We decided to remove some features for the upcoming v5 as they are currently only supported in a few old models and no longer integrated in current model additions. It's recommended to stick to v4.x in case you need them. Following features are affected:
87+
- No more head masking, see #41076. This feature allowed to turn off certain heads during the attention calculation and only worked for eager.
88+
- No more relative positional biases in Bert-like models, see #41170. This feature was introduced to allow relative position scores within attention calculations (similar to T5). However, this feature is barely used in official models and a lot of complexity instead. It also only worked with eager.
89+
- No more head pruning, see #41417 by @gante. As the name suggests, it allowed to prune heads within your attention layers.
90+
91+
### Updates to supported torch APIs
92+
93+
We dropped support for two torch APIs:
94+
- `torchscript` in https://github.com/huggingface/transformers/pull/41688
95+
- `torch.fx` in https://github.com/huggingface/transformers/pull/41683
96+
97+
Those APIs were deprecated by the PyTorch team, and we're instead focusing on the supported APIs `dynamo` and `export`.
98+
99+
## Quantization changes
100+
101+
We clean up the quantization API in transformers, and significantly refactor the weight loading as highlighted
102+
above.
103+
104+
We drop support for two quantization arguments that have been deprecated for some time:
105+
- `load_in_4bit`
106+
- `load_in_8bit`
107+
108+
We remove them in favor of the `quantization_config` argument which is much more complete. As an example, here is how
109+
you would load a 4-bit bitsandbytes model using this argument:
110+
111+
```python
112+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
113+
114+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
115+
116+
model_4bit = AutoModelForCausalLM.from_pretrained(
117+
"meta-llama/Llama-3.2-3B",
118+
device_map="auto",
119+
quantization_config=quantization_config
120+
)
121+
```
122+
123+
124+
## Configuration
125+
126+
- Methods to init a nested config such as `from_xxx_config` are deleted. Configs can be init from the `__init__` method in the same way (https://github.com/huggingface/transformers/pull/41314)
127+
128+
## Processing
129+
130+
### Tokenization
131+
132+
- Slow tokenizer files (aka: `tokenization_<model>.py` ) will be removed in favor of using fast tokenizer files `tokenization_<model>_fast.py` --> will be renamed to `tokenization_<model>.py`. As fast tokenizers are :hugs:`tokenizers` - backend, they include a wider range of features that are maintainable and reliable.
133+
- Other backends (sentence piece, tokenizers, etc.) will be supported with a light layer if loading a fast tokenizer fails
134+
- Remove legacy files like special_tokens_map.json and added_tokens.json
135+
- Remove _eventually_correct_t5_max_length
136+
- `encode_plus` --> `__call__`
137+
- `batch_decode` --> `decode`
138+
139+
`apply_chat_template` by default returns naked `input_ids` rather than a `BatchEncoding` dict.
140+
This was inconvenient - it should return a `BatchEncoding` dict like `tokenizer.__call__()`, but we were stuck with
141+
it for backward compatibility. The method now returns a `BatchEncoding`.
142+
143+
Linked PRs:
144+
- https://github.com/huggingface/transformers/issues/40938
145+
- https://github.com/huggingface/transformers/pull/40936
146+
- https://github.com/huggingface/transformers/pull/41626
147+
148+
### Processing classes
149+
150+
- In processing classes each attribute will be serialized under `processor_config.json` as a nested dict, instead of serializing attributes in their own config files. Loading will be supported for all old format processors (https://github.com/huggingface/transformers/pull/41474)
151+
- `XXXFeatureExtractors` classes are completely removed in favor of `XXXImageProcessor` class for all vision models (https://github.com/huggingface/transformers/pull/41174)
152+
- Minor change: `XXXFastImageProcessorKwargs` is removed in favor of `XXXImageProcessorKwargs` which will be shared between fast and slow processors (https://github.com/huggingface/transformers/pull/40931)
153+
154+
## Modeling
155+
156+
- Some `RotaryEmbeddings` layers will start returning a dict of tuples, in case the model uses several RoPE configurations (Gemma2, ModernBert). Each value will be a tuple of "cos, sin" per RoPE type.
157+
- Config attribute for `RotaryEmbeddings` layer will be unified and accessed via `config.rope_parameters`. Config attr for `rope_theta` might not be accessible anymore for some models, and instead will be in `config.rope_parameters['rope_theta']`. BC will be supported for a while as much as possible, and in the near future we'll gradually move to the new RoPE format (https://github.com/huggingface/transformers/pull/39847)
158+
159+
### Generate
160+
161+
- Old, deprecated output type aliases were removed (e.g. `GreedySearchEncoderDecoderOutput`). We now only have 4 output classes built from the following matrix: decoder-only vs encoder-decoder, uses beams vs doesn't use beams (https://github.com/huggingface/transformers/pull/40998)
162+
- Removed deprecated classes regarding decoding methods that were moved to the Hub due to low usage (constraints and beam scores) (https://github.com/huggingface/transformers/pull/41223)
163+
- If `generate` doesn't receive any KV Cache argument, the default cache class used is now defined by the model (as opposed to always being `DynamicCache`) (https://github.com/huggingface/transformers/pull/41505)
164+
165+
## Trainer
166+
167+
### Removing arguments without deprecation cycle in `TrainingArguments` due to low usage
168+
169+
- `mp_parameters` -> legacy param that was later on added to sagemaker trainer
170+
- `_n_gpu` -> not intended for users to set, we will initialize it correctly instead of putting it in the `TrainingArguments`
171+
- `overwrite_output_dir` - > replaced by `resume_from_checkpoint` and it was only used in examples script, no impact on Trainer.
172+
- `logging_dir` -> only used for tensorboard, set `TENSORBOARD_LOGGING_DIR` env var instead
173+
- `jit_mode_eval` -> use `use_torch_compile` instead as torchscript is not recommended anymore
174+
- `tpu_num_cores`-> It is actually better to remove it as it is not recommended to set the number of cores. By default, all tpu cores are used . Set `TPU_NUM_CORES` env var instead
175+
- `past_index` -> it was only used for a very small number of models that have special architecture like transformersxl + it was not documented at all how to train those model
176+
- `ray_scope` -> only for a minor arg for ray integration. Set `RAY_SCOPE` var env instead
177+
- `warmup_ratio` -> use `warmup_step` instead. We combined both args together by allowing passing float values in `warmup_step`.
178+
179+
### Removing deprecated arguments in `TrainingArguments`
180+
181+
- `fsdp_min_num_params` and `fsdp_transformer_layer_cls_to_wrap` -> use `fsdp_config`
182+
- `tpu_metrics_debug` -> `debug`
183+
- `push_to_hub_token` -> `hub_token`
184+
- `push_to_hub_model_id` and `push_to_hub_organization` -> `hub_model_id`
185+
- `include_inputs_for_metrics` -> `include_for_metrics`
186+
- `per_gpu_train_batch_size` -> `per_device_train_batch_size`
187+
- `per_gpu_eval_batch_size` -> `per_device_eval_batch_size`
188+
- `use_mps_device` -> mps will be used by default if detected
189+
- `fp16_backend` and `half_precision_backend` -> we will only rely on torch.amp as everything has been upstream to torch
190+
- `no_cuda` -> `use_cpu`
191+
- ` include_tokens_per_second` -> `include_num_input_tokens_seen`
192+
- `use_legacy_prediction_loop` -> we only use `evaluation_loop` function from now on
193+
194+
### Removing deprecated arguments in `Trainer`
195+
196+
- `tokenizer` in initialization -> `processor`
197+
- `model_path` in train() -> `resume_from_checkpoint`
198+
199+
### Removed features for `Trainer`
200+
201+
- sigpot integration for hp search was removed as the library was archived + the api stopped working
202+
- drop support for sagemaker API <1.10
203+
- bump accelerate minimum version to 1.1.0
204+
205+
### New defaults for `Trainer`
206+
207+
- `use_cache` in the model config will be set to `False`. You can still change the cache value through `TrainingArguments` `usel_cache` argument if needed.
208+
209+
## CLI
210+
211+
The deprecated `transformers-cli ...` command was deprecated, `transformers ...` is now the only CLI entry point.
212+
213+
`transformers` CLI has been migrated to `Typer`, making it easier to maintain + adding some nice features out of
214+
the box (improved `--help` section, autocompletion).
215+
216+
Biggest breaking change is in `transformers chat`. This command starts a terminal UI to interact with a chat model.
217+
It used to also be able to start a Chat Completion server powered by `transformers` and chat with it. In this revamped
218+
version, this feature has been removed in favor of `transformers serve`. The goal of splitting `transformers chat`
219+
and `transformers serve` is to define clear boundaries between client and server code. It helps with maintenance
220+
but also makes the commands less bloated. The new signature of `transformers chat` is:
221+
222+
```
223+
Usage: transformers chat [OPTIONS] BASE_URL MODEL_ID [GENERATE_FLAGS]...
224+
225+
Chat with a model from the command line.
226+
```
227+
228+
Example:
229+
230+
```sh
231+
transformers chat https://router.huggingface.co/v1 HuggingFaceTB/SmolLM3-3B
232+
```
233+
234+
Linked PRs:
235+
- https://github.com/huggingface/transformers/pull/40997
236+
- https://github.com/huggingface/transformers/pull/41487

0 commit comments

Comments
 (0)