easydel.kernels.flash_attention#
- class easydel.kernels.flash_attention.AttentionConfig(blocksize_q: int = 128, blocksize_k: int = 128, softmax_scale: Optional[float] = None, backend: Optional[Backend] = None, platform: Optional[Platform] = None)[source]#
Bases:
objectConfiguration for Flash Attention computation.
- blocksize_k: int = 128#
- blocksize_q: int = 128#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- softmax_scale: Optional[float] = None#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.kernels.flash_attention.Backend(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumSupported compute backends.
- CPU = 'cpu'#
- GPU = 'gpu'#
- TPU = 'tpu'#
- class easydel.kernels.flash_attention.FlashAttention(config: Optional[AttentionConfig] = None)[source]#
Bases:
objectFlash Attention implementation with multiple backend support.
- class easydel.kernels.flash_attention.Platform(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumSupported Flash Attention platforms.
- JAX = 'jax'#
- PALLAS = 'pallas'#
- TRITON = 'triton'#
- easydel.kernels.flash_attention.create_flash_attention(backend: Optional[Union[Backend, str]] = None, platform: Optional[Union[Platform, str]] = None, **kwargs) FlashAttention[source]#
Factory function to create a FlashAttention instance with the specified configuration.
- Parameters
backend – Compute backend to use (GPU, TPU, or CPU)
platform – Platform to use (Triton, Pallas, or JAX)
**kwargs – Additional configuration parameters for AttentionConfig
- Returns
Configured FlashAttention instance