Skip to content

Commit 2a5e941

Browse files
committed
update test_paged_attention for CPU
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 039a5ff commit 2a5e941

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

tests/generation/test_paged_attention.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from parameterized import parameterized
55

66
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
7-
from transformers.testing_utils import require_flash_attn, require_torch_accelerator, slow
7+
from transformers.testing_utils import slow, torch_device
88

99

1010
_TEST_PROMPTS = [
@@ -23,10 +23,17 @@
2323
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train travel in 30 minutes?\n## Step 1: Convert the speed from km/h to km/min",
2424
]
2525

26+
# TODO: Use Expectations once _EXPECTED_OUTPUTS are verified by HF team. Currently fails on A100.
27+
_EXPECTED_OUTPUTS_CPU = [
28+
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n## Step 2: Determine the taste and nutritional value of the fruit\nThe fruit is described as sweet",
29+
"get started with our services.\nWe will be in touch with you shortly to discuss your project and provide a quote.\n\n**Project Details**\n\n* Project Name: _____________________________________________________\n* Project Description: __________________________________________________\n* Project Type (check all that apply",
30+
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train travel in 30 minutes?\n## Step 1: Convert the speed from km/h to km/min",
31+
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer is not a straightforward one, and it requires some lateral thinking to arrive at the correct solution.",
32+
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes up a conversation.",
33+
]
34+
2635

2736
@slow
28-
@require_flash_attn
29-
@require_torch_accelerator
3037
class TestBatchGeneration(unittest.TestCase):
3138
@classmethod
3239
def setUpClass(cls):
@@ -51,6 +58,11 @@ def setUpClass(cls):
5158
]
5259
)
5360
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
61+
if attn_impl in ["paged|flash_attention_2", "paged|flex_attention"] and torch_device == "cpu":
62+
self.skipTest(
63+
f"CPU only support sdpa/eager paged attention for now, but found {attn_impl}. Skipping test."
64+
)
65+
5466
self.model.config.attn_implementation = attn_impl
5567

5668
generation_config = GenerationConfig(
@@ -77,11 +89,13 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max
7789
f"\n[{attn_impl}] Batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
7890
)
7991

92+
expected_outputs = _EXPECTED_OUTPUTS_CPU if torch_device == "cpu" else _EXPECTED_OUTPUTS
93+
8094
for i, req_id in enumerate(batch_outputs):
8195
generated = self.tokenizer.decode(
8296
batch_outputs[req_id].generated_tokens, skip_special_tokens=False
8397
).strip()
84-
expected = _EXPECTED_OUTPUTS[i].strip()
98+
expected = expected_outputs[i].strip()
8599
self.assertTrue(
86100
generated.startswith(expected),
87101
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}",
@@ -97,6 +111,11 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max
97111
)
98112
def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens):
99113
"""Test batch generation with do_sampling=True to verify sampling works correctly."""
114+
if attn_impl in ["paged|flash_attention_2", "paged|flex_attention"] and torch_device == "cpu":
115+
self.skipTest(
116+
f"CPU only support sdpa/eager paged attention for now, but found {attn_impl}. Skipping test."
117+
)
118+
100119
self.model.config.attn_implementation = attn_impl
101120

102121
generation_config = GenerationConfig(

0 commit comments

Comments
 (0)