easydel.kernels.tpu_ops.ring_attention_pallas._ring_attention

easydel.kernels.tpu_ops.ring_attention_pallas._ring_attention#

Efficient Ring Attention Implementation for Single-Device Execution

This module provides an optimized implementation of ring attention, originally inspired by the work of Liu et al. (2023) ([https://arxiv.org/abs/2310.01889](https://arxiv.org/abs/2310.01889)). It incorporates the following enhancements:

  • Single-Device Focus: Adapted for efficient execution on a single device, removing the need for parallel communication primitives.

  • Enhanced JIT Compatibility: Streamlined for smoother integration with JAX’s Just-In-Time (JIT) compilation.

  • Performance Optimizations: Includes code optimizations for improved speed and memory usage.

Note: While based on existing implementations, this version offers significant modifications to enhance its usability and performance in single-device and multi-host settings. - also adding softmax scale option to support custom scales