Skip to content

Commit 0c18820

Browse files
authored
[kernels] adding RMSNorm kernel for mps devices (#42058)
add kernel
1 parent dfe6e4c commit 0c18820

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def use_kernel_func_from_hub(func_name: str):
111111
layer_name="RMSNorm",
112112
)
113113
},
114+
"mps": {
115+
Mode.INFERENCE: LayerRepository(
116+
repo_id="kernels-community/mlx_rmsnorm",
117+
layer_name="RMSNorm",
118+
)
119+
},
114120
"npu": {
115121
Mode.INFERENCE: LayerRepository(
116122
repo_id="kernels-community/liger_kernels",

0 commit comments

Comments
 (0)