easydel.kernels.ring_attention#
- easydel.kernels.ring_attention.ring_attention(query: ~typing.Union[~jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number], key: ~typing.Union[~jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number], value: ~typing.Union[~jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number], bias: ~typing.Optional[~typing.Union[~jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number]] = None, segment_ids: ~typing.Optional[~typing.Union[~jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number]] = None, axis_name: ~typing.Optional[str] = None, float32_logits: bool = True, softmax_scale: ~typing.Optional[float] = None, blocksize_q: int = 512, blocksize_k: int = 512, blocksize_c: ~typing.Optional[int] = None, deterministic: bool = True, dropout_rng: ~typing.Optional[~jax.Array] = None, pdrop: float = 0.0, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, policy=<function nothing_saveable>, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, prevent_cse: bool = True, cache_idx=None, backend: ~typing.Literal['cpu', 'gpu', 'tpu'] = Ellipsis, platform: ~typing.Literal['pallas', 'jax'] = Ellipsis, autocheck: bool = True)[source]#
Computes ring attention with blockwise transformers. Supports JAX, Pallas backends for TPU,GPU,CPU :param query: Query array of shape (batch, q_len, num_heads, dim_per_head). :param key: Key array of shape (batch, kv_len, num_heads, dim_per_head). :param value: Value array of shape (batch, kv_len, num_heads, dim_per_head). :param bias: tp.Optional bias array of shape (batch, num_heads, q_len, kv_len). :param segment_ids: tp.Optional segment ids array of shape (batch, seq_len). :param axis_name: Name of the axis to ppermute over. :param float32_logits: Whether to compute logits in float32. :param softmax_scale: scale for softmax or depth ** -0.5. :param blocksize_q: Size of query chunks. :param blocksize_k: Size of key chunks. :param blocksize_c: Size of causal blocks. :param deterministic: Whether to apply dropout. :param dropout_rng: PRNG key for dropout. :param pdrop: Dropout probability. :param dtype: dtype of the computation. :param policy: Checkpoint policy. :param precision: Precision of the computation. :param prevent_cse: Whether to prevent common subexpression elimination. :param platform: platform to be used for func (JAX, Pallas) :param backend: requested backend for func (cpu, tpu, gpu) :param autocheck: whenever to auto check blocksizes(q/k)
- Returns
Output array of shape (batch, q_len, num_heads, dim_per_head).