easydel.layers.operations._attention_outputs#

class easydel.layers.operations._attention_outputs.AttentionOutput(attention_weights: jaxtyping.Float[Array, 'batch num_heads seq_len seq_len'] | None = None, attention_outputs: jaxtyping.Float[Array, 'batch seq_len num_heads head_dim'] | None = None, cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None)[source]#

Bases: OperationOutput

This dataclass encapsulates the results computation

attention_outputs: jaxtyping.Float[Array, 'batch seq_len num_heads head_dim'] | None = None#
attention_weights: jaxtyping.Float[Array, 'batch num_heads seq_len seq_len'] | None = None#
cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.