Skip to content

Commit 6daa3ee

Browse files
authored
Fix InternVL attention when using qk_norm (38B and 78B) (#37620)
* fix internvlvision attention when using qk_norm * nit * modular
1 parent 27a25be commit 6daa3ee

File tree

2 files changed

+58
-47
lines changed

2 files changed

+58
-47
lines changed

src/transformers/models/internvl/modeling_internvl.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,7 @@ def forward(
150150
key_states = self.k_proj(hidden_states)
151151
value_states = self.v_proj(hidden_states)
152152

153-
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
154153
query_states = self.q_norm(query_states)
155-
156-
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
157154
key_states = self.k_norm(key_states)
158155

159156
query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
@@ -860,13 +857,13 @@ def get_image_features(
860857
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
861858
def forward(
862859
self,
863-
input_ids: torch.LongTensor = None,
864-
pixel_values: torch.FloatTensor = None,
860+
input_ids: Optional[torch.LongTensor] = None,
861+
pixel_values: Optional[torch.FloatTensor] = None,
865862
attention_mask: Optional[torch.Tensor] = None,
866863
position_ids: Optional[torch.LongTensor] = None,
867864
past_key_values: Optional[List[torch.FloatTensor]] = None,
868865
inputs_embeds: Optional[torch.FloatTensor] = None,
869-
vision_feature_layer: Optional[int] = None,
866+
vision_feature_layer: Optional[Union[int, List[int]]] = None,
870867
vision_feature_select_strategy: Optional[str] = None,
871868
labels: Optional[torch.LongTensor] = None,
872869
use_cache: Optional[bool] = None,

src/transformers/models/internvl/modular_internvl.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616

1717
import collections.abc
1818
from dataclasses import dataclass
19-
from typing import List, Optional, Tuple, Union
19+
from typing import Callable, List, Optional, Tuple, Union
2020

2121
import torch
2222
import torch.nn as nn
2323
import torch.utils.checkpoint
2424

2525
from ...activations import ACT2FN
26+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
2627
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
27-
from ...modeling_utils import PreTrainedModel
28+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
29+
from ...processing_utils import Unpack
2830
from ...utils import (
2931
add_code_sample_docstrings,
3032
add_start_docstrings,
@@ -92,6 +94,55 @@ def __init__(self, config: InternVLVisionConfig):
9294
self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
9395
self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
9496

97+
def forward(
98+
self,
99+
hidden_states: torch.Tensor,
100+
attention_mask: Optional[torch.Tensor] = None,
101+
output_attentions: Optional[torch.Tensor] = None,
102+
**kwargs: Unpack[FlashAttentionKwargs],
103+
):
104+
batch_size, seq_len, _ = hidden_states.size()
105+
106+
query_states = self.q_proj(hidden_states)
107+
key_states = self.k_proj(hidden_states)
108+
value_states = self.v_proj(hidden_states)
109+
110+
query_states = self.q_norm(query_states)
111+
key_states = self.k_norm(key_states)
112+
113+
query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
114+
key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
115+
value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
116+
117+
attention_interface: Callable = eager_attention_forward
118+
if self.config._attn_implementation != "eager":
119+
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
120+
logger.warning_once(
121+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
122+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
123+
)
124+
else:
125+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
126+
127+
attn_output, attn_weights = attention_interface(
128+
self,
129+
query_states,
130+
key_states,
131+
value_states,
132+
attention_mask,
133+
dropout=0.0 if not self.training else self.attention_dropout,
134+
scaling=self.scale,
135+
is_causal=False,
136+
**kwargs,
137+
)
138+
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
139+
140+
output = self.projection_layer(attn_output)
141+
output = self.projection_dropout(output)
142+
143+
outputs = (output, attn_weights) if output_attentions else (output, None)
144+
return outputs
145+
95146

96147
class InternVLVisionPreTrainedModel(PreTrainedModel):
97148
"""
@@ -609,26 +660,7 @@ def get_image_features(
609660

610661
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
611662
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
612-
def forward(
613-
self,
614-
input_ids: torch.LongTensor = None,
615-
pixel_values: torch.FloatTensor = None,
616-
attention_mask: Optional[torch.Tensor] = None,
617-
position_ids: Optional[torch.LongTensor] = None,
618-
past_key_values: Optional[List[torch.FloatTensor]] = None,
619-
inputs_embeds: Optional[torch.FloatTensor] = None,
620-
vision_feature_layer: Optional[int] = None,
621-
vision_feature_select_strategy: Optional[str] = None,
622-
labels: Optional[torch.LongTensor] = None,
623-
use_cache: Optional[bool] = None,
624-
output_attentions: Optional[bool] = None,
625-
output_hidden_states: Optional[bool] = None,
626-
return_dict: Optional[bool] = None,
627-
cache_position: Optional[torch.LongTensor] = None,
628-
logits_to_keep: Union[int, torch.Tensor] = 0,
629-
image_sizes: Optional[torch.Tensor] = None,
630-
**lm_kwargs,
631-
) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
663+
def forward(**super_kwargs):
632664
r"""
633665
Args:
634666
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -679,25 +711,7 @@ def forward(
679711
>>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
680712
The images depict the Statue of Liberty and the Golden Gate Bridge.
681713
```"""
682-
super().forward(
683-
input_ids=input_ids,
684-
pixel_values=pixel_values,
685-
attention_mask=attention_mask,
686-
position_ids=position_ids,
687-
past_key_values=past_key_values,
688-
inputs_embeds=inputs_embeds,
689-
vision_feature_layer=vision_feature_layer,
690-
vision_feature_select_strategy=vision_feature_select_strategy,
691-
labels=labels,
692-
use_cache=use_cache,
693-
output_attentions=output_attentions,
694-
output_hidden_states=output_hidden_states,
695-
return_dict=return_dict,
696-
cache_position=cache_position,
697-
logits_to_keep=logits_to_keep,
698-
image_sizes=image_sizes,
699-
**lm_kwargs,
700-
)
714+
super().forward(**super_kwargs)
701715

702716

703717
__all__ = [

0 commit comments

Comments
 (0)