-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Fix init empty weights without accelerate
#37337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
0c55f78
add the integration
Cyrilvallez f14c00f
Update accelerate.py
Cyrilvallez 63ae9e8
Update accelerate.py
Cyrilvallez 49c29ba
add find_tied_params as well
Cyrilvallez 0dfcf5b
Update accelerate.py
Cyrilvallez 456a63f
add where copied from
Cyrilvallez 27425b5
simplify
Cyrilvallez c6dec0b
add error
Cyrilvallez File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,211 @@ | ||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
| Since, https://github.com/huggingface/transformers/pull/36963, loading is always performed with models on meta | ||
| device. But since the `init_empty_weights` and `find_tied_parameters` functions are from accelerate, and accelerate is | ||
| somewhat still a soft dependency, we copy the functions here to be used natively in Transformers. | ||
| """ | ||
|
|
||
| import warnings | ||
| from contextlib import contextmanager | ||
|
|
||
| from ..utils import is_torch_available, logging | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| @contextmanager | ||
| def init_empty_weights(include_buffers: bool = False): | ||
| """ | ||
| A context manager under which models are initialized with all parameters on the meta device, therefore creating an | ||
| empty model. Useful when just initializing the model would blow the available RAM. | ||
|
|
||
| Args: | ||
| include_buffers (`bool`, *optional*): | ||
| Whether or not to also put all buffers on the meta device while initializing. | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
| import torch.nn as nn | ||
| from accelerate import init_empty_weights | ||
|
|
||
| # Initialize a model with 100 billions parameters in no time and without using any RAM. | ||
| with init_empty_weights(): | ||
| tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) | ||
| ``` | ||
|
|
||
| <Tip warning={true}> | ||
|
|
||
| Any model created under this context manager has no weights. As such you can't do something like | ||
| `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. | ||
| Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not | ||
| called. | ||
|
|
||
| </Tip> | ||
| """ | ||
| with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f: | ||
| yield f | ||
|
|
||
|
|
||
| @contextmanager | ||
| def init_on_device(device: "torch.device", include_buffers: bool = False): | ||
| """ | ||
| A context manager under which models are initialized with all parameters on the specified device. | ||
|
|
||
| Args: | ||
| device (`torch.device`): | ||
| Device to initialize all parameters on. | ||
| include_buffers (`bool`, *optional*): | ||
| Whether or not to also put all buffers on the meta device while initializing. | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
| import torch.nn as nn | ||
| from accelerate import init_on_device | ||
|
|
||
| with init_on_device(device=torch.device("cuda")): | ||
| tst = nn.Linear(100, 100) # on `cuda` device | ||
| ``` | ||
| """ | ||
| if include_buffers: | ||
| with device: | ||
| yield | ||
| return | ||
|
|
||
| old_register_parameter = nn.Module.register_parameter | ||
| if include_buffers: | ||
| old_register_buffer = nn.Module.register_buffer | ||
|
|
||
| def register_empty_parameter(module, name, param): | ||
| old_register_parameter(module, name, param) | ||
| if param is not None: | ||
| param_cls = type(module._parameters[name]) | ||
| kwargs = module._parameters[name].__dict__ | ||
| kwargs["requires_grad"] = param.requires_grad | ||
| module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) | ||
|
|
||
| def register_empty_buffer(module, name, buffer, persistent=True): | ||
| old_register_buffer(module, name, buffer, persistent=persistent) | ||
| if buffer is not None: | ||
| module._buffers[name] = module._buffers[name].to(device) | ||
|
|
||
| # Patch tensor creation | ||
| if include_buffers: | ||
| tensor_constructors_to_patch = { | ||
| torch_function_name: getattr(torch, torch_function_name) | ||
| for torch_function_name in ["empty", "zeros", "ones", "full"] | ||
| } | ||
| else: | ||
| tensor_constructors_to_patch = {} | ||
|
|
||
| def patch_tensor_constructor(fn): | ||
| def wrapper(*args, **kwargs): | ||
| kwargs["device"] = device | ||
| return fn(*args, **kwargs) | ||
|
|
||
| return wrapper | ||
|
|
||
| try: | ||
| nn.Module.register_parameter = register_empty_parameter | ||
| if include_buffers: | ||
| nn.Module.register_buffer = register_empty_buffer | ||
| for torch_function_name in tensor_constructors_to_patch.keys(): | ||
| setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) | ||
| yield | ||
| finally: | ||
| nn.Module.register_parameter = old_register_parameter | ||
| if include_buffers: | ||
| nn.Module.register_buffer = old_register_buffer | ||
| for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): | ||
| setattr(torch, torch_function_name, old_torch_function) | ||
|
|
||
|
|
||
| class FindTiedParametersResult(list): | ||
| """ | ||
| This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not | ||
| a list or on the `values` method as in the future this will be removed. | ||
| """ | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| def values(self): | ||
| warnings.warn( | ||
| "The 'values' method of FindTiedParametersResult is deprecated and will be removed in Accelerate v1.3.0. ", | ||
| FutureWarning, | ||
| ) | ||
| return sum([x[1:] for x in self], []) | ||
|
|
||
|
|
||
| def find_tied_parameters(model: "nn.Module", **kwargs): | ||
| """ | ||
| Find the tied parameters in a given model. | ||
|
|
||
| <Tip warning={true}> | ||
|
|
||
| The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore | ||
| them. | ||
|
|
||
| </Tip> | ||
|
|
||
| Args: | ||
| model (`torch.nn.Module`): The model to inspect. | ||
|
|
||
| Returns: | ||
| List[List[str]]: A list of lists of parameter names being all tied together. | ||
|
|
||
| Example: | ||
|
|
||
| ```py | ||
| >>> from collections import OrderedDict | ||
| >>> import torch.nn as nn | ||
|
|
||
| >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) | ||
| >>> model.linear2.weight = model.linear1.weight | ||
| >>> find_tied_parameters(model) | ||
| [['linear1.weight', 'linear2.weight']] | ||
| ``` | ||
| """ | ||
|
|
||
| # get ALL model parameters and thier names | ||
| all_named_parameters = dict(model.named_parameters(remove_duplicate=False)) | ||
|
|
||
| # get ONLY unique named parameters, | ||
| # if parameter is tied and have multiple names, it will be included only once | ||
| no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True)) | ||
|
|
||
| # the difference of the two sets will give us the tied parameters | ||
| tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys()) | ||
|
|
||
| # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know | ||
| # which names refer to the same parameter. To identify this, we need to group them together. | ||
| tied_param_groups = {} | ||
| for tied_param_name in tied_param_names: | ||
| tied_param = all_named_parameters[tied_param_name] | ||
| for param_name, param in no_duplicate_named_parameters.items(): | ||
| # compare if parameters are the same, if so, group thier names together | ||
| if param is tied_param: | ||
| if param_name not in tied_param_groups: | ||
| tied_param_groups[param_name] = [] | ||
| tied_param_groups[param_name].append(tied_param_name) | ||
|
|
||
| return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nit not sure if we should replace this warning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean, remove it entirely? We can as we do not expect people to use this function, it's simply internal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this or warning once