Source code for easydel.layers.operations._attention_outputs
from __future__ import annotations
from eformer.pytree import auto_pytree
from jax import Array
from jaxtyping import Float
from ..caching import RaggedPagesCacheView, TransformerCacheView
from ._operation_impl import OperationOutput
[docs]@auto_pytree
class AttentionOutput(OperationOutput):
"""
This dataclass encapsulates the results computation
"""
attention_weights: Float[Array, "batch num_heads seq_len seq_len"] | None = None
attention_outputs: Float[Array, "batch seq_len num_heads head_dim"] | None = None
cache_view: TransformerCacheView | RaggedPagesCacheView | None = None