Skip to content

Commit f80b048

Browse files
YangKai0616ydshieh
andauthored
[XPU] Fix UT errors in the sam3 and lfm series model. (#42798)
* Make sam3 tests pass on XPU * Update flm2 tests GT for XPU * Remove the skip tests of local mask for XPU * Pass position_ids to varlen FA2 * Change modular also * Skip FA2 bwd tests * Make style * Increase rtol * Adapt to the main branch * fix cuda 1 --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 6d00f6b commit f80b048

File tree

4 files changed

+21
-19
lines changed

4 files changed

+21
-19
lines changed

tests/models/lfm2_moe/test_modeling_lfm2_moe.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def test_model_1a8b_logits(self):
176176
# Expected mean on dim = -1
177177
EXPECTED_MEANS = Expectations(
178178
{
179-
("cuda", None): torch.tensor([[-1.3855, -0.5123, -1.3143, -1.2144, -1.0791, -1.2117, -1.4704, -0.7648, -0.6175, -1.2402, -1.1459, -1.0083, -1.0247, -0.8830, -1.5643, -1.7266, -1.6254,]]),
180-
("xpu", None): torch.tensor([[-1.3863, -0.4653, -1.3246, -1.3199, -1.0940, -1.2254, -1.4716, -0.8852, -0.5920, -1.2182, -1.1782, -1.0268, -1.0114, -0.8816, -1.5774, -1.7408, -1.6147,]]),
179+
("cuda", None): torch.tensor([[-1.3912, -0.4653, -1.3339, -1.3249, -1.0985, -1.2373, -1.4599, -0.7515, -0.6140, -1.2329, -1.1481, -1.0081, -0.9937, -0.8875, -1.5539, -1.7283, -1.6284]]),
180+
("xpu", None): torch.tensor([[-1.3879, -0.4730, -1.3193, -1.3139, -1.0826, -1.2129, -1.4744, -0.7485, -0.6004, -1.2353, -1.1602, -1.0432, -1.0180, -0.9099, -1.5949, -1.7487, -1.5991]]),
181181
}
182182
)
183183
# fmt: on
@@ -188,8 +188,8 @@ def test_model_1a8b_logits(self):
188188
# Expected portion of the logits
189189
EXPECTED_SLICES = Expectations(
190190
{
191-
("cuda", None): torch.tensor([-1.2656, 2.4844, 5.5000, -1.3359, -1.3203, -1.3438, 1.9375, 5.8438, -0.6523, -1.2891]),
192-
("xpu", None): torch.tensor([-1.2656, 2.4531, 5.4375, -1.3438, -1.3203, -1.3516, 1.9297, 5.7812, -0.6719, -1.3203]),
191+
("cuda", None): torch.tensor([-1.2734, 2.4844, 5.5000, -1.3438, -1.3281, -1.3516, 1.9375, 5.8438, -0.6641, -1.2969]),
192+
("xpu", None): torch.tensor([-1.2734, 2.4531, 5.4688, -1.3438, -1.3281, -1.3516, 1.9297, 5.7812, -0.6719, -1.3125]),
193193
}
194194
)
195195
# fmt: on
@@ -219,14 +219,16 @@ def test_model_1a8b_batched_chat_generation(self):
219219
# fmt: off
220220
EXPECTED_TEXT_COMPLETIONS = Expectations(
221221
{
222-
("cuda", None): ["Who are you?, a language model designed to assist with information and tasks? \nI am",
223-
"Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor",
224-
"The Meji Restoration in Japan ended or the Meiji Restoration (1868–1912) marked a pivotal",
225-
],
226-
("xpu", None): ['Who are you? (AI) designed to assist? \nI am an AI assistant developed to',
227-
'Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor',
228-
'The Meji Restoration in Japan ended** \n* **Key Event:** The overthrow of the Tokugawa'
229-
],
222+
("cuda", None): [
223+
"Who are you? (AI) designed to assist? \nI am an AI assistant developed to",
224+
"Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum.",
225+
"The Meji Restoration in Japan ended** \n**A.** The shogunate was abolished, and imperial"
226+
],
227+
("xpu", None): [
228+
"Who are you? (AI) designed to assist? \nI am an AI language model developed",
229+
"Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor",
230+
"The Meji Restoration in Japan ended, which occurred in 1868, marked the: \nA) Establish"
231+
],
230232
}
231233
)
232234
# fmt: on

tests/models/sam3/test_modeling_sam3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from transformers.testing_utils import (
2424
backend_empty_cache,
25+
require_deterministic_for_xpu,
2526
require_torch,
2627
slow,
2728
torch_device,
@@ -1448,6 +1449,7 @@ def test_semantic_segmentation_output(self):
14481449
# Check that semantic seg has same spatial size as pred_masks
14491450
self.assertEqual(outputs.semantic_seg.shape[-2:], outputs.pred_masks.shape[-2:])
14501451

1452+
@require_deterministic_for_xpu
14511453
def test_efficient_multi_prompt_single_image(self):
14521454
"""Test efficient inference with multiple prompts on a single image using get_vision_features."""
14531455
raw_image = prepare_coco_cat_image()
@@ -1491,6 +1493,7 @@ def test_efficient_multi_prompt_single_image(self):
14911493
torch.testing.assert_close(outputs_with_embeds.pred_boxes, outputs_direct.pred_boxes, atol=1e-5, rtol=1e-5)
14921494
torch.testing.assert_close(outputs_with_embeds.pred_masks, outputs_direct.pred_masks, atol=1e-5, rtol=1e-5)
14931495

1496+
@require_deterministic_for_xpu
14941497
def test_efficient_single_prompt_multi_images(self):
14951498
"""Test efficient inference with same prompt on multiple images using get_text_features."""
14961499
raw_image1 = prepare_coco_cat_image()

tests/models/sam3_tracker_video/test_modeling_sam3_tracker_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def test_inference_mask_generation_video_batched_bb(self):
428428
]
429429
).to(torch_device),
430430
atol=1e-4,
431-
rtol=1e-4,
431+
rtol=1e-3,
432432
)
433433

434434
def test_inference_propagate_video_from_mask_input(self):

tests/test_modeling_common.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2749,12 +2749,6 @@ def flash_attn_inference_equivalence(
27492749
if getattr(config, "sliding_window", None):
27502750
config.sliding_window = 2
27512751

2752-
if torch_device == "xpu" and (
2753-
attn_implementation == "kernels-community/flash-attn2"
2754-
or attn_implementation == "flash_attention_2"
2755-
):
2756-
self.skipTest("XPU does not support sliding window attention with Flash-Attention-2 currently.")
2757-
27582752
model = model_class(config)
27592753
if not all(
27602754
submodel._supports_flash_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel)
@@ -3386,6 +3380,9 @@ def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break(s
33863380
if not is_torch_fp16_available_on_device(torch_device):
33873381
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
33883382

3383+
if torch_device == "xpu":
3384+
self.skipTest("XPU FA2 currently does not support backward.")
3385+
33893386
torch.compiler.reset()
33903387
dtype = torch.float16
33913388

0 commit comments

Comments
 (0)