Skip to content

Commit cb6f03f

Browse files
BakerBunkerlvyuanjun.lyj
andauthored
Fix Qwen3-Omni inference when mixing video and image inputs in one batch (#41741)
* Fix qwen3omni inference when mixing video and image inputs in one batch * Fix `router_aux_loss_coef` --------- Co-authored-by: lvyuanjun.lyj <lvyuanjun.lyj@alibaba-inc.com>
1 parent 8fc5420 commit cb6f03f

File tree

2 files changed

+49
-40
lines changed

2 files changed

+49
-40
lines changed

src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,8 +1723,9 @@ def forward(
17231723
past_key_values=past_key_values,
17241724
)
17251725

1726-
def _deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds):
1727-
visual_pos_masks = visual_pos_masks[..., 0]
1726+
def _deepstack_process(
1727+
self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
1728+
):
17281729
visual_pos_masks = visual_pos_masks.to(hidden_states.device)
17291730
visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
17301731
hidden_states = hidden_states.clone()
@@ -1859,6 +1860,7 @@ def __init__(self, config):
18591860
self.rope_deltas = None
18601861
self.num_experts = config.text_config.num_experts
18611862
self.num_experts_per_tok = config.text_config.num_experts_per_tok
1863+
self.router_aux_loss_coef = config.text_config.router_aux_loss_coef
18621864
self.post_init()
18631865

18641866
def get_input_embeddings(self):
@@ -2067,6 +2069,7 @@ def forward(
20672069

20682070
visual_embeds_multiscale = None
20692071
visual_pos_masks = None
2072+
image_mask, video_mask = None, None
20702073
# 2. Merge text , audios , image and video
20712074
if input_features is not None:
20722075
audio_features = self.get_audio_features(
@@ -2086,9 +2089,6 @@ def forward(
20862089
)
20872090
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
20882091

2089-
visual_pos_masks = image_mask
2090-
visual_embeds_multiscale = image_embeds_multiscale
2091-
20922092
if pixel_values_videos is not None:
20932093
video_embeds, video_embeds_multiscale = self.get_video_features(pixel_values_videos, video_grid_thw)
20942094

@@ -2098,20 +2098,27 @@ def forward(
20982098
)
20992099
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
21002100

2101-
if visual_embeds_multiscale is None:
2102-
visual_embeds_multiscale = video_embeds_multiscale
2103-
visual_pos_masks = video_mask
2104-
else:
2105-
visual_pos_masks = video_mask | image_mask
2106-
visual_embeds_multiscale_joint = ()
2107-
image_mask_joint = image_mask[visual_pos_masks]
2108-
video_mask_joint = video_mask[visual_pos_masks]
2109-
for img_embed, vid_embed in zip(visual_embeds_multiscale, video_embeds_multiscale):
2110-
embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1])
2111-
embed_joint[image_mask_joint, :] = img_embed
2112-
embed_joint[video_mask_joint, :] = vid_embed
2113-
visual_embeds_multiscale_joint = visual_embeds_multiscale_joint + (embed_joint,)
2114-
visual_embeds_multiscale = visual_embeds_multiscale_joint
2101+
if image_mask is not None and video_mask is not None:
2102+
image_mask = image_mask[..., 0]
2103+
video_mask = video_mask[..., 0]
2104+
visual_pos_masks = video_mask | image_mask
2105+
visual_embeds_multiscale_joint = ()
2106+
image_mask_joint = image_mask[visual_pos_masks]
2107+
video_mask_joint = video_mask[visual_pos_masks]
2108+
for img_embed, vid_embed in zip(image_embeds_multiscale, video_embeds_multiscale):
2109+
embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1])
2110+
embed_joint[image_mask_joint, :] = img_embed
2111+
embed_joint[video_mask_joint, :] = vid_embed
2112+
visual_embeds_multiscale_joint = visual_embeds_multiscale_joint + (embed_joint,)
2113+
visual_embeds_multiscale = visual_embeds_multiscale_joint
2114+
elif image_mask is not None:
2115+
image_mask = image_mask[..., 0]
2116+
visual_embeds_multiscale = image_embeds_multiscale
2117+
visual_pos_masks = image_mask
2118+
elif video_mask is not None:
2119+
video_mask = video_mask[..., 0]
2120+
visual_embeds_multiscale = video_embeds_multiscale
2121+
visual_pos_masks = video_mask
21152122

21162123
if feature_attention_mask is not None:
21172124
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)

src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,10 +1243,6 @@ def __init__(self, config: Qwen3OmniMoeTextConfig):
12431243
)
12441244
self.rotary_emb = Qwen3OmniMoeThinkerTextRotaryEmbedding(config)
12451245

1246-
def _deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds):
1247-
visual_pos_masks = visual_pos_masks[..., 0]
1248-
return super()._deepstack_process(hidden_states, visual_pos_masks, visual_embeds)
1249-
12501246

12511247
@dataclass
12521248
class Qwen3OmniMoeThinkerCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
@@ -1274,6 +1270,7 @@ def __init__(self, config):
12741270
super().__init__(config)
12751271
self.num_experts = config.text_config.num_experts
12761272
self.num_experts_per_tok = config.text_config.num_experts_per_tok
1273+
self.router_aux_loss_coef = config.text_config.router_aux_loss_coef
12771274

12781275
def get_audio_features(
12791276
self,
@@ -1342,6 +1339,7 @@ def forward(
13421339

13431340
visual_embeds_multiscale = None
13441341
visual_pos_masks = None
1342+
image_mask, video_mask = None, None
13451343
# 2. Merge text , audios , image and video
13461344
if input_features is not None:
13471345
audio_features = self.get_audio_features(
@@ -1361,9 +1359,6 @@ def forward(
13611359
)
13621360
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
13631361

1364-
visual_pos_masks = image_mask
1365-
visual_embeds_multiscale = image_embeds_multiscale
1366-
13671362
if pixel_values_videos is not None:
13681363
video_embeds, video_embeds_multiscale = self.get_video_features(pixel_values_videos, video_grid_thw)
13691364

@@ -1373,20 +1368,27 @@ def forward(
13731368
)
13741369
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
13751370

1376-
if visual_embeds_multiscale is None:
1377-
visual_embeds_multiscale = video_embeds_multiscale
1378-
visual_pos_masks = video_mask
1379-
else:
1380-
visual_pos_masks = video_mask | image_mask
1381-
visual_embeds_multiscale_joint = ()
1382-
image_mask_joint = image_mask[visual_pos_masks]
1383-
video_mask_joint = video_mask[visual_pos_masks]
1384-
for img_embed, vid_embed in zip(visual_embeds_multiscale, video_embeds_multiscale):
1385-
embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1])
1386-
embed_joint[image_mask_joint, :] = img_embed
1387-
embed_joint[video_mask_joint, :] = vid_embed
1388-
visual_embeds_multiscale_joint = visual_embeds_multiscale_joint + (embed_joint,)
1389-
visual_embeds_multiscale = visual_embeds_multiscale_joint
1371+
if image_mask is not None and video_mask is not None:
1372+
image_mask = image_mask[..., 0]
1373+
video_mask = video_mask[..., 0]
1374+
visual_pos_masks = video_mask | image_mask
1375+
visual_embeds_multiscale_joint = ()
1376+
image_mask_joint = image_mask[visual_pos_masks]
1377+
video_mask_joint = video_mask[visual_pos_masks]
1378+
for img_embed, vid_embed in zip(image_embeds_multiscale, video_embeds_multiscale):
1379+
embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1])
1380+
embed_joint[image_mask_joint, :] = img_embed
1381+
embed_joint[video_mask_joint, :] = vid_embed
1382+
visual_embeds_multiscale_joint = visual_embeds_multiscale_joint + (embed_joint,)
1383+
visual_embeds_multiscale = visual_embeds_multiscale_joint
1384+
elif image_mask is not None:
1385+
image_mask = image_mask[..., 0]
1386+
visual_embeds_multiscale = image_embeds_multiscale
1387+
visual_pos_masks = image_mask
1388+
elif video_mask is not None:
1389+
video_mask = video_mask[..., 0]
1390+
visual_embeds_multiscale = video_embeds_multiscale
1391+
visual_pos_masks = video_mask
13901392

13911393
if feature_attention_mask is not None:
13921394
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)

0 commit comments

Comments
 (0)