Skip to content

Conversation

@HelloKS
Copy link
Contributor

@HelloKS HelloKS commented Dec 14, 2025

Make sure to read the contributing guidelines before submitting a PR

Hello, This is my first contribution for llama.cpp.

This PR adds support for "KORMo-Team/KORMo-10B-sft". Trained from scratch with open Korean resources and more.

From what I understand, this model shares architecture with LLaMA 3 family but with different tokenizer and tensor name. I tested locally and it seems working well.

image

Let me know if I can improve this!

For test: https://huggingface.co/hell0ks/KORMo-10B-sft-gguf

@CISC
Copy link
Collaborator

CISC commented Dec 14, 2025

It's not llama3, it looks like qwen3 to me.

@HelloKS
Copy link
Contributor Author

HelloKS commented Dec 14, 2025

It's not llama3, it looks like qwen3 to me.

I checked qwen3 transformer implementation with KORMo's custom code, but Q/K Normalization does not apply on KORMo one.

image

Qwen3

        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

KORMo

      query_states = self.q_proj(hidden_states)
      key_states = self.k_proj(hidden_states)
      value_states = self.v_proj(hidden_states)

        query_states = query_states.view(hidden_shape).transpose(1, 2)
        key_states = key_states.view(hidden_shape).transpose(1, 2)
        value_states = value_states.view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

LLaMA

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

... And in model's paper: "The base structure followed the Llama-3 series architecture"
image

Not sure why they decided to use "KORMoForCausalLM" beside "LlamaForCausalLM" then..

@HelloKS HelloKS changed the title model : add KORMo model model: add KORMo model Dec 14, 2025
@github-actions github-actions bot added the python python script changes label Dec 14, 2025
@CISC
Copy link
Collaborator

CISC commented Dec 14, 2025

It's not llama3, it looks like qwen3 to me.

I checked qwen3 transformer implementation with KORMo's custom code, but Q/K Normalization does not apply on KORMo one.

Right, so probably qwen2 then.

... And in model's paper: "The base structure followed the Llama-3 series architecture"

It's a bit of an odd statement.

Not sure why they decided to use "KORMoForCausalLM" beside "LlamaForCausalLM" then..

That is indeed the weirdest thing of all, there's nothing there warranting a new arch.

@CISC
Copy link
Collaborator

CISC commented Dec 14, 2025

I suggest trying to move this to qwen2, the pre-tokenizer certainly is qwen2, and the chat template is almost identical to qwen's as well...

@HelloKS
Copy link
Contributor Author

HelloKS commented Dec 14, 2025

I suggest trying to move this to qwen2, the pre-tokenizer certainly is qwen2, and the chat template is almost identical to qwen's as well...

It doesn't even have sliding window attention(SWA) like Qwen2 as per it's "custom" modeling code. (which just look like copy & pasted from llama codes from what I see)

Not sure about tokenizer tho.
Update: I checked Qwen3, LLaMA 3.1, and KORMo one and pretokenizer regex and it's configs are identical

  "pre_tokenizer": {
    "type": "Sequence",
    "pretokenizers": [
      {
        "type": "Split",
        "pattern": {
          "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
        },
        "behavior": "Isolated",
        "invert": false
      },
      {
        "type": "ByteLevel",
        "add_prefix_space": false,
        "trim_offsets": false,
        "use_regex": false
      }
    ]
  },

It seems they:

  • Took LLaMA Architecture
  • Optimized tokenizer a bit for Korean and English only
  • "Hardcode" some config from llama's modeling code
  • Tried something with FA3 implementation, but they never started
  • Changed tensor name for some reason
  • they just slapped "KORMoForCausalLM" instead standard llama one.

I'm getting confused. What do you think?

@CISC
Copy link
Collaborator

CISC commented Dec 14, 2025

Not sure about tokenizer tho. Update: I checked Qwen3, LLaMA 3.1, and KORMo one and pretokenizer regex and it's configs are identical

They are not, LLaMA 3.1's regex differs slightly, KORMo is using Qwen's.

@CISC
Copy link
Collaborator

CISC commented Dec 14, 2025

I'm getting confused. What do you think?

As I said, try moving everything to qwen2.

@HelloKS
Copy link
Contributor Author

HelloKS commented Dec 14, 2025

They are not, LLaMA 3.1's regex differs slightly, KORMo is using Qwen's.

I'm sorry, you were right. PreTokenizers are different. My fault.

Qwen3 and KORMo:

"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"

LLaMA 3.1:

"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"

I changed as you suggested.

Add:
For convert_hf_to_gguf.py, I think LLaMA is still right. It doesn't have Q/K/V bias (Qwen2), and Q/K Norm (Qwen3).

@CISC
Copy link
Collaborator

CISC commented Dec 14, 2025

Add: For convert_hf_to_gguf.py, I think LLaMA is still right. It doesn't have Q/K/V bias (Qwen2), and Q/K Norm (Qwen3).

I still think Qwen2 is correct, it just needs to be updated so that bias is actually optional, like in Qwen2MoE. In fact, just tested it and it works perfectly.

llama.cpp/src/llama-model.cpp

Lines 3420 to 3423 in 5dbb758

// optional bias tensors
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);

if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}

There is however something screwy with the chat template, not sure what's going on, but with --jinja enabled there seems to be lots of whitespace which completely ruins generation, forcing it to f.ex. --chat-template chatml makes it work decently, but is prone to thinking loops due to missing <think> token.

Edit: Oh, I see, you submitted a fixed one, what a weird bug. :)

@HelloKS
Copy link
Contributor Author

HelloKS commented Dec 15, 2025

I still think Qwen2 is correct, it just needs to be updated so that bias is actually optional, like in Qwen2MoE. In fact, just tested it and it works perfectly.

llama.cpp/src/llama-model.cpp

Lines 3420 to 3423 in 5dbb758

// optional bias tensors
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);

if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}

Thanks for guiding me the right direction. I appreciate it.

Now model happily runs with Qwen2 architecture, and Qwen 2.5 also works well.

  • KORMo-Team/KORMo-10B-sft
kormo
  • Qwen/Qwen2.5-7B-Instruct
qwen25

There is however something screwy with the chat template, not sure what's going on, but with --jinja enabled there seems to be lots of whitespace which completely ruins generation, forcing it to f.ex. --chat-template chatml makes it work decently, but is prone to thinking loops due to missing <think> token.

Edit: Oh, I see, you submitted a fixed one, what a weird bug. :)

Yes, it was me lol. They merged the fix today, so no more problems!

@github-actions github-actions bot added the model Model specific label Dec 15, 2025
@CISC CISC merged commit 9d52f17 into ggml-org:master Dec 15, 2025
49 of 78 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants