easydel.kernels.cpu_ops.ring_attention_jax._ring_attention#
Efficient Ring Attention Implementation for Single-Device Execution
This module provides an optimized implementation of ring attention, originally inspired by the work of Liu et al. (2023) ([https://arxiv.org/abs/2310.01889](https://arxiv.org/abs/2310.01889)). It incorporates the following enhancements:
Single-Device Focus: Adapted for efficient execution on a single device, removing the need for parallel communication primitives.
Enhanced JIT Compatibility: Streamlined for smoother integration with JAX’s Just-In-Time (JIT) compilation.
Performance Optimizations: Includes code optimizations for improved speed and memory usage.
Note: While based on existing implementations, this version offers significant modifications to enhance its usability and performance in single-device and multi-host settings. - also adding softmax scale option to support custom scales