|
16 | 16 |
|
17 | 17 | import collections.abc |
18 | 18 | from dataclasses import dataclass |
19 | | -from typing import List, Optional, Tuple, Union |
| 19 | +from typing import Callable, List, Optional, Tuple, Union |
20 | 20 |
|
21 | 21 | import torch |
22 | 22 | import torch.nn as nn |
23 | 23 | import torch.utils.checkpoint |
24 | 24 |
|
25 | 25 | from ...activations import ACT2FN |
| 26 | +from ...modeling_flash_attention_utils import FlashAttentionKwargs |
26 | 27 | from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling |
27 | | -from ...modeling_utils import PreTrainedModel |
| 28 | +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| 29 | +from ...processing_utils import Unpack |
28 | 30 | from ...utils import ( |
29 | 31 | add_code_sample_docstrings, |
30 | 32 | add_start_docstrings, |
@@ -92,6 +94,55 @@ def __init__(self, config: InternVLVisionConfig): |
92 | 94 | self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity() |
93 | 95 | self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity() |
94 | 96 |
|
| 97 | + def forward( |
| 98 | + self, |
| 99 | + hidden_states: torch.Tensor, |
| 100 | + attention_mask: Optional[torch.Tensor] = None, |
| 101 | + output_attentions: Optional[torch.Tensor] = None, |
| 102 | + **kwargs: Unpack[FlashAttentionKwargs], |
| 103 | + ): |
| 104 | + batch_size, seq_len, _ = hidden_states.size() |
| 105 | + |
| 106 | + query_states = self.q_proj(hidden_states) |
| 107 | + key_states = self.k_proj(hidden_states) |
| 108 | + value_states = self.v_proj(hidden_states) |
| 109 | + |
| 110 | + query_states = self.q_norm(query_states) |
| 111 | + key_states = self.k_norm(key_states) |
| 112 | + |
| 113 | + query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 114 | + key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 115 | + value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 116 | + |
| 117 | + attention_interface: Callable = eager_attention_forward |
| 118 | + if self.config._attn_implementation != "eager": |
| 119 | + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): |
| 120 | + logger.warning_once( |
| 121 | + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " |
| 122 | + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
| 123 | + ) |
| 124 | + else: |
| 125 | + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
| 126 | + |
| 127 | + attn_output, attn_weights = attention_interface( |
| 128 | + self, |
| 129 | + query_states, |
| 130 | + key_states, |
| 131 | + value_states, |
| 132 | + attention_mask, |
| 133 | + dropout=0.0 if not self.training else self.attention_dropout, |
| 134 | + scaling=self.scale, |
| 135 | + is_causal=False, |
| 136 | + **kwargs, |
| 137 | + ) |
| 138 | + attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim) |
| 139 | + |
| 140 | + output = self.projection_layer(attn_output) |
| 141 | + output = self.projection_dropout(output) |
| 142 | + |
| 143 | + outputs = (output, attn_weights) if output_attentions else (output, None) |
| 144 | + return outputs |
| 145 | + |
95 | 146 |
|
96 | 147 | class InternVLVisionPreTrainedModel(PreTrainedModel): |
97 | 148 | """ |
@@ -609,26 +660,7 @@ def get_image_features( |
609 | 660 |
|
610 | 661 | @add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING) |
611 | 662 | @replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
612 | | - def forward( |
613 | | - self, |
614 | | - input_ids: torch.LongTensor = None, |
615 | | - pixel_values: torch.FloatTensor = None, |
616 | | - attention_mask: Optional[torch.Tensor] = None, |
617 | | - position_ids: Optional[torch.LongTensor] = None, |
618 | | - past_key_values: Optional[List[torch.FloatTensor]] = None, |
619 | | - inputs_embeds: Optional[torch.FloatTensor] = None, |
620 | | - vision_feature_layer: Optional[int] = None, |
621 | | - vision_feature_select_strategy: Optional[str] = None, |
622 | | - labels: Optional[torch.LongTensor] = None, |
623 | | - use_cache: Optional[bool] = None, |
624 | | - output_attentions: Optional[bool] = None, |
625 | | - output_hidden_states: Optional[bool] = None, |
626 | | - return_dict: Optional[bool] = None, |
627 | | - cache_position: Optional[torch.LongTensor] = None, |
628 | | - logits_to_keep: Union[int, torch.Tensor] = 0, |
629 | | - image_sizes: Optional[torch.Tensor] = None, |
630 | | - **lm_kwargs, |
631 | | - ) -> Union[Tuple, InternVLCausalLMOutputWithPast]: |
| 663 | + def forward(**super_kwargs): |
632 | 664 | r""" |
633 | 665 | Args: |
634 | 666 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
@@ -679,25 +711,7 @@ def forward( |
679 | 711 | >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)) |
680 | 712 | The images depict the Statue of Liberty and the Golden Gate Bridge. |
681 | 713 | ```""" |
682 | | - super().forward( |
683 | | - input_ids=input_ids, |
684 | | - pixel_values=pixel_values, |
685 | | - attention_mask=attention_mask, |
686 | | - position_ids=position_ids, |
687 | | - past_key_values=past_key_values, |
688 | | - inputs_embeds=inputs_embeds, |
689 | | - vision_feature_layer=vision_feature_layer, |
690 | | - vision_feature_select_strategy=vision_feature_select_strategy, |
691 | | - labels=labels, |
692 | | - use_cache=use_cache, |
693 | | - output_attentions=output_attentions, |
694 | | - output_hidden_states=output_hidden_states, |
695 | | - return_dict=return_dict, |
696 | | - cache_position=cache_position, |
697 | | - logits_to_keep=logits_to_keep, |
698 | | - image_sizes=image_sizes, |
699 | | - **lm_kwargs, |
700 | | - ) |
| 714 | + super().forward(**super_kwargs) |
701 | 715 |
|
702 | 716 |
|
703 | 717 | __all__ = [ |
|
0 commit comments