-
Notifications
You must be signed in to change notification settings - Fork 31.4k
[kernels] refactor function kernel calling #41577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]]): | ||
| if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the main utility function applied to the case of causal-conv1d
|
run-slow: falcon_mamba, mamba |
|
This comment contains run-slow, running the specified jobs: models: ['models/falcon_mamba', 'models/mamba'] |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ncie
| raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") | ||
|
|
||
|
|
||
| _KERNEL_SIMPLE_MAPPING: dict[str, str] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why "simple" ?
| causal_conv1d = lazy_load_kernel("causal-conv1d", _KERNEL_MAPPING_ACROSS_MODELS) | ||
| causal_conv1d_update, causal_conv1d_fn = ( | ||
| (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) | ||
| if causal_conv1d is not None | ||
| else (None, None) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should not need to pass the kernel mapping, lazy load kernel can handle it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
[For maintainers] Suggested jobs to run (before merge) run-slow: falcon_mamba, mamba |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks 🤗
| ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) | ||
|
|
||
|
|
||
| def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] = _KERNEL_MODULE_MAPPING): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be None, if none default to kernel mapping python will cry otherwise!
| causal_conv1d = lazy_load_kernel("causal-conv1d") | ||
| causal_conv1d_update, causal_conv1d_fn = ( | ||
| (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn) | ||
| if causal_conv1d is not None | ||
| else (None, None) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice and simple
* refactor function kernel callling * nit * don't pass the mapping * use _kernels_available * rm import
What does this PR do?
This should simplify lazy kernel loading in Transformers.
We simply define a mapping between each kernel name and the repository it should be pulled from, then load it using the
lazy_load_kernelfunction. This function adds the kernel to a global cache shared across all models.If the kernel isn’t available, we check whether it’s installed as a module for backward compatibility; otherwise, we return
None.