@@ -8490,8 +8490,18 @@ def set_vocab(self):
84908490class NemotronHModel (GraniteHybridModel ):
84918491 """Hybrid mamba2/attention model from NVIDIA"""
84928492 model_arch = gguf .MODEL_ARCH .NEMOTRON_H
8493+ is_moe : bool = False
84938494
84948495 def __init__ (self , * args , ** kwargs ):
8496+ # We have to determine the correct model architecture (MoE vs non-MoE) before
8497+ # calling the parent __init__. This is because the parent constructor
8498+ # uses self.model_arch to build the tensor name map, and all MoE-specific
8499+ # mappings would be missed if it were called with the default non-MoE arch.
8500+ hparams = ModelBase .load_hparams (args [0 ], self .is_mistral_format )
8501+ if "num_experts_per_tok" in hparams :
8502+ self .model_arch = gguf .MODEL_ARCH .NEMOTRON_H_MOE
8503+ self .is_moe = True
8504+
84958505 super ().__init__ (* args , ** kwargs )
84968506
84978507 # Save the top-level head_dim for later
@@ -8503,9 +8513,11 @@ def __init__(self, *args, **kwargs):
85038513
85048514 # Update the ssm / attn / mlp layers
85058515 # M: Mamba2, *: Attention, -: MLP
8516+ # MoE:
8517+ # M: Mamba2, *: Attention, E: Expert
85068518 hybrid_override_pattern = self .hparams ["hybrid_override_pattern" ]
85078519 self ._ssm_layers = [i for i , val in enumerate (hybrid_override_pattern ) if val == "M" ]
8508- self ._mlp_layers = [i for i , val in enumerate (hybrid_override_pattern ) if val == "-" ]
8520+ self ._mlp_layers = [i for i , val in enumerate (hybrid_override_pattern ) if val == ( "E" if self . is_moe else "-" ) ]
85098521
85108522 def get_attn_layers (self ):
85118523 hybrid_override_pattern = self .hparams ["hybrid_override_pattern" ]
@@ -8521,18 +8533,110 @@ def set_gguf_parameters(self):
85218533 # Set feed_forward_length
85228534 # NOTE: This will trigger an override warning. This is preferrable to
85238535 # duplicating all the parent logic
8524- n_ff = self .find_hparam (["intermediate_size" , "n_inner" , "hidden_dim" ])
8525- self .gguf_writer .add_feed_forward_length ([
8526- n_ff if i in self ._mlp_layers else 0 for i in range (self .block_count )
8527- ])
8536+ if not self .is_moe :
8537+ n_ff = self .find_hparam (["intermediate_size" , "n_inner" , "hidden_dim" ])
8538+ self .gguf_writer .add_feed_forward_length ([
8539+ n_ff if i in self ._mlp_layers else 0 for i in range (self .block_count )
8540+ ])
8541+ else :
8542+ moe_intermediate_size = self .hparams ["moe_intermediate_size" ]
8543+ self .gguf_writer .add_feed_forward_length ([
8544+ moe_intermediate_size if i in self ._mlp_layers else 0 for i in range (self .block_count )
8545+ ])
8546+ self .gguf_writer .add_expert_used_count (self .hparams ["num_experts_per_tok" ])
8547+ self .gguf_writer .add_expert_feed_forward_length (self .hparams ["moe_intermediate_size" ])
8548+ self .gguf_writer .add_expert_shared_feed_forward_length (self .hparams ["moe_shared_expert_intermediate_size" ])
8549+ self .gguf_writer .add_expert_count (self .hparams ["n_routed_experts" ])
8550+ self .gguf_writer .add_expert_shared_count (self .hparams ["n_shared_experts" ])
8551+ self .gguf_writer .add_expert_weights_norm (self .hparams ["norm_topk_prob" ])
8552+ self .gguf_writer .add_expert_weights_scale (self .hparams ["routed_scaling_factor" ])
8553+ self .gguf_writer .add_expert_group_count (self .hparams ["n_group" ])
8554+
8555+ # number of experts used per token (top-k)
8556+ if (n_experts_used := self .hparams .get ("num_experts_per_tok" )) is not None :
8557+ self .gguf_writer .add_expert_used_count (n_experts_used )
85288558
85298559 def set_vocab (self ):
85308560 super ().set_vocab ()
85318561
85328562 # The tokenizer _does_ add a BOS token (via post_processor type
85338563 # TemplateProcessing) but does not set add_bos_token to true in the
85348564 # config, so we need to explicitly override it here.
8535- self .gguf_writer .add_add_bos_token (True )
8565+ if not self .is_moe :
8566+ self .gguf_writer .add_add_bos_token (True )
8567+
8568+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
8569+ if self .is_moe and bid is not None :
8570+ if name .endswith ("mixer.gate.e_score_correction_bias" ):
8571+ new_name = name .replace ("e_score_correction_bias" , "e_score_correction.bias" )
8572+ mapped_name = self .map_tensor_name (new_name )
8573+ return [(mapped_name , data_torch )]
8574+
8575+ if name .endswith ("mixer.dt_bias" ):
8576+ new_name = name .replace ("dt_bias" , "dt.bias" )
8577+ mapped_name = self .map_tensor_name (new_name )
8578+ return [(mapped_name , data_torch )]
8579+
8580+ if name .endswith ("mixer.conv1d.weight" ):
8581+ squeezed_data = data_torch .squeeze ()
8582+ mapped_name = self .map_tensor_name (name )
8583+ return [(mapped_name , squeezed_data )]
8584+
8585+ if name .endswith ("mixer.A_log" ):
8586+ transformed_data = - torch .exp (data_torch )
8587+ reshaped_data = transformed_data .squeeze ().reshape (- 1 , 1 )
8588+ mapped_name = self .map_tensor_name (name )
8589+ return [(mapped_name , reshaped_data )]
8590+
8591+ if name .endswith ("mixer.D" ):
8592+ reshaped_data = data_torch .squeeze ().reshape (- 1 , 1 )
8593+ mapped_name = self .map_tensor_name (name )
8594+ return [(mapped_name , reshaped_data )]
8595+
8596+ if name .endswith ("mixer.norm.weight" ):
8597+ reshaped_data = data_torch .reshape (8 , 512 )
8598+ mapped_name = self .map_tensor_name (name )
8599+ return [(mapped_name , reshaped_data )]
8600+
8601+ if name .find ("mixer.experts" ) != - 1 :
8602+ n_experts = self .hparams ["n_routed_experts" ]
8603+ assert bid is not None
8604+
8605+ if self ._experts is None :
8606+ self ._experts = [{} for _ in range (self .block_count )]
8607+
8608+ self ._experts [bid ][name ] = data_torch
8609+
8610+ if len (self ._experts [bid ]) >= n_experts * 2 :
8611+ # merge the experts into a single tensor
8612+ tensors : list [tuple [str , Tensor ]] = []
8613+ for w_name in ["down_proj" , "up_proj" ]:
8614+ datas : list [Tensor ] = []
8615+
8616+ for xid in range (n_experts ):
8617+ ename = f"backbone.layers.{ bid } .mixer.experts.{ xid } .{ w_name } .weight"
8618+ datas .append (self ._experts [bid ][ename ])
8619+ del self ._experts [bid ][ename ]
8620+
8621+ data_torch = torch .stack (datas , dim = 0 )
8622+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
8623+ new_name = self .map_tensor_name (merged_name )
8624+ tensors .append ((new_name , data_torch ))
8625+
8626+ return tensors
8627+ else :
8628+ return []
8629+
8630+ return super ().modify_tensors (data_torch , name , bid )
8631+
8632+ def prepare_tensors (self ):
8633+ super ().prepare_tensors ()
8634+
8635+ if self ._experts is not None :
8636+ # flatten `list[dict[str, Tensor]]` into `list[str]`
8637+ experts = [k for d in self ._experts for k in d .keys ()]
8638+ if len (experts ) > 0 :
8639+ raise ValueError (f"Unprocessed experts: { experts } " )
85368640
85378641
85388642@ModelBase .register ("BailingMoeForCausalLM" )
0 commit comments