@@ -409,7 +409,7 @@ def convert(
409409 config = None ,
410410 hf_quantizer = None ,
411411 missing_keys : Optional [MutableSet [str ]] = None ,
412- misc : Optional [MutableMapping [str , str ]] = None ,
412+ conversion_errors : Optional [MutableMapping [str , str ]] = None ,
413413 ):
414414 # Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
415415 # attribute during the whole process
@@ -421,7 +421,9 @@ def convert(
421421 collected_tensors = {target_key : collected_tensors [self .source_patterns [0 ]]}
422422
423423 if hf_quantizer is not None and self .quantization_operation is not None :
424- with log_to_misc (layer_name , misc , (len (collected_tensors ), layer_name ), self .quantization_operation ):
424+ with log_conversion_errors (
425+ layer_name , conversion_errors , (len (collected_tensors ), layer_name ), self .quantization_operation
426+ ):
425427 collected_tensors = self .quantization_operation .convert (
426428 collected_tensors ,
427429 source_patterns = self .source_patterns ,
@@ -432,7 +434,7 @@ def convert(
432434 missing_keys = missing_keys ,
433435 )
434436
435- return collected_tensors , misc
437+ return collected_tensors , conversion_errors
436438
437439
438440@dataclass (slots = True )
@@ -455,14 +457,14 @@ def convert(
455457 config = None ,
456458 hf_quantizer = None ,
457459 missing_keys : Optional [MutableSet [str ]] = None ,
458- misc : Optional [MutableMapping [str , str ]] = None ,
460+ conversion_errors : Optional [MutableMapping [str , str ]] = None ,
459461 ):
460462 # Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
461463 # attribute during the whole process
462464 collected_tensors = self .materialize_tensors ()
463465
464466 for op in self .operations :
465- with log_to_misc (layer_name , misc , (len (collected_tensors ), layer_name ), op ):
467+ with log_conversion_errors (layer_name , conversion_errors , (len (collected_tensors ), layer_name ), op ):
466468 collected_tensors = op .convert (
467469 collected_tensors ,
468470 source_patterns = self .source_patterns ,
@@ -489,7 +491,9 @@ def convert(
489491 pass
490492
491493 if hf_quantizer is not None and self .quantization_operation is not None :
492- with log_to_misc (layer_name , misc , (len (collected_tensors ), layer_name ), self .quantization_operation ):
494+ with log_conversion_errors (
495+ layer_name , conversion_errors , (len (collected_tensors ), layer_name ), self .quantization_operation
496+ ):
493497 collected_tensors = self .quantization_operation .convert (
494498 collected_tensors ,
495499 source_patterns = self .source_patterns ,
@@ -499,7 +503,7 @@ def convert(
499503 model = model ,
500504 missing_keys = missing_keys ,
501505 )
502- return collected_tensors , misc
506+ return collected_tensors , conversion_errors
503507
504508
505509# For I/O bound operations (i.e. here reading files), it is better to have fewer threads, e.g. 4 is a good default.
@@ -560,13 +564,14 @@ def dot_natural_key(s: str):
560564
561565
562566@contextmanager
563- def log_to_misc (
567+ def log_conversion_errors (
564568 first_target_key : str ,
565- misc : MutableMapping [str , str ],
569+ conversion_errors : MutableMapping [str , str ],
566570 extras : Any = None ,
567571 op : Union [list [ConversionOps ], ConversionOps , None ] = None ,
568572):
569- # A simple helper to handle errors with contextual messages.
573+ """Catch all exceptions during `convert` calls, and log the errors for later. Re-raise a `SkipParameters` exception
574+ that will be catched later to skip the parameters that raised the original Exception."""
570575 try :
571576 yield
572577 except Exception as e :
@@ -585,17 +590,19 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
585590 if isinstance (extras , tuple ) and len (extras ) == 2 :
586591 length , target_keys = extras
587592 descriptor = f"{ op_name } " if op_name else ""
588- misc [first_target_key ] = (
593+ conversion_errors [first_target_key ] = (
589594 f"{ e } \n Error: { descriptor } on tensors destined for { target_keys } . Ckpt contains: { length } "
590595 )
591596 elif isinstance (extras , str ):
592597 suffix = f" via { op_name } " if op_name else ""
593- misc [first_target_key ] = f"{ e } \n Error{ suffix } when processing parameter { extras } "
598+ conversion_errors [first_target_key ] = f"{ e } \n Error{ suffix } when processing parameter { extras } "
594599 elif extras is None and op_name :
595- misc [first_target_key ] = f"{ op_name } : { e } "
600+ conversion_errors [first_target_key ] = f"{ op_name } : { e } "
596601 else :
597- misc [first_target_key ] = f"{ extras } |Error: { e } "
598- raise SkipLayer ()
602+ conversion_errors [first_target_key ] = f"{ extras } |Error: { e } "
603+
604+ # Raise a specific Exception that we can catch easily
605+ raise SkipParameters ()
599606
600607
601608def set_param_for_module (
@@ -604,44 +611,42 @@ def set_param_for_module(
604611 param_value : torch .Tensor ,
605612 mismatch_keys : MutableSet [tuple [str , torch .Size , torch .Size ]],
606613 missing_keys : MutableSet [str ],
607- misc : MutableMapping [str , Any ],
608614 unexpected_keys : MutableSet [str ],
609615 distributed_operation : Optional [TensorParallelLayer ],
610616 hf_quantizer : HfQuantizer ,
611617):
612- with log_to_misc (target_name , misc , target_name ):
613- module_path , _ , param_name = target_name .rpartition ("." )
614- module_obj = model .get_submodule (module_path ) if module_path else model
618+ module_path , _ , param_name = target_name .rpartition ("." )
619+ module_obj = model .get_submodule (module_path ) if module_path else model
615620
616- ref = getattr (module_obj , param_name )
617- if ref is None :
618- unexpected_keys .add (target_name )
621+ ref = getattr (module_obj , param_name )
622+ if ref is None :
623+ unexpected_keys .add (target_name )
624+ else :
625+ use_dtensor = hasattr (distributed_operation , "use_dtensor" ) and distributed_operation .use_dtensor
626+ if not isinstance (param_value , torch .nn .Parameter ):
627+ if distributed_operation is not None :
628+ param_value = DTensor .from_local (
629+ param_value ,
630+ distributed_operation .device_mesh ,
631+ getattr (distributed_operation , "shard" , Replicate ()),
632+ run_check = False ,
633+ shape = ref .size (),
634+ stride = ref .stride (),
635+ )
636+ if not use_dtensor :
637+ # we convert to local
638+ param_value = param_value .to_local ()
639+ if param_name not in module_obj ._buffers :
640+ param_value = torch .nn .Parameter (param_value , requires_grad = param_value .is_floating_point ())
641+
642+ # Remove from missing keys (it's either mismatched, or all good)
643+ missing_keys .discard (target_name )
644+ if ref is not None and ref .shape != param_value .shape and hf_quantizer is None :
645+ mismatch_keys .add ((target_name , param_value .shape , ref .shape ))
619646 else :
620- use_dtensor = hasattr (distributed_operation , "use_dtensor" ) and distributed_operation .use_dtensor
621- if not isinstance (param_value , torch .nn .Parameter ):
622- if distributed_operation is not None :
623- param_value = DTensor .from_local (
624- param_value ,
625- distributed_operation .device_mesh ,
626- getattr (distributed_operation , "shard" , Replicate ()),
627- run_check = False ,
628- shape = ref .size (),
629- stride = ref .stride (),
630- )
631- if not use_dtensor :
632- # we convert to local
633- param_value = param_value .to_local ()
634- if param_name not in module_obj ._buffers :
635- param_value = torch .nn .Parameter (param_value , requires_grad = param_value .is_floating_point ())
636-
637- # Remove from missing keys (it's either mismatched, or all good)
638- missing_keys .discard (target_name )
639- if ref is not None and ref .shape != param_value .shape and hf_quantizer is None :
640- mismatch_keys .add ((target_name , param_value .shape , ref .shape ))
641- else :
642- # super important otherwise _init_weight will re-init the param
643- param_value ._is_hf_initialized = True
644- setattr (module_obj , param_name , param_value )
647+ # super important otherwise _init_weight will re-init the param
648+ param_value ._is_hf_initialized = True
649+ setattr (module_obj , param_name , param_value )
645650
646651
647652def offload_and_maybe_resave_param (
@@ -663,8 +668,9 @@ def offload_and_maybe_resave_param(
663668 return disk_offload_index
664669
665670
666- class SkipLayer (Exception ):
667- """Control-flow sentinel: abort processing of the current layer only."""
671+ class SkipParameters (Exception ):
672+ """Control-flow sentinel: abort processing of the current parameters only (that were supposed to be created
673+ by a WeightConverter)."""
668674
669675 pass
670676
@@ -818,7 +824,7 @@ def convert_and_load_state_dict_in_model(
818824 meta_model_state_dict = model .state_dict ()
819825 missing_keys = set (meta_model_state_dict .keys ())
820826
821- misc = {}
827+ conversion_errors = {}
822828 mismatch_keys = set ()
823829 unexpected_keys = set ()
824830
@@ -925,13 +931,13 @@ def convert_and_load_state_dict_in_model(
925931 pbar .set_postfix ({"Materializing param" : first_param_name })
926932 pbar .refresh ()
927933 try :
928- realized_value , misc = mapping .convert (
934+ realized_value , conversion_errors = mapping .convert (
929935 first_param_name ,
930936 model = model ,
931937 config = model .config ,
932938 hf_quantizer = hf_quantizer ,
933939 missing_keys = missing_keys ,
934- misc = misc ,
940+ conversion_errors = conversion_errors ,
935941 )
936942 for target_name , param in realized_value .items ():
937943 param = param [0 ] if isinstance (param , list ) else param
@@ -949,7 +955,6 @@ def convert_and_load_state_dict_in_model(
949955 param ,
950956 mismatch_keys ,
951957 missing_keys ,
952- misc ,
953958 unexpected_keys ,
954959 mapping .distributed_operation ,
955960 hf_quantizer ,
@@ -958,7 +963,7 @@ def convert_and_load_state_dict_in_model(
958963 # Cleanup all the tensors that were gathered before next iteration
959964 del realized_value
960965
961- except SkipLayer :
966+ except SkipParameters :
962967 continue
963968
964969 # Close the pool, independently of whether the code was interrupted or finished successfully
@@ -969,7 +974,7 @@ def convert_and_load_state_dict_in_model(
969974
970975 # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user)
971976 model ._weight_conversions = weight_mapping
972- return missing_keys , unexpected_keys , mismatch_keys , disk_offload_index , misc
977+ return missing_keys , unexpected_keys , mismatch_keys , disk_offload_index , conversion_errors
973978
974979
975980def revert_weight_conversion (model : PreTrainedModel , state_dict : dict [str , torch .Tensor ]):
@@ -1016,7 +1021,7 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
10161021 new_state_dict = {}
10171022 for first_param_name , reversed_converter in conversion_mapping .items ():
10181023 # Apply the reverse converter
1019- realized_value , misc = reversed_converter .convert (first_param_name , model = model , config = model .config )
1024+ realized_value , _ = reversed_converter .convert (first_param_name , model = model , config = model .config )
10201025 for target_name , param in realized_value .items ():
10211026 param = param [0 ] if isinstance (param , list ) else param
10221027 new_state_dict [target_name ] = param
0 commit comments