@@ -1723,8 +1723,9 @@ def forward(
17231723 past_key_values = past_key_values ,
17241724 )
17251725
1726- def _deepstack_process (self , hidden_states , visual_pos_masks , visual_embeds ):
1727- visual_pos_masks = visual_pos_masks [..., 0 ]
1726+ def _deepstack_process (
1727+ self , hidden_states : torch .Tensor , visual_pos_masks : torch .Tensor , visual_embeds : torch .Tensor
1728+ ):
17281729 visual_pos_masks = visual_pos_masks .to (hidden_states .device )
17291730 visual_embeds = visual_embeds .to (hidden_states .device , hidden_states .dtype )
17301731 hidden_states = hidden_states .clone ()
@@ -1859,6 +1860,7 @@ def __init__(self, config):
18591860 self .rope_deltas = None
18601861 self .num_experts = config .text_config .num_experts
18611862 self .num_experts_per_tok = config .text_config .num_experts_per_tok
1863+ self .router_aux_loss_coef = config .text_config .router_aux_loss_coef
18621864 self .post_init ()
18631865
18641866 def get_input_embeddings (self ):
@@ -2067,6 +2069,7 @@ def forward(
20672069
20682070 visual_embeds_multiscale = None
20692071 visual_pos_masks = None
2072+ image_mask , video_mask = None , None
20702073 # 2. Merge text , audios , image and video
20712074 if input_features is not None :
20722075 audio_features = self .get_audio_features (
@@ -2086,9 +2089,6 @@ def forward(
20862089 )
20872090 inputs_embeds = inputs_embeds .masked_scatter (image_mask , image_embeds )
20882091
2089- visual_pos_masks = image_mask
2090- visual_embeds_multiscale = image_embeds_multiscale
2091-
20922092 if pixel_values_videos is not None :
20932093 video_embeds , video_embeds_multiscale = self .get_video_features (pixel_values_videos , video_grid_thw )
20942094
@@ -2098,20 +2098,27 @@ def forward(
20982098 )
20992099 inputs_embeds = inputs_embeds .masked_scatter (video_mask , video_embeds )
21002100
2101- if visual_embeds_multiscale is None :
2102- visual_embeds_multiscale = video_embeds_multiscale
2103- visual_pos_masks = video_mask
2104- else :
2105- visual_pos_masks = video_mask | image_mask
2106- visual_embeds_multiscale_joint = ()
2107- image_mask_joint = image_mask [visual_pos_masks ]
2108- video_mask_joint = video_mask [visual_pos_masks ]
2109- for img_embed , vid_embed in zip (visual_embeds_multiscale , video_embeds_multiscale ):
2110- embed_joint = img_embed .new_zeros (visual_pos_masks .sum (), img_embed .shape [- 1 ])
2111- embed_joint [image_mask_joint , :] = img_embed
2112- embed_joint [video_mask_joint , :] = vid_embed
2113- visual_embeds_multiscale_joint = visual_embeds_multiscale_joint + (embed_joint ,)
2114- visual_embeds_multiscale = visual_embeds_multiscale_joint
2101+ if image_mask is not None and video_mask is not None :
2102+ image_mask = image_mask [..., 0 ]
2103+ video_mask = video_mask [..., 0 ]
2104+ visual_pos_masks = video_mask | image_mask
2105+ visual_embeds_multiscale_joint = ()
2106+ image_mask_joint = image_mask [visual_pos_masks ]
2107+ video_mask_joint = video_mask [visual_pos_masks ]
2108+ for img_embed , vid_embed in zip (image_embeds_multiscale , video_embeds_multiscale ):
2109+ embed_joint = img_embed .new_zeros (visual_pos_masks .sum (), img_embed .shape [- 1 ])
2110+ embed_joint [image_mask_joint , :] = img_embed
2111+ embed_joint [video_mask_joint , :] = vid_embed
2112+ visual_embeds_multiscale_joint = visual_embeds_multiscale_joint + (embed_joint ,)
2113+ visual_embeds_multiscale = visual_embeds_multiscale_joint
2114+ elif image_mask is not None :
2115+ image_mask = image_mask [..., 0 ]
2116+ visual_embeds_multiscale = image_embeds_multiscale
2117+ visual_pos_masks = image_mask
2118+ elif video_mask is not None :
2119+ video_mask = video_mask [..., 0 ]
2120+ visual_embeds_multiscale = video_embeds_multiscale
2121+ visual_pos_masks = video_mask
21152122
21162123 if feature_attention_mask is not None :
21172124 audio_feature_lengths = torch .sum (feature_attention_mask , dim = 1 )
0 commit comments