easydel.kernels.flash_attention#
- class easydel.kernels.flash_attention.AttentionConfig(blocksize_q: int = 128, blocksize_k: int = 128, softmax_scale: Optional[float] = None, backend: Optional[Backend] = None, platform: Optional[Platform] = None)[source]#
Bases:
objectConfiguration for Flash Attention computation.
- blocksize_k: int = 128#
- blocksize_q: int = 128#
- replace(**kwargs)#
- softmax_scale: Optional[float] = None#
- class easydel.kernels.flash_attention.Backend(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumSupported compute backends.
- CPU = 'cpu'#
- GPU = 'gpu'#
- TPU = 'tpu'#
- class easydel.kernels.flash_attention.FlashAttention(config: Optional[AttentionConfig] = None)[source]#
Bases:
objectFlash Attention implementation with multiple backend support.
- class easydel.kernels.flash_attention.Platform(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumSupported Flash Attention platforms.
- JAX = 'jax'#
- PALLAS = 'pallas'#
- TRITON = 'triton'#
- easydel.kernels.flash_attention.create_flash_attention(backend: Optional[Union[Backend, str]] = None, platform: Optional[Union[Platform, str]] = None, **kwargs) FlashAttention[source]#
Factory function to create a FlashAttention instance with the specified configuration.
- Parameters
backend – Compute backend to use (GPU, TPU, or CPU)
platform – Platform to use (Triton, Pallas, or JAX)
**kwargs – Additional configuration parameters for AttentionConfig
- Returns
Configured FlashAttention instance