Skip to content

Commit 1fb3fc4

Browse files
authored
[kernels] refactor function kernel calling (#41577)
* refactor function kernel callling * nit * don't pass the mapping * use _kernels_available * rm import
1 parent 9176af5 commit 1fb3fc4

File tree

4 files changed

+112
-69
lines changed

4 files changed

+112
-69
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
import re
1515
from collections.abc import Callable
1616
from functools import partial
17+
from types import ModuleType
1718
from typing import Optional, Union
1819

1920
from ..modeling_flash_attention_utils import lazy_import_flash_attention
21+
from ..utils import logging
2022
from .flash_attention import flash_attention_forward
2123

2224

25+
logger = logging.get_logger(__name__)
26+
2327
try:
2428
from kernels import (
2529
Device,
@@ -158,6 +162,13 @@ def register_kernel_mapping(*args, **kwargs):
158162
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
159163

160164

165+
_HUB_KERNEL_MAPPING: dict[str, str] = {
166+
"causal-conv1d": "kernels-community/causal-conv1d",
167+
}
168+
169+
_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {}
170+
171+
161172
def is_kernel(attn_implementation: Optional[str]) -> bool:
162173
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
163174
return (
@@ -220,9 +231,53 @@ def load_and_register_attn_kernel(attn_implementation: str, attention_wrapper: O
220231
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
221232

222233

234+
def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] = _KERNEL_MODULE_MAPPING):
235+
if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
236+
return mapping[kernel_name]
237+
if kernel_name not in _HUB_KERNEL_MAPPING:
238+
logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
239+
mapping[kernel_name] = None
240+
return None
241+
if _kernels_available:
242+
from kernels import get_kernel
243+
244+
try:
245+
kernel = get_kernel(_HUB_KERNEL_MAPPING[kernel_name])
246+
mapping[kernel_name] = kernel
247+
except FileNotFoundError:
248+
mapping[kernel_name] = None
249+
250+
else:
251+
# Try to import is_{kernel_name}_available from ..utils
252+
import importlib
253+
254+
new_kernel_name = kernel_name.replace("-", "_")
255+
func_name = f"is_{new_kernel_name}_available"
256+
257+
try:
258+
utils_mod = importlib.import_module("..utils.import_utils", __package__)
259+
is_kernel_available = getattr(utils_mod, func_name, None)
260+
except Exception:
261+
is_kernel_available = None
262+
263+
if callable(is_kernel_available) and is_kernel_available():
264+
# Try to import the module "{kernel_name}" from parent package level
265+
try:
266+
module = importlib.import_module(f"{kernel_name}")
267+
mapping[kernel_name] = module
268+
return module
269+
except Exception:
270+
mapping[kernel_name] = None
271+
else:
272+
mapping[kernel_name] = None
273+
274+
return mapping[kernel_name]
275+
276+
223277
__all__ = [
224278
"LayerRepository",
225279
"use_kernel_forward_from_hub",
226280
"register_kernel_mapping",
227281
"replace_kernel_forward_from_hub",
282+
"lazy_load_kernel",
228283
]

src/transformers/models/falcon_mamba/modeling_falcon_mamba.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@
3030
from ...activations import ACT2FN
3131
from ...configuration_utils import PreTrainedConfig
3232
from ...generation import GenerationMixin
33+
from ...integrations.hub_kernels import lazy_load_kernel
3334
from ...modeling_layers import GradientCheckpointingLayer
3435
from ...modeling_utils import PreTrainedModel
3536
from ...utils import ModelOutput, auto_docstring, logging
3637
from ...utils.import_utils import (
37-
is_causal_conv1d_available,
38-
is_kernels_available,
3938
is_mamba_ssm_available,
4039
is_mambapy_available,
4140
)
@@ -162,33 +161,6 @@ def reset(self):
162161
self.ssm_states[layer_idx].zero_()
163162

164163

165-
def _lazy_load_causal_conv1d():
166-
global _causal_conv1d_cache
167-
if _causal_conv1d_cache is not None:
168-
return _causal_conv1d_cache
169-
170-
if is_kernels_available():
171-
from kernels import get_kernel
172-
173-
try:
174-
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
175-
except FileNotFoundError:
176-
# no kernel binary match, fallback to slow path
177-
_causal_conv1d_cache = (None, None)
178-
else:
179-
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
180-
elif is_causal_conv1d_available():
181-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
182-
183-
_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
184-
else:
185-
_causal_conv1d_cache = (None, None)
186-
return _causal_conv1d_cache
187-
188-
189-
_causal_conv1d_cache = None
190-
191-
192164
def rms_forward(hidden_states, variance_epsilon=1e-6):
193165
"""
194166
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
@@ -268,7 +240,12 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int):
268240
self.rms_eps = config.mixer_rms_eps
269241

270242
def warn_slow_implementation(self):
271-
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
243+
causal_conv1d = lazy_load_kernel("causal-conv1d")
244+
causal_conv1d_update, causal_conv1d_fn = (
245+
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
246+
if causal_conv1d is not None
247+
else (None, None)
248+
)
272249
is_fast_path_available = all(
273250
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
274251
)
@@ -323,7 +300,12 @@ def cuda_kernels_forward(
323300
)
324301

325302
else:
326-
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
303+
causal_conv1d = lazy_load_kernel("causal-conv1d")
304+
causal_conv1d_update, causal_conv1d_fn = (
305+
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
306+
if causal_conv1d is not None
307+
else (None, None)
308+
)
327309
hidden_states, gate = projected_states.chunk(2, dim=1)
328310

329311
if attention_mask is not None:
@@ -518,7 +500,12 @@ def forward(
518500
cache_position: Optional[torch.LongTensor] = None,
519501
attention_mask: Optional[torch.LongTensor] = None,
520502
):
521-
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
503+
causal_conv1d = lazy_load_kernel("causal-conv1d")
504+
causal_conv1d_update, causal_conv1d_fn = (
505+
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
506+
if causal_conv1d is not None
507+
else (None, None)
508+
)
522509
is_fast_path_available = all(
523510
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
524511
)

src/transformers/models/falcon_mamba/modular_falcon_mamba.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from torch import nn
2121

22+
from ...integrations.hub_kernels import lazy_load_kernel
2223
from ...utils import auto_docstring, logging
2324
from ...utils.import_utils import (
2425
is_mamba_ssm_available,
@@ -35,7 +36,6 @@
3536
MambaOutput,
3637
MambaPreTrainedModel,
3738
MambaRMSNorm,
38-
_lazy_load_causal_conv1d,
3939
)
4040

4141

@@ -54,8 +54,6 @@
5454
else:
5555
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
5656

57-
_causal_conv1d_cache = None
58-
5957

6058
class FalconMambaConfig(MambaConfig):
6159
"""
@@ -258,7 +256,12 @@ def rms_forward(hidden_states, variance_epsilon=1e-6):
258256

259257
class FalconMambaMixer(MambaMixer):
260258
def warn_slow_implementation(self):
261-
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
259+
causal_conv1d = lazy_load_kernel("causal-conv1d")
260+
causal_conv1d_update, causal_conv1d_fn = (
261+
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
262+
if causal_conv1d is not None
263+
else (None, None)
264+
)
262265
is_fast_path_available = all(
263266
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
264267
)
@@ -324,7 +327,12 @@ def cuda_kernels_forward(
324327
)
325328

326329
else:
327-
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
330+
causal_conv1d = lazy_load_kernel("causal-conv1d")
331+
causal_conv1d_update, causal_conv1d_fn = (
332+
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
333+
if causal_conv1d is not None
334+
else (None, None)
335+
)
328336
hidden_states, gate = projected_states.chunk(2, dim=1)
329337

330338
if attention_mask is not None:
@@ -518,7 +526,12 @@ def forward(
518526
cache_position: Optional[torch.LongTensor] = None,
519527
attention_mask: Optional[torch.LongTensor] = None,
520528
):
521-
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
529+
causal_conv1d = lazy_load_kernel("causal-conv1d")
530+
causal_conv1d_update, causal_conv1d_fn = (
531+
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
532+
if causal_conv1d is not None
533+
else (None, None)
534+
)
522535
is_fast_path_available = all(
523536
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
524537
)

src/transformers/models/mamba/modeling_mamba.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...activations import ACT2FN
2626
from ...configuration_utils import PreTrainedConfig
2727
from ...generation import GenerationMixin
28+
from ...integrations.hub_kernels import lazy_load_kernel
2829
from ...modeling_layers import GradientCheckpointingLayer
2930
from ...modeling_utils import PreTrainedModel
3031
from ...utils import (
@@ -33,8 +34,6 @@
3334
logging,
3435
)
3536
from ...utils.import_utils import (
36-
is_causal_conv1d_available,
37-
is_kernels_available,
3837
is_mamba_ssm_available,
3938
is_mambapy_available,
4039
)
@@ -54,32 +53,6 @@
5453
else:
5554
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
5655

57-
_causal_conv1d_cache = None
58-
59-
60-
def _lazy_load_causal_conv1d():
61-
global _causal_conv1d_cache
62-
if _causal_conv1d_cache is not None:
63-
return _causal_conv1d_cache
64-
65-
if is_kernels_available():
66-
from kernels import get_kernel
67-
68-
try:
69-
_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
70-
except FileNotFoundError:
71-
# no kernel binary match, fallback to slow path
72-
_causal_conv1d_cache = (None, None)
73-
else:
74-
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
75-
elif is_causal_conv1d_available():
76-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
77-
78-
_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
79-
else:
80-
_causal_conv1d_cache = (None, None)
81-
return _causal_conv1d_cache
82-
8356

8457
class MambaCache:
8558
"""
@@ -236,7 +209,12 @@ def __init__(self, config: MambaConfig, layer_idx: int):
236209
self.warn_slow_implementation()
237210

238211
def warn_slow_implementation(self):
239-
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
212+
causal_conv1d = lazy_load_kernel("causal-conv1d")
213+
causal_conv1d_update, causal_conv1d_fn = (
214+
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
215+
if causal_conv1d is not None
216+
else (None, None)
217+
)
240218
is_fast_path_available = all(
241219
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
242220
)
@@ -287,7 +265,12 @@ def cuda_kernels_forward(
287265
)
288266

289267
else:
290-
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
268+
causal_conv1d = lazy_load_kernel("causal-conv1d")
269+
causal_conv1d_update, causal_conv1d_fn = (
270+
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
271+
if causal_conv1d is not None
272+
else (None, None)
273+
)
291274
hidden_states, gate = projected_states.chunk(2, dim=1)
292275

293276
if attention_mask is not None:
@@ -451,7 +434,12 @@ def forward(
451434
cache_position: Optional[torch.LongTensor] = None,
452435
attention_mask: Optional[torch.LongTensor] = None,
453436
):
454-
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
437+
causal_conv1d = lazy_load_kernel("causal-conv1d")
438+
causal_conv1d_update, causal_conv1d_fn = (
439+
(causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
440+
if causal_conv1d is not None
441+
else (None, None)
442+
)
455443
is_fast_path_available = all(
456444
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
457445
)

0 commit comments

Comments
 (0)