FlexibleAttentionModule: A Versatile Attention Mechanism Factory#
The FlexibleAttentionModule class is designed to simplify the creation and execution of different attention mechanisms within
your EasyDeL models. It provides a unified interface for working with various attention types, allowing you to easily
switch between them and experiment with different configurations.
Key Features:
Mechanism Selection: The
attn_mechanismargument lets you choose the specific attention algorithm you want to use (e.g., “vanilla,” “flash,” “splash,” “ring,” “cudnn”).Sharding and Partitioning: The class supports advanced JAX sharding techniques to distribute attention computations across multiple devices for efficient processing of large models. It handles partitioning of query, key, value, bias, and attention weight matrices using
PartitionSpec.Blockwise Attention: Enables the use of blockwise attention for increased memory efficiency, especially with long sequences.
Caching Support: Facilitates the use of attention caching to speed up inference and generation tasks.
Dropout and Determinism: Allows for applying dropout to attention weights and controlling the deterministic behavior of the attention computation.
Testing Utility: Provides a
run_attention_benchmarksmethod to compare different attention mechanisms in terms of accuracy, gradient stability, and computation time.
How it Works:
Initialization:
During initialization, you provide the desired
attn_mechanism, JAXmeshfor sharding, scaling factor (sm_scale), number of attention heads, head dimensions, and other configuration parameters.The class automatically sets default values for many parameters based on the chosen attention mechanism and the provided EasyDeL configuration (
base_module_class).
Calling the Module:
When you call the
FlexibleAttentionModuleobject, you pass in the query, key, and value states, along with optional parameters like attention masks, biases, and causal flags.The module internally selects the appropriate attention function based on the specified
attn_mechanism.It performs any necessary sharding and partitioning based on the configured partition specifications.
The attention computation is executed, and the attention outputs (and optionally attention weights) are returned.
Advantages:
Flexibility: Allows you to easily switch between different attention mechanisms without major code changes.
Efficiency: Supports advanced JAX sharding for distributed computation, enabling the handling of large models.
FlexibleAttentionModule is a EasyDeL module that can perform attention operation with different strategies to help user achieve the best possible performance and numerical stability, here are some strategies supported right now.
Flash Attention TPU/GPU/CPU known as “flash_attn2”
Ring Attention to Support higher context length such 1 Million or above known as “ring”
Normal Attention which use flax.linen.attention with shard map known as “vanilla”
Splash Attention on TPUs which is known as “splash”
Other Attention modules might be added you can check source code for that..
Testing which Attention Module works best#
in order to test which attention module in what axis dims works best for you you can run
from easydel import FlexibleAttentionModule
print(
FlexibleAttentionModule.run_attention_benchmarks(
axis_dims=(1, 1, 1, -1),
sequence_length=128 * 8,
num_attention_heads=32,
num_key_value_heads=32,
chunk_size=128,
)
)
Example of Using Flash Attention on TPU#
import jax
import flax.linen.attention as flt
from fjformer import GenerateRNG
from easydel import PartitionAxis
from easydel.layers.attention import FlexibleAttentionModule, FlaxAttentionModule
from easydel.modules.easydel_modelling_utils import EasyDeLBaseConfig
from jax import numpy as jnp, random, lax
import math
rng_gen = GenerateRNG(seed=42)
config = EasyDeLBaseConfig(
axis_dims=(1, -1, 1, 1),
axis_names=("dp", "fsdp", "tp", "sp"),
block_q=512,
block_k=512
)
BATCH_SIZE = len(jax.devices())
NUM_ATTN_HEADS = 32
CONTEXT_LENGTH = 8192
HEAD_DIM = 256
def make_fake_input_data(
batch_size: int,
num_attention_head: int,
context_length: int,
head_dim: int,
):
q = random.normal(next(rng_gen), (batch_size, context_length, num_attention_head, head_dim), dtype=jnp.float32)
k = random.normal(next(rng_gen), (batch_size, context_length, num_attention_head, head_dim), dtype=jnp.float32)
v = random.normal(next(rng_gen), (batch_size, context_length, num_attention_head, head_dim), dtype=jnp.float32)
attention_mask = jnp.ones((batch_size, context_length))
causal_mask = flt.make_causal_mask(attention_mask)
cm_ = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
at_ = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), cm_.shape)
at_ = flt.combine_masks(at_, cm_)
attention_bias = lax.select(
at_ > 0,
jnp.full(at_.shape, 0.0).astype(jnp.float32),
jnp.full(at_.shape, jnp.finfo(jnp.float32).min).astype(jnp.float32),
)
return (
q, k, v, attention_mask, causal_mask, attention_bias
)
q, k, v, attention_mask, causal_mask, attention_bias = make_fake_input_data(
BATCH_SIZE,
NUM_ATTN_HEADS,
CONTEXT_LENGTH,
HEAD_DIM
)
flash_attention = FlexibleAttentionModule(
num_attention_heads=NUM_ATTN_HEADS,
attention_dropout=0.0,
head_dims=HEAD_DIM,
partition_axis=PartitionAxis(
batch_axis=("dp", "fsdp"),
query_sequence_axis="sp",
key_sequence_axis="sp",
head_axis="tp",
attention_dim_axis=None
),
shard_attention_computation=config.shard_attention_computation,
precision=lax.Precision("fastest"),
force_float32_tpu=True,
attn_mechanism=..., # check from source
dtype=jnp.float32,
scan_ring_attention=config.scan_ring_attention,
mesh=config.mesh,
sm_scale=1 / math.sqrt(q.shape[-1]),
)
normal_attention = FlexibleAttentionModule(
num_attention_heads=NUM_ATTN_HEADS,
attention_dropout=0.0,
head_dims=HEAD_DIM,
partition_axis=PartitionAxis(
batch_axis=("dp", "fsdp"),
query_sequence_axis="sp",
key_sequence_axis="sp",
head_axis="tp",
attention_dim_axis=None
),
shard_attention_computation=config.shard_attention_computation,
precision=lax.Precision("fastest"),
force_float32_tpu=True,
attn_mechanism="vanilla",
dtype=jnp.float32,
scan_ring_attention=config.scan_ring_attention,
mesh=config.mesh,
sm_scale=1 / math.sqrt(q.shape[-1]),
)
with config.mesh:
flash_attn_out = flash_attention(
query_states=q,
key_states=k,
value_states=v,
bias=attention_bias,
key_value_sequence_length=CONTEXT_LENGTH,
query_sequence_length=CONTEXT_LENGTH
)
normal_attn_out = normal_attention(
query_states=q,
key_states=k,
value_states=v,
bias=attention_bias,
key_value_sequence_length=CONTEXT_LENGTH,
query_sequence_length=CONTEXT_LENGTH
)
print(
flash_attn_out.attention_outputs[0, CONTEXT_LENGTH - 5, NUM_ATTN_HEADS - 1, HEAD_DIM - 10:]
)
# Array([-0.05915311, 0.0078501 , 0.03785717, 0.0134844 , 0.08464689,
# 0.06667967, -0.02629154, -0.0180066 , -0.02972782, 0.02833381], dtype=float32)
print(
normal_attn_out.attention_outputs[0, CONTEXT_LENGTH - 5, NUM_ATTN_HEADS - 1, HEAD_DIM - 10:]
)
# Array([-0.0590958 , 0.00796138, 0.03789062, 0.01350671, 0.08461153,
# 0.06662725, -0.0262386 , -0.01806086, -0.0296791 , 0.02824247], dtype=float32)