Skip to content

Commit 2995341

Browse files
danbevggerganov
andauthored
llama : add support for NVIDIA Nemotron 3 Nano (ggml-org#18058)
* llama : add support for NVIDIA Nemotron Nano 3 This commit adds support for the NVIDIA Nemotron Nano 3 model, enabling the conversion and running of this model. Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 40d9c39 commit 2995341

File tree

9 files changed

+267
-23
lines changed

9 files changed

+267
-23
lines changed

convert_hf_to_gguf.py

Lines changed: 110 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8490,8 +8490,18 @@ def set_vocab(self):
84908490
class NemotronHModel(GraniteHybridModel):
84918491
"""Hybrid mamba2/attention model from NVIDIA"""
84928492
model_arch = gguf.MODEL_ARCH.NEMOTRON_H
8493+
is_moe: bool = False
84938494

84948495
def __init__(self, *args, **kwargs):
8496+
# We have to determine the correct model architecture (MoE vs non-MoE) before
8497+
# calling the parent __init__. This is because the parent constructor
8498+
# uses self.model_arch to build the tensor name map, and all MoE-specific
8499+
# mappings would be missed if it were called with the default non-MoE arch.
8500+
hparams = ModelBase.load_hparams(args[0], self.is_mistral_format)
8501+
if "num_experts_per_tok" in hparams:
8502+
self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE
8503+
self.is_moe = True
8504+
84958505
super().__init__(*args, **kwargs)
84968506

84978507
# Save the top-level head_dim for later
@@ -8503,9 +8513,11 @@ def __init__(self, *args, **kwargs):
85038513

85048514
# Update the ssm / attn / mlp layers
85058515
# M: Mamba2, *: Attention, -: MLP
8516+
# MoE:
8517+
# M: Mamba2, *: Attention, E: Expert
85068518
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
85078519
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
8508-
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
8520+
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")]
85098521

85108522
def get_attn_layers(self):
85118523
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
@@ -8521,18 +8533,110 @@ def set_gguf_parameters(self):
85218533
# Set feed_forward_length
85228534
# NOTE: This will trigger an override warning. This is preferrable to
85238535
# duplicating all the parent logic
8524-
n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
8525-
self.gguf_writer.add_feed_forward_length([
8526-
n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
8527-
])
8536+
if not self.is_moe:
8537+
n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
8538+
self.gguf_writer.add_feed_forward_length([
8539+
n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
8540+
])
8541+
else:
8542+
moe_intermediate_size = self.hparams["moe_intermediate_size"]
8543+
self.gguf_writer.add_feed_forward_length([
8544+
moe_intermediate_size if i in self._mlp_layers else 0 for i in range(self.block_count)
8545+
])
8546+
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
8547+
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
8548+
self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["moe_shared_expert_intermediate_size"])
8549+
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
8550+
self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"])
8551+
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
8552+
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
8553+
self.gguf_writer.add_expert_group_count(self.hparams["n_group"])
8554+
8555+
# number of experts used per token (top-k)
8556+
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
8557+
self.gguf_writer.add_expert_used_count(n_experts_used)
85288558

85298559
def set_vocab(self):
85308560
super().set_vocab()
85318561

