Skip to content

Commit 464dfa0

Browse files
authored
Raise conversion errors after loading (#42807)
* raise * comment * fix * add test * fix * add back return * small * raise after report * typos * fix * patch * switch name * doc * oupsi that was commented out
1 parent 6a93635 commit 464dfa0

File tree

4 files changed

+109
-73
lines changed

4 files changed

+109
-73
lines changed

src/transformers/core_model_loading.py

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def convert(
409409
config=None,
410410
hf_quantizer=None,
411411
missing_keys: Optional[MutableSet[str]] = None,
412-
misc: Optional[MutableMapping[str, str]] = None,
412+
conversion_errors: Optional[MutableMapping[str, str]] = None,
413413
):
414414
# Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
415415
# attribute during the whole process
@@ -421,7 +421,9 @@ def convert(
421421
collected_tensors = {target_key: collected_tensors[self.source_patterns[0]]}
422422

423423
if hf_quantizer is not None and self.quantization_operation is not None:
424-
with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), self.quantization_operation):
424+
with log_conversion_errors(
425+
layer_name, conversion_errors, (len(collected_tensors), layer_name), self.quantization_operation
426+
):
425427
collected_tensors = self.quantization_operation.convert(
426428
collected_tensors,
427429
source_patterns=self.source_patterns,
@@ -432,7 +434,7 @@ def convert(
432434
missing_keys=missing_keys,
433435
)
434436

435-
return collected_tensors, misc
437+
return collected_tensors, conversion_errors
436438

437439

438440
@dataclass(slots=True)
@@ -455,14 +457,14 @@ def convert(
455457
config=None,
456458
hf_quantizer=None,
457459
missing_keys: Optional[MutableSet[str]] = None,
458-
misc: Optional[MutableMapping[str, str]] = None,
460+
conversion_errors: Optional[MutableMapping[str, str]] = None,
459461
):
460462
# Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
461463
# attribute during the whole process
462464
collected_tensors = self.materialize_tensors()
463465

464466
for op in self.operations:
465-
with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), op):
467+
with log_conversion_errors(layer_name, conversion_errors, (len(collected_tensors), layer_name), op):
466468
collected_tensors = op.convert(
467469
collected_tensors,
468470
source_patterns=self.source_patterns,
@@ -489,7 +491,9 @@ def convert(
489491
pass
490492

491493
if hf_quantizer is not None and self.quantization_operation is not None:
492-
with log_to_misc(layer_name, misc, (len(collected_tensors), layer_name), self.quantization_operation):
494+
with log_conversion_errors(
495+
layer_name, conversion_errors, (len(collected_tensors), layer_name), self.quantization_operation
496+
):
493497
collected_tensors = self.quantization_operation.convert(
494498
collected_tensors,
495499
source_patterns=self.source_patterns,
@@ -499,7 +503,7 @@ def convert(
499503
model=model,
500504
missing_keys=missing_keys,
501505
)
502-
return collected_tensors, misc
506+
return collected_tensors, conversion_errors
503507

504508

505509
# For I/O bound operations (i.e. here reading files), it is better to have fewer threads, e.g. 4 is a good default.
@@ -560,13 +564,14 @@ def dot_natural_key(s: str):
560564

561565

562566
@contextmanager
563-
def log_to_misc(
567+
def log_conversion_errors(
564568
first_target_key: str,
565-
misc: MutableMapping[str, str],
569+
conversion_errors: MutableMapping[str, str],
566570
extras: Any = None,
567571
op: Union[list[ConversionOps], ConversionOps, None] = None,
568572
):
569-
# A simple helper to handle errors with contextual messages.
573+
"""Catch all exceptions during `convert` calls, and log the errors for later. Re-raise a `SkipParameters` exception
574+
that will be catched later to skip the parameters that raised the original Exception."""
570575
try:
571576
yield
572577
except Exception as e:
@@ -585,17 +590,19 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
585590
if isinstance(extras, tuple) and len(extras) == 2:
586591
length, target_keys = extras
587592
descriptor = f"{op_name} " if op_name else ""
588-
misc[first_target_key] = (
593+
conversion_errors[first_target_key] = (
589594
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {length}"
590595
)
591596
elif isinstance(extras, str):
592597
suffix = f" via {op_name}" if op_name else ""
593-
misc[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
598+
conversion_errors[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
594599
elif extras is None and op_name:
595-
misc[first_target_key] = f"{op_name}: {e}"
600+
conversion_errors[first_target_key] = f"{op_name}: {e}"
596601
else:
597-
misc[first_target_key] = f"{extras} |Error: {e}"
598-
raise SkipLayer()
602+
conversion_errors[first_target_key] = f"{extras} |Error: {e}"
603+
604+
# Raise a specific Exception that we can catch easily
605+
raise SkipParameters()
599606

600607

601608
def set_param_for_module(
@@ -604,44 +611,42 @@ def set_param_for_module(
604611
param_value: torch.Tensor,
605612
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
606613
missing_keys: MutableSet[str],
607-
misc: MutableMapping[str, Any],
608614
unexpected_keys: MutableSet[str],
609615
distributed_operation: Optional[TensorParallelLayer],
610616
hf_quantizer: HfQuantizer,
611617
):
612-
with log_to_misc(target_name, misc, target_name):
613-
module_path, _, param_name = target_name.rpartition(".")
614-
module_obj = model.get_submodule(module_path) if module_path else model
618+
module_path, _, param_name = target_name.rpartition(".")
619+
module_obj = model.get_submodule(module_path) if module_path else model
615620

616-
ref = getattr(module_obj, param_name)
617-
if ref is None:
618-
unexpected_keys.add(target_name)
621+
ref = getattr(module_obj, param_name)
622+
if ref is None:
623+
unexpected_keys.add(target_name)
624+
else:
625+
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
626+
if not isinstance(param_value, torch.nn.Parameter):
627+
if distributed_operation is not None:
628+
param_value = DTensor.from_local(
629+
param_value,
630+
distributed_operation.device_mesh,
631+
getattr(distributed_operation, "shard", Replicate()),
632+
run_check=False,
633+
shape=ref.size(),
634+
stride=ref.stride(),
635+
)
636+
if not use_dtensor:
637+
# we convert to local
638+
param_value = param_value.to_local()
639+
if param_name not in module_obj._buffers:
640+
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
641+
642+
# Remove from missing keys (it's either mismatched, or all good)
643+
missing_keys.discard(target_name)
644+
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
645+
mismatch_keys.add((target_name, param_value.shape, ref.shape))
619646
else:
620-
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
621-
if not isinstance(param_value, torch.nn.Parameter):
622-
if distributed_operation is not None:
623-
param_value = DTensor.from_local(
624-
param_value,
625-
distributed_operation.device_mesh,
626-
getattr(distributed_operation, "shard", Replicate()),
627-
run_check=False,
628-
shape=ref.size(),
629-
stride=ref.stride(),
630-
)
631-
if not use_dtensor:
632-
# we convert to local
633-
param_value = param_value.to_local()
634-
if param_name not in module_obj._buffers:
635-
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
636-
637-
# Remove from missing keys (it's either mismatched, or all good)
638-
missing_keys.discard(target_name)
639-
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
640-
mismatch_keys.add((target_name, param_value.shape, ref.shape))
641-
else:
642-
# super important otherwise _init_weight will re-init the param
643-
param_value._is_hf_initialized = True
644-
setattr(module_obj, param_name, param_value)
647+
# super important otherwise _init_weight will re-init the param
648+
param_value._is_hf_initialized = True
649+
setattr(module_obj, param_name, param_value)
645650

646651

647652
def offload_and_maybe_resave_param(
@@ -663,8 +668,9 @@ def offload_and_maybe_resave_param(
663668
return disk_offload_index
664669

665670

666-
class SkipLayer(Exception):
667-
"""Control-flow sentinel: abort processing of the current layer only."""
671+
class SkipParameters(Exception):
672+
"""Control-flow sentinel: abort processing of the current parameters only (that were supposed to be created
673+
by a WeightConverter)."""
668674

669675
pass
670676

@@ -818,7 +824,7 @@ def convert_and_load_state_dict_in_model(
818824
meta_model_state_dict = model.state_dict()
819825
missing_keys = set(meta_model_state_dict.keys())
820826

821-
misc = {}
827+
conversion_errors = {}
822828
mismatch_keys = set()
823829
unexpected_keys = set()
824830

@@ -925,13 +931,13 @@ def convert_and_load_state_dict_in_model(
925931
pbar.set_postfix({"Materializing param": first_param_name})
926932
pbar.refresh()
927933
try:
928-
realized_value, misc = mapping.convert(
934+
realized_value, conversion_errors = mapping.convert(
929935
first_param_name,
930936
model=model,
931937
config=model.config,
932938
hf_quantizer=hf_quantizer,
933939
missing_keys=missing_keys,
934-
misc=misc,
940+
conversion_errors=conversion_errors,
935941
)
936942
for target_name, param in realized_value.items():
937943
param = param[0] if isinstance(param, list) else param
@@ -949,7 +955,6 @@ def convert_and_load_state_dict_in_model(
949955
param,
950956
mismatch_keys,
951957
missing_keys,
952-
misc,
953958
unexpected_keys,
954959
mapping.distributed_operation,
955960
hf_quantizer,
@@ -958,7 +963,7 @@ def convert_and_load_state_dict_in_model(
958963
# Cleanup all the tensors that were gathered before next iteration
959964
del realized_value
960965

961-
except SkipLayer:
966+
except SkipParameters:
962967
continue
963968

964969
# Close the pool, independently of whether the code was interrupted or finished successfully
@@ -969,7 +974,7 @@ def convert_and_load_state_dict_in_model(
969974

970975
# Keep the current weight conversion mapping for later saving (in case it was coming directly from the user)
971976
model._weight_conversions = weight_mapping
972-
return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc
977+
return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, conversion_errors
973978

974979

975980
def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch.Tensor]):
@@ -1016,7 +1021,7 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
10161021
new_state_dict = {}
10171022
for first_param_name, reversed_converter in conversion_mapping.items():
10181023
# Apply the reverse converter
1019-
realized_value, misc = reversed_converter.convert(first_param_name, model=model, config=model.config)
1024+
realized_value, _ = reversed_converter.convert(first_param_name, model=model, config=model.config)
10201025
for target_name, param in realized_value.items():
10211026
param = param[0] if isinstance(param, list) else param
10221027
new_state_dict[target_name] = param

src/transformers/modeling_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4102,7 +4102,7 @@ def _load_pretrained_model(
41024102
state_dict = merged_state_dict
41034103
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
41044104
# This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
4105-
missing_keys, unexpected_keys, mismatched_keys, misc = set(), set(), set(), set()
4105+
missing_keys, unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set(), set()
41064106
else:
41074107
all_pointer = set()
41084108
# Checkpoints are safetensors
@@ -4124,7 +4124,7 @@ def _load_pretrained_model(
41244124
else:
41254125
raise ValueError("Neither a state dict nor checkpoint files were found.")
41264126

4127-
missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, misc = (
4127+
missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
41284128
convert_and_load_state_dict_in_model(
41294129
model,
41304130
merged_state_dict,
@@ -4198,7 +4198,7 @@ def _load_pretrained_model(
41984198
missing_keys=missing_keys,
41994199
mismatched_keys=mismatched_keys,
42004200
mismatched_shapes=mismatched_keys,
4201-
misc=misc,
4201+
conversion_errors=conversion_errors,
42024202
ignore_mismatched_sizes=ignore_mismatched_sizes,
42034203
)
42044204

src/transformers/utils/loading_report.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,8 @@ def log_state_dict_report(
148148
mismatched_keys=None,
149149
mismatched_shapes=None,
150150
ignore_mismatched_sizes=True,
151-
misc=None,
151+
conversion_errors=None,
152152
color=True, # allow disabling for plain logs
153-
min_width_full_table=60, # terminal min width to attempt full table
154153
):
155154
"""Log a readable report about state_dict loading issues.
156155
@@ -165,12 +164,13 @@ def log_state_dict_report(
165164
missing_keys = missing_keys or []
166165
mismatched_keys = mismatched_keys or []
167166
mismatched_shapes = mismatched_shapes or []
168-
misc = misc or {}
167+
conversion_errors = conversion_errors or {}
169168

170169
# Detect whether the current stdout supports ANSI colors; allow callers to pass `color=False` to force no color
171170
color_enabled = bool(color and sys.stdout.isatty())
172171
ansi = ANSI(color_enabled)
173172

173+
# Re-raise errors early if needed
174174
if error_msgs:
175175
error_msg = "\n\t".join(error_msgs)
176176
if "size mismatch" in error_msg:
@@ -204,9 +204,9 @@ def log_state_dict_report(
204204
)
205205
rows.append(data)
206206

207-
if misc:
208-
for k, v in update_key_name(misc).items():
209-
status = "MISC"
207+
if conversion_errors:
208+
for k, v in update_key_name(conversion_errors).items():
209+
status = "CONVERSION"
210210
status = _color(status, "purple", ansi)
211211
_details = v[:term_w]
212212
rows.append([k, status, _details])
@@ -228,16 +228,25 @@ def log_state_dict_report(
228228
if unexpected_keys:
229229
tips += f"\n- {_color('UNEXPECTED', 'orange', ansi) + ansi['italic']}\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch."
230230
if missing_keys:
231-
tips += f"\n- {_color('MISSING', 'red', ansi) + ansi['italic']}\t:those params were newly initialized because missing form the checkpoint. Consider training on your downstream task."
231+
tips += f"\n- {_color('MISSING', 'red', ansi) + ansi['italic']}\t:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task."
232232
if mismatched_keys:
233-
tips += f"\n- {_color('MISMATCH', 'yellow', ansi) + ansi['italic']}\t:ckpt weights were loaded, but they did not match the original empty weight."
234-
if misc:
235-
tips += f"\n- {_color('MISC', 'purple', ansi) + ansi['italic']}\t:originate from the conversion scheme"
233+
tips += f"\n- {_color('MISMATCH', 'yellow', ansi) + ansi['italic']}\t:ckpt weights were loaded, but they did not match the original empty weight shapes."
234+
if conversion_errors:
235+
tips += f"\n- {_color('CONVERSION', 'purple', ansi) + ansi['italic']}\t:originate from the conversion scheme"
236236
tips += f"{ansi['reset']}"
237237

238+
# Log the report as warning
238239
logger.warning(prelude + table + tips)
240+
241+
# Re-raise in those case, after the report
242+
if conversion_errors:
243+
raise RuntimeError(
244+
"We encountered some issues during automatic conversion of the weights. For details look at the `CONVERSION` entries of "
245+
"the above report!"
246+
)
239247
if not ignore_mismatched_sizes and mismatched_keys:
240248
raise RuntimeError(
241249
"You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!"
242250
)
251+
243252
return prelude + table + tips

0 commit comments

Comments
 (0)