Skip to content

Conversation

@zucchini-nlp
Copy link
Member

What does this PR do?

Attempt to fix #41093. I believe that the is_prefill() logic had edge cases which were caught in the linked issue. Let's remove it since the position ids are also prepared in prepare_inputs_for_generation.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen2_5_vl

@zucchini-nlp zucchini-nlp requested a review from gante September 26, 2025 08:38
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new code is not compile-compatible, but it should be fine -- if we provide position_ids, it is not reached 🤗

@zucchini-nlp zucchini-nlp merged commit 1ec0b54 into huggingface:main Oct 6, 2025
19 checks passed
AhnJoonSung pushed a commit to AhnJoonSung/transformers that referenced this pull request Oct 12, 2025
@BenjaminBossan
Copy link
Member

BenjaminBossan commented Nov 27, 2025

Hi @zucchini-nlp unfortunately this PR (1ec0b544140feec6a6ff804932bd83c03851732b) breaks Qwen 2.5 with some PEFT methods. This time, it's not prefix tuning :D but it's prompt tuning. A reproducer:

import torch
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from peft import get_peft_model, PromptTuningConfig, PromptTuningInit, TaskType

max_new_tokens = 40
model_id =  "Qwen/Qwen2.5-VL-3B-Instruct"
text = "Discuss the most important work by Mary Shelley."
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(text, return_tensors="pt").to(0)

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, device_map=0)
prompt_tune_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.RANDOM,
    num_virtual_tokens=5,
)
model = get_peft_model(model, prompt_tune_config)
model.prompt_encoder.default.embedding.weight.data.zero_()  # make peft almost no-op

torch.manual_seed(0)
with torch.no_grad():
    generated_ids = model.generate(
        # seq len should be 5 virtual + 9 normal tokens
        **inputs,
        max_new_tokens=max_new_tokens,
        # use_cache=False,  # without cache, it works
    )
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])

The error is:

    generated_ids = model.generate(
                    ^^^^^^^^^^^^^^^
  File "/home/name/work/forks/peft/src/peft/peft_model.py", line 2050, in generate
    outputs = self.base_model.generate(**kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/generation/utils.py", line 2542, in generate
    result = decoding_method(
             ^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/generation/utils.py", line 2765, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/utils/generic.py", line 783, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1464, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1293, in forward
    outputs = self.language_model(
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 891, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 741, in forward
    hidden_states, self_attn_weights = self.self_attn(
                                       ^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 666, in forward
    attn_output, attn_weights = attention_interface(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/integrations/sdpa_attention.py", line 96, in sdpa_attention_forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [16, 29] but got: [16, 15].

The mismatching shapes stem from the key, which has seq len 29, whereas query and value have 15. With the previous commit, all of them have seq len 14, which is what is actually expected (5 virtual tokens + 9 normal input tokens).

When disabling cache or with other models, like Llama or Qwen3 VL, this error does not occur.

To give a short description, in prompt tuning, we create some extra embeddings, concat them with the inputs_embeds, and pass the larger embeds into the model. We else remove the position_ids from the model kwargs, maybe that's related?

The previous PR, 0947b9042c7eae073b0e4f641f65c13647705a30, still works.

@zucchini-nlp
Copy link
Member Author

Interesting since this block was copied from Qwen2-VL to have consistency 🤔 (and hoping it will fix the linked issue)

@BenjaminBossan
Copy link
Member

I got a very similar error with "Qwen/Qwen2-VL-2B-Instruct":

    attn_output = torch.nn.functional.scaled_dot_product_attention(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [12, 29] but got: [12, 15].

@zucchini-nlp
Copy link
Member Author

I will take a look after v5-rc0. If it fails on qwen-vl, then it has been there for a looong time haha

@BenjaminBossan
Copy link
Member

I will take a look after v5-rc0. If it fails on qwen-vl, then it has been there for a looong time haha

Thanks. I'd say it's not super urgent, it can wait for after v5.

@zucchini-nlp
Copy link
Member Author

@BenjaminBossan seems like cache_position has also disappeared which is causing the issue. We try to infer prefill vs decoding staged based on cache position because Qwen-VL build position ids differently in these stages. Is it deleted also on purpose?

@BenjaminBossan
Copy link
Member

Thanks for investigating @zucchini-nlp. I investigated why that happens and found this comment:

For transformers>=4.38.0 - for some architectures such as Llama, cache_position is passed in the forward
pass to keep track of the position ids of the cache. We have to pop that from model_kwargs as
cache_position is properly created by the model, using the passed inputs_embeds:

https://github.com/huggingface/peft/blob/b10527e82c2171568f538f5b822817e8a753672a/src/peft/peft_model.py#L2179-L2183

The original PR that introduced it was huggingface/peft#1484. So IIUC, for these Qwen VL models, the comment is not true that the cache_position is created on the fly if it's missing. Some quick solutions that come to mind:

  1. PEFT needs to keep the cache_position and adjust it correctly.
  2. Same as 1. but only for some architectures like Qwen VL.
  3. Adjust Qwen VL in transformers to calculate the cache position.

As to how to adjust the cache_position correctly, do you think this is correct?

is_prefill = (model_kwargs.get("cache_position") is not None) and (model_kwargs["cache_position"][0] == 0)
...
if is_prefill:
    # virtual tokens are prepended to the inputs_embeds, so extend the cache position
    new_seq_len = model_kwargs['inputs_embeds'].shape[1]
    model_kwargs["cache_position"] = torch.arange(new_seq_len).to(dtype=model_kwargs["cache_position"].dtype, device=model_kwargs["cache_position"].device)
else:
    # leave model_kwargs["cache_position"] as is

@zucchini-nlp
Copy link
Member Author

Ah I see, that makes sense now. In that case the best would be to fix it on transformers-side and adjust the postion id preparation step. Instead of assuming that cache_position = None means prefill, we would need to check past key values as well. I can make a PR on Monday

@BenjaminBossan
Copy link
Member

Thanks so much @zucchini-nlp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

IndexError: The shape of the mask [1406] at index 0 does not match the shape of the indexed tensor [1405] at index 0

4 participants