85328562
# The tokenizer _does_ add a BOS token (via post_processor type
85338563
# TemplateProcessing) but does not set add_bos_token to true in the
85348564
# config, so we need to explicitly override it here.
8535-
self.gguf_writer.add_add_bos_token(True)
8565+
if not self.is_moe:
8566+
self.gguf_writer.add_add_bos_token(True)
8567+
8568+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8569+
if self.is_moe and bid is not None:
8570+
if name.endswith("mixer.gate.e_score_correction_bias"):
8571+
new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
8572+
mapped_name = self.map_tensor_name(new_name)
8573+
return [(mapped_name, data_torch)]
8574+
8575+
if name.endswith("mixer.dt_bias"):
8576+
new_name = name.replace("dt_bias", "dt.bias")
8577+
mapped_name = self.map_tensor_name(new_name)
8578+
return [(mapped_name, data_torch)]
8579+
8580+
if name.endswith("mixer.conv1d.weight"):
8581+
squeezed_data = data_torch.squeeze()
8582+
mapped_name = self.map_tensor_name(name)
8583+
return [(mapped_name, squeezed_data)]
8584+
8585+
if name.endswith("mixer.A_log"):
8586+
transformed_data = -torch.exp(data_torch)
8587+
reshaped_data = transformed_data.squeeze().reshape(-1, 1)
8588+
mapped_name = self.map_tensor_name(name)
8589+
return [(mapped_name, reshaped_data)]
8590+
8591+
if name.endswith("mixer.D"):
8592+
reshaped_data = data_torch.squeeze().reshape(-1, 1)
8593+
mapped_name = self.map_tensor_name(name)
8594+
return [(mapped_name, reshaped_data)]
8595+
8596+
if name.endswith("mixer.norm.weight"):
8597+
reshaped_data = data_torch.reshape(8, 512)
8598+
mapped_name = self.map_tensor_name(name)
8599+
return [(mapped_name, reshaped_data)]
8600+
8601+
if name.find("mixer.experts") != -1:
8602+
n_experts = self.hparams["n_routed_experts"]
8603+
assert bid is not None
8604+
8605+
if self._experts is None:
8606+
self._experts = [{} for _ in range(self.block_count)]
8607+
8608+
self._experts[bid][name] = data_torch
8609+
8610+
if len(self._experts[bid]) >= n_experts * 2:
8611+
# merge the experts into a single tensor
8612+
tensors: list[tuple[str, Tensor]] = []
8613+
for w_name in ["down_proj", "up_proj"]:
8614+
datas: list[Tensor] = []
8615+
8616+
for xid in range(n_experts):
8617+
ename = f"backbone.layers.{bid}.mixer.experts.{xid}.{w_name}.weight"
8618+
datas.append(self._experts[bid][ename])
8619+
del self._experts[bid][ename]
8620+
8621+
data_torch = torch.stack(datas, dim=0)
8622+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
8623+
new_name = self.map_tensor_name(merged_name)
8624+
tensors.append((new_name, data_torch))
8625+
8626+
return tensors
8627+
else:
8628+
return []
8629+
8630+
return super().modify_tensors(data_torch, name, bid)
8631+
8632+
def prepare_tensors(self):
8633+
super().prepare_tensors()
8634+
8635+
if self._experts is not None:
8636+
# flatten `list[dict[str, Tensor]]` into `list[str]`
8637+
experts = [k for d in self._experts for k in d.keys()]
8638+
if len(experts) > 0:
8639+
raise ValueError(f"Unprocessed experts: {experts}")
85368640

85378641

85388642
@ModelBase.register("BailingMoeForCausalLM")

gguf-py/gguf/constants.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ class MODEL_ARCH(IntEnum):
413413
JAIS = auto()
414414
NEMOTRON = auto()
415415
NEMOTRON_H = auto()
416+
NEMOTRON_H_MOE = auto()
416417
EXAONE = auto()
417418
EXAONE4 = auto()
418419
GRANITE = auto()
@@ -786,6 +787,7 @@ class MODEL_TENSOR(IntEnum):
786787
MODEL_ARCH.JAIS: "jais",
787788
MODEL_ARCH.NEMOTRON: "nemotron",
788789
MODEL_ARCH.NEMOTRON_H: "nemotron_h",
790+
MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe",
789791
MODEL_ARCH.EXAONE: "exaone",
790792
MODEL_ARCH.EXAONE4: "exaone4",
791793
MODEL_ARCH.GRANITE: "granite",
@@ -2529,6 +2531,33 @@ class MODEL_TENSOR(IntEnum):
25292531
MODEL_TENSOR.FFN_DOWN,
25302532
MODEL_TENSOR.FFN_UP,
25312533
],
2534+
MODEL_ARCH.NEMOTRON_H_MOE: [
2535+
MODEL_TENSOR.TOKEN_EMBD,
2536+
MODEL_TENSOR.OUTPUT_NORM,
2537+
MODEL_TENSOR.OUTPUT,
2538+
MODEL_TENSOR.ATTN_NORM,
2539+
MODEL_TENSOR.SSM_IN,
2540+
MODEL_TENSOR.SSM_CONV1D,
2541+
MODEL_TENSOR.SSM_DT,
2542+
MODEL_TENSOR.SSM_A,
2543+
MODEL_TENSOR.SSM_D,
2544+
MODEL_TENSOR.SSM_NORM,
2545+
MODEL_TENSOR.SSM_OUT,
2546+
MODEL_TENSOR.ATTN_Q,
2547+
MODEL_TENSOR.ATTN_K,
2548+
MODEL_TENSOR.ATTN_V,
2549+
MODEL_TENSOR.ATTN_OUT,
2550+
MODEL_TENSOR.FFN_DOWN,
2551+
MODEL_TENSOR.FFN_UP,
2552+
# experts
2553+
MODEL_TENSOR.FFN_GATE_INP,
2554+
MODEL_TENSOR.FFN_UP_EXP,
2555+
MODEL_TENSOR.FFN_DOWN_EXP,
2556+
# shared expert
2557+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2558+
MODEL_TENSOR.FFN_UP_SHEXP,
2559+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2560+
],
25322561
MODEL_ARCH.EXAONE: [
25332562
MODEL_TENSOR.TOKEN_EMBD,
25342563
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ class TensorNameMap:
379379
"model.layers.{bid}.feed_forward.gate", # lfm2moe
380380
"model.layers.{bid}.mlp.router.gate", # afmoe
381381
"layers.{bid}.gate", # mistral-large
382+
"backbone.layers.{bid}.mixer.gate", # nemotron-h-moe
382383
),
383384

