easydel.layers.ops._lightning_attention

easydel.layers.ops._lightning_attention#

easydel.layers.ops._lightning_attention.build_slope_tensor(n_attention_heads)[source]#
easydel.layers.ops._lightning_attention.lightning_attention(q: ~jax.Array, k: ~jax.Array, v: ~jax.Array, slope_rate: ~jax.Array, position_ids: ~typing.Optional[~jax.Array] = None, attn_mask: ~typing.Optional[~jax.Array] = None, past_key_value: ~typing.Optional[~jax.Array] = None, init_cache: bool = False, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Tuple[Array, Optional[Array]][source]#