easydel.kernels.flash_attention#

class easydel.kernels.flash_attention.AttentionConfig(blocksize_q: int = 128, blocksize_k: int = 128, softmax_scale: Optional[float] = None, backend: Optional[Backend] = None, platform: Optional[Platform] = None)[source]#

Bases: object

Configuration for Flash Attention computation.

backend: Optional[Backend] = None#
blocksize_k: int = 128#
blocksize_q: int = 128#
platform: Optional[Platform] = None#
replace(**kwargs)#
softmax_scale: Optional[float] = None#
class easydel.kernels.flash_attention.Backend(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Supported compute backends.

CPU = 'cpu'#
GPU = 'gpu'#
TPU = 'tpu'#
class easydel.kernels.flash_attention.FlashAttention(config: Optional[AttentionConfig] = None)[source]#

Bases: object

Flash Attention implementation with multiple backend support.

static repeat_kv_heads(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], num_reps: int) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#

Repeats key and value heads to match query heads.

class easydel.kernels.flash_attention.Platform(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Supported Flash Attention platforms.

JAX = 'jax'#
PALLAS = 'pallas'#
TRITON = 'triton'#
easydel.kernels.flash_attention.create_flash_attention(backend: Optional[Union[Backend, str]] = None, platform: Optional[Union[Platform, str]] = None, **kwargs) FlashAttention[source]#

Factory function to create a FlashAttention instance with the specified configuration.

Parameters
  • backend – Compute backend to use (GPU, TPU, or CPU)

  • platform – Platform to use (Triton, Pallas, or JAX)

  • **kwargs – Additional configuration parameters for AttentionConfig

Returns

Configured FlashAttention instance

easydel.kernels.flash_attention.free_gpu_in_process() int[source]#

Returns the index of the GPU with the most available memory using JAX local_devices.

Returns

Index of the GPU with most free memory

Return type

int

easydel.kernels.flash_attention.get_device_memory_usage(device: Device) float[source]#

Get the memory usage for a specific JAX device using local_devices stats.

Parameters

device – JAX device to check

Returns

Memory usage in bytes

Return type

float