384385
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -392,6 +393,7 @@ class TensorNameMap:
392393
"model.layers.{bid}.mlp.expert_bias", # afmoe
393394
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
394395
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
396+
"backbone.layers.{bid}.mixer.gate.e_score_correction" # nemotron-h-moe
395397
),
396398

397399
# Feed-forward up
@@ -440,7 +442,7 @@ class TensorNameMap:
440442
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
441443
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
442444
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
443-
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe
445+
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe, nemotron-h-moe (merged)
444446
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
445447
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
446448
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
@@ -454,6 +456,7 @@ class TensorNameMap:
454456
"model.layers.{bid}.feed_forward.down_proj",
455457
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
456458
"layers.{bid}.shared_experts.w3", # mistral-large
459+
"backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe
457460
),
458461

459462
MODEL_TENSOR.FFN_UP_CHEXP: (
@@ -548,7 +551,7 @@ class TensorNameMap:
548551
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
549552
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
550553
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
551-
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe
554+
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe nemotron-h-moe (merged)
552555
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
553556
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
554557
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
@@ -563,6 +566,7 @@ class TensorNameMap:
563566
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
564567
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
565568
"layers.{bid}.shared_experts.w2", # mistral-large
569+
"backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe
566570
),
567571

568572
MODEL_TENSOR.FFN_DOWN_CHEXP: (
@@ -706,6 +710,7 @@ class TensorNameMap:
706710
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
707711
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
708712
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
713+
"backbone.layers.{bid}.mixer.dt", # nemotron-h-moe
709714
),
710715

711716
MODEL_TENSOR.SSM_DT_NORM: (

src/llama-arch.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
7575
{ LLM_ARCH_JAIS, "jais" },
7676
{ LLM_ARCH_NEMOTRON, "nemotron" },
7777
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
78+
{ LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
7879
{ LLM_ARCH_EXAONE, "exaone" },
7980
{ LLM_ARCH_EXAONE4, "exaone4" },
8081
{ LLM_ARCH_RWKV6, "rwkv6" },
@@ -1763,6 +1764,39 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
17631764
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
17641765
},
17651766
},
1767+
{
1768+
LLM_ARCH_NEMOTRON_H_MOE,
1769+
{
1770+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1771+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1772+
{ LLM_TENSOR_OUTPUT, "output" },
1773+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1774+
// mamba(2) ssm layers
1775+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
1776+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
1777+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
1778+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
1779+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
1780+
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
1781+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1782+
// attention layers
1783+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1784+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1785+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1786+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1787+
// dense FFN
1788+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1789+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1790+
// MoE FFN (for MoE layers)
1791+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1792+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1793+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1794+
{ LLM_TENSOR_FFN_EXP_PROBS_B,"blk.%d.exp_probs_b" },
1795+
// MoE shared expert layer
1796+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1797+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1798+
},
1799+
},
17661800
{
17671801
LLM_ARCH_EXAONE,
17681802
{
@@ -2817,6 +2851,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
28172851
case LLM_ARCH_LFM2:
28182852
case LLM_ARCH_LFM2MOE:
28192853
case LLM_ARCH_NEMOTRON_H:
2854+
case LLM_ARCH_NEMOTRON_H_MOE:
28202855
case LLM_ARCH_QWEN3NEXT:
28212856
return true;
28222857
default:

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ enum llm_arch {
7979
LLM_ARCH_JAIS,
8080
LLM_ARCH_NEMOTRON,
8181
LLM_ARCH_NEMOTRON_H,
82+
LLM_ARCH_NEMOTRON_H_MOE,
8283
LLM_ARCH_EXAONE,
8384
LLM_ARCH_EXAONE4,
8485
LLM_ARCH_RWKV6,

src/llama-graph.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
10891089
cur = ggml_relu(ctx0, cur);
10901090
cb(cur, "ffn_moe_relu", il);
10911091
} break;
1092+
case LLM_FFN_RELU_SQR:
1093+
if (gate_exps) {
1094+
// TODO: add support for gated squared relu
1095+
GGML_ABORT("fatal error: gated squared relu not implemented");
1096+
} else {
1097+
cur = ggml_relu(ctx0, cur);
1098+
cur = ggml_sqr(ctx0, cur);
1099+
cb(cur, "ffn_moe_relu_sqr", il);
1100+
} break;
10921101
default:
10931102
GGML_ABORT("fatal error");
10941103
}

0 commit comments

Comments
 (0)