Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tests
  • Loading branch information
Cyrilvallez committed Sep 9, 2025
commit 2c19e6e681c6aaa7e4ac89e42e9d8a7091ed1267
142 changes: 137 additions & 5 deletions tests/models/qwen3_next/test_modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,14 @@
# limitations under the License.

import copy
import tempfile
import unittest

import pytest
from parameterized import parameterized

from transformers import Qwen3NextConfig, is_torch_available
from transformers.testing_utils import (
require_torch,
slow,
torch_device,
)
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device


if is_torch_available():
Expand All @@ -40,6 +37,7 @@
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextDynamicCache

from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
from ...generation.test_utils import has_similar_generate_outputs
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
_config_zero_init,
Expand Down Expand Up @@ -149,6 +147,108 @@ def test_past_key_values_format(self):
self.assertEqual(self_attention_layer_keys.shape, default_self_attention_shape)
self.assertEqual(self_attention_layer_values.shape, default_self_attention_shape)

@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
"Needs to be overwritten as Qwen3-Next has non-standard cache."
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(torch_device)
model.eval()

generate_kwargs = {
"pad_token_id": -1,
"eos_token_id": -1,
"forced_eos_token_id": None,
"encoder_no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
}

# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)

# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
new_attention_len = outputs_cached.sequences.shape[-1]

inputs["input_ids"] = outputs_cached.sequences
if "attention_mask" in inputs:
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"],
(0, new_attention_len - inputs["attention_mask"].shape[1]),
mode="constant",
value=1,
)
first_caches_scores = outputs_cached.scores
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
full_cached_scores = first_caches_scores + outputs_cached.scores
outputs_cached.scores = full_cached_scores

# The two sets of generated text and past kv should be equal to each other
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
# Diff with the main test: we need to skip layers where it stays None
if outputs.past_key_values[layer_idx][kv_idx] is not None:
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
outputs_cached.past_key_values[layer_idx][kv_idx],
)
)

@pytest.mark.generate
def test_generate_continue_from_inputs_embeds(self):
"Needs to be overwritten as Qwen3-Next has non-standard cache."
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
model = model_class(config).to(torch_device).eval()
input_ids = inputs_dict.pop("input_ids")
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
model.config.is_decoder = True
model.generation_config.use_cache = True

generation_kwargs = {
"return_dict_in_generate": True,
"do_sample": False,
}

# Traditional way of generating text, with `return_dict_in_generate` to return the past key values.
input_embeds = model.get_input_embeddings()(input_ids)
outputs = model.generate(inputs_embeds=input_embeds, max_new_tokens=4, **generation_kwargs)

# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens)
initial_output = model.generate(inputs_embeds=input_embeds, max_new_tokens=3, **generation_kwargs)
continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1)
cached_output = model.generate(
inputs_embeds=continued_embeds,
max_new_tokens=1,
past_key_values=initial_output.past_key_values,
**generation_kwargs,
)

# Combine the (3 + 1) generated tokens and verify it matches with full generation.
combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1)
self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist())
# The two sets of past kv should be equal to each other
for layer_idx in range(len(cached_output.past_key_values)):
for kv_idx in range(len(cached_output.past_key_values[layer_idx])):
# Diff with the main test: we need to skip layers where it stays None
if outputs.past_key_values[layer_idx][kv_idx] is not None:
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
cached_output.past_key_values[layer_idx][kv_idx],
)
)

def test_attention_outputs(self):
"Needs to be overwritten as Qwen3-Next alternates between attention layers and gated deltanet layers."
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down Expand Up @@ -249,6 +349,38 @@ def test_eager_matches_sdpa_inference(
def test_multi_gpu_data_parallel_forward(self):
pass

@require_torch_multi_gpu
def test_can_use_device_map(self):
"""
Test that this model can be dispatched on multiple gpus. It's not obvious as the Cache is not standard,
ant each layer need to use the correct device on which it reside (i.e. it needs to be lazy initialized).
"""
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
inputs_dict = {k: v.to(0) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
model = model_class(config).eval()

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
del model
model = model_class.from_pretrained(
tmpdirname,
device_map={
"lm_head": 0,
"model.embed_tokens": 0,
"model.norm": 0,
"model.layers.0": 0,
"model.layers.1": 1,
},
)

# Check that we indeed use 2 different devices for each layer
self.assertTrue({param.device for param in model.model.layers[0].parameters()} == {torch.device(0)})
self.assertTrue({param.device for param in model.model.layers[1].parameters()} == {torch.device(1)})

# This should not crash
_ = model.generate(**inputs_dict, max_new_tokens=5, min_new_tokens=5)


@slow
class Qwen3NextIntegrationTest(unittest.TestCase):
Expand Down