Skip to content

Commit ffb47c2

Browse files
committed
Enable FP8 conversion on sm < 89
1 parent 6412e35 commit ffb47c2

File tree

6 files changed

+487
-149
lines changed

6 files changed

+487
-149
lines changed

python/test/regression/test_cast_matmul.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515

1616
input_dtypes = ["bfloat16", "float16", "float32"]
1717
if is_cuda():
18-
input_dtypes += ["int8", "float8_e5m2"]
19-
cc = torch.cuda.get_device_capability(0)
20-
if cc >= (8, 9):
21-
input_dtypes += ["float8_e4m3fn"]
18+
input_dtypes += ["int8", "float8_e5m2", "float8_e4m3fn"]
2219
elif is_hip_cdna3():
2320
input_dtypes += [
2421
"int8",

python/test/unit/language/test_compile_errors.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,14 +355,12 @@ def kernel():
355355
@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
356356
def test_fp8_support(fresh_triton_cache, dtype):
357357
warning_dtypes = []
358-
supported_dtypes = [tl.float8e5]
358+
supported_dtypes = [tl.float8e5, tl.float8e4nv]
359359
if is_cuda():
360360
cc = torch.cuda.get_device_capability(0)
361361
supported_dtypes.append(tl.float8e4b15)
362362
if cc >= (9, 0):
363363
warning_dtypes.append(tl.float8e4b15)
364-
if cc >= (8, 9):
365-
supported_dtypes.append(tl.float8e4nv)
366364
elif is_hip():
367365
supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16]
368366
if is_hip_cdna4():

python/test/unit/language/test_conversions.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,7 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device):
276276
# On HIP, fp8e4nv upcasting to fp32 is only supported on CDNA4, and
277277
# fp8e4nv upcasting to bf16 and fp16 is only supported on CDNA3 and CDNA4.
278278
if is_cuda():
279-
if ((src_dtype == 'float8e4nv' and torch.cuda.get_device_capability(0) < (8, 9))
280-
or src_dtype in ('float8e4b8', 'float8e5b16')):
279+
if src_dtype in ('float8e4b8', 'float8e5b16'):
281280
# If the dtype should error out in the given device, we assert that and return
282281
with pytest.raises(triton.CompilationError, match="not supported in this architecture"):
283282
launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device)
@@ -333,12 +332,6 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device):
333332
def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
334333

335334
if is_cuda():
336-
if src_dtype != 'float32' and torch.cuda.get_device_capability(0) < (9, 0):
337-
pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+")
338-
339-
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0):
340-
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")
341-
342335
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne':
343336
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")
344337

python/test/unit/language/test_core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,8 +1237,6 @@ def test_abs_fp8(in_dtype, device):
12371237
cc = torch.cuda.get_device_capability()
12381238
if in_dtype == tl.float8e4b15 and cc >= (9, 0):
12391239
pytest.skip("float8e4b15 not supported on CUDA >= 9.0")
1240-
if in_dtype == tl.float8e4nv and cc < (8, 9):
1241-
pytest.skip("float8e4nv not supported on CUDA < 8.9")
12421240

12431241
@triton.jit
12441242
def abs_kernel(X, Z, SIZE: tl.constexpr):

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class CUDAOptions:
128128
enable_fp_fusion: bool = True
129129
launch_cooperative_grid: bool = False
130130
launch_pdl: bool = False
131-
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
131+
supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e4b15")
132132
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
133133
default_dot_input_precision: str = "tf32"
134134
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
@@ -191,8 +191,6 @@ def parse_options(self, opts) -> Any:
191191

192192
if "supported_fp8_dtypes" not in args:
193193
supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
194-
if capability >= 89:
195-
supported_fp8_dtypes.add("fp8e4nv")
196194
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
197195

198196
if "deprecated_fp8_dot_operand_dtypes" not in args:

0 commit comments

Comments
 (0)