-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Open
Labels
Description
System Info
transformersversion: 5.0.0.dev0- Platform: Linux-5.14.0-503.23.1.el9_5.x86_64-x86_64-with-glibc2.34
- Python version: 3.10.15
- Huggingface_hub version: 1.2.2
- Safetensors version: 0.4.5
- Accelerate version: 1.11.0
- Accelerate config: not found
- DeepSpeed version: 0.17.5
- PyTorch version (accelerator?): 2.7.0+cu128 (CUDA)
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA H100 80GB HBM3
Who can help?
When running granite-4.0-micro with transformers v5 (or the latest main branch), I get degraded results.
In the working version v4.57.3, the code applies RoPE (the code supports both RoPE for dense-transformer models and no position embeddings for hybrid models)
See the relevant code-piece here
cos, sin = position_embeddings if position_embeddings is not None else (None, None)
if position_embeddings is not None:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
When upgrading to v5, the lines that apply RoPE are missing - see here
Re-applying RoPE in the v5 code version resolves the degraded performance when using granite-4.0-micro.
Happy to submit a PR to resolve this.
CC: @gabe-l-hart @shawntan @alex-jw-brooks
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Run the official code in https://huggingface.co/ibm-granite/granite-4.0-micro
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda"
model_path = "ibm-granite/granite-4.0-micro"
tokenizer = AutoTokenizer.from_pretrained(model_path)
# drop device_map if running on CPU
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
model.eval()
# change input text as desired
chat = [
{ "role": "user", "content": "Please list one IBM Research laboratory located in the United States. You should only output its name and location." },
]
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
# tokenize the text
input_tokens = tokenizer(chat, return_tensors="pt").to(device)
# generate output tokens
output = model.generate(**input_tokens,
max_new_tokens=100)
# decode output tokens into text
output = tokenizer.batch_decode(output)
# print output
print(output[0])
Expected behavior
<|start_of_role|>system<|end_of_role|>You are a helpful assistant. Please ensure responses are professional, accurate, and safe.<|end_of_text|>
<|start_of_role|>user<|end_of_role|>Please list one IBM Research laboratory located in the United States. You should only output its name and location.<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>Almaden Research Center, San Jose, California<|end_of_text|>
shawntan and gabe-l-hart