# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model output classes for EasyDeL.
Defines standardized output structures for various model types and tasks.
These dataclasses provide consistent interfaces for model outputs while
maintaining compatibility with JAX pytrees.
Classes:
ModelOutput: Base class for all model outputs
CausalLMOutput: Output for causal language models
MoeCausalLMOutput: Output for MoE causal language models
SequenceClassifierOutput: Output for sequence classification
ImageClassifierOutput: Output for image classification
CLIPOutput: Output for CLIP models
CLIPTextModelOutput: Output for CLIP text encoders
GreedySearchOutput: Output for greedy generation
SampleOutput: Output for sampling generation
BeamSearchOutput: Output for beam search generation
Key Features:
- Consistent interface across model types
- JAX pytree compatibility
- Optional fields with None defaults
- Dictionary-like access patterns
- Automatic validation
Example:
>>> from easydel.infra.modeling_outputs import CausalLMOutput
>>> output = CausalLMOutput(
... logits=logits,
... hidden_states=hidden_states,
... attentions=attentions
... )
>>> # Access as attribute or dictionary
>>> logits = output.logits
>>> logits = output["logits"]
"""
from __future__ import annotations
import typing as tp
from dataclasses import fields, is_dataclass
import chex
from eformer.pytree import auto_pytree
from jax.core import Tracer
if tp.TYPE_CHECKING:
from easydel.layers.caching import TransformerCache, TransformerCacheView
else:
TransformerCacheView = tp.Any
TransformerCache = tp.Any
def _is_array(array):
if isinstance(array, Tracer):
return True
return False
[docs]class ModelOutput(tp.OrderedDict):
"""Base class for all model outputs.
Provides a consistent interface for model outputs that behaves like
both a tuple (for positional access) and a dictionary (for named access).
Automatically filters out None values and provides validation.
Subclasses must use the @auto_pytree decorator to ensure JAX compatibility.
Methods:
to_tuple: Convert to tuple, excluding None values
Note:
All fields except the first should have None as default.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
is_modeloutput_subclass = self.__class__ != ModelOutput
if is_modeloutput_subclass and not is_dataclass(self):
raise TypeError(
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
" This is a subclass of ModelOutput and so must use the @auto_pytree decorator."
)
[docs] def to_tuple(self) -> tuple[tp.Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())
def __post_init__(self):
"""Check the ModelOutput dataclass.
Only occurs if @auto_pytree decorator has been used.
"""
class_fields = fields(self)
# Safety and consistency checks
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]):
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not _is_array(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
if first_field_iterator:
for idx, element in enumerate(iterator):
if not isinstance(element, list | tuple) or not len(element) == 2 or not isinstance(element[0], str):
if idx == 0:
self[class_fields[0].name] = first_field
else:
raise ValueError(f"Cannot set key/value for {element}. It needs to be a tuple (key, value).")
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
[docs] def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
[docs] def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
[docs] def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
# Will raise a KeyException if needed
super().__setitem__(key, value)
# Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value)
def __reduce__(self):
if not is_dataclass(self):
return super().__reduce__()
_fn, _args, *remaining = super().__reduce__()
args = tuple(getattr(self, field.name) for field in fields(self))
return _fn, args, *remaining
[docs]@auto_pytree
class AttentionLayerOutput(ModelOutput):
"""Output from a single attention layer.
Contains the attention computation results from a transformer attention layer,
including optional attention weights and cache views for efficient generation.
Args:
attention_output: Output tensor from the attention layer with shape
(batch_size, sequence_length, hidden_size).
attention_weight: Optional attention weights after softmax with shape
(batch_size, num_heads, sequence_length, sequence_length).
Only returned when output_attentions=True.
cache_view: Optional cache view for efficient autoregressive generation.
Contains cached key-value pairs from previous steps.
"""
attention_output: chex.Array
attention_weight: chex.Array | None = None
cache_view: TransformerCacheView | None = None
[docs]@auto_pytree
class EncoderLayerOutput(ModelOutput):
"""Output from a single encoder layer.
Contains the outputs from a transformer encoder layer, including
the processed hidden states and optional attention weights.
Args:
hidden_states: Output hidden states from the encoder layer with shape
(batch_size, sequence_length, hidden_size).
residual_states: Optional residual connection states before layer norm
with shape (batch_size, sequence_length, hidden_size).
attention_weight: Optional attention weights after softmax with shape
(batch_size, num_heads, sequence_length, sequence_length).
Only returned when output_attentions=True.
"""
hidden_states: chex.Array
residual_states: chex.Array | None = None
attention_weight: chex.Array | None = None
[docs]@auto_pytree
class DecoderLayerOutput(ModelOutput):
"""Output from a single decoder layer.
Contains the outputs from a transformer decoder layer, including
hidden states, attention weights, and optional MoE routing information.
Args:
hidden_states: Output hidden states from the decoder layer with shape
(batch_size, sequence_length, hidden_size).
residual_states: Optional residual connection states before layer norm
with shape (batch_size, sequence_length, hidden_size).
cross_attention: Optional cross-attention outputs when using encoder-decoder
architecture with shape (batch_size, sequence_length, hidden_size).
attention_weight: Optional self-attention weights after softmax with shape
(batch_size, num_heads, sequence_length, sequence_length).
router_logits: Optional MoE router logits for expert selection with shape
(batch_size, sequence_length, num_experts).
gate_loss: Optional auxiliary loss for MoE load balancing.
cache_view: Optional cache view for efficient autoregressive generation.
"""
hidden_states: chex.Array
residual_states: chex.Array | None = None
cross_attention: chex.Array | None = None
attention_weight: chex.Array | None = None
router_logits: chex.Array | None = None
gate_loss: chex.Array | None = None
cache_view: TransformerCacheView | None = None
[docs]@auto_pytree
class BaseModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
past_key_values: dict[str, chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class BaseModelOutputWithNoAttention(ModelOutput):
"""
Base class for model's outputs, with potential hidden states.
Args:
last_hidden_state (`chex.Array` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class BaseModelOutputWithPoolingAndNoAttention(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`chex.Array` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`chex.Array` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a pooling operation on the spatial dimensions.
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: chex.Array = None
pooler_output: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class ImageClassifierOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
logits (`chex.Array` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(chex.Array)`, *optional*, returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
tp.Tuple of `chex.Array` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
"""
logits: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class BaseModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tp.Dict[str, chex.Array]`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: chex.Array = None
past_key_values: dict[str, chex.Array] | None = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class BaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`chex.Array` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: chex.Array = None
pooler_output: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`chex.Array` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
past_key_values (`tuple(chex.Array | None)`):
tp.Tuple of `tuple(chex.Array)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
"""
last_hidden_state: chex.Array = None
pooler_output: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
past_key_values: TransformerCache | None = None
attentions: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(chex.Array | None)`):
tp.Tuple of `tuple(chex.Array)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: chex.Array = None
past_key_values: TransformerCache | None = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class Seq2SeqModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
Args:
last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(chex.Array | None)`):
tp.Tuple of `tuple(chex.Array)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
last_hidden_state: chex.Array = None
past_key_values: TransformerCache | None = None
decoder_hidden_states: tuple[chex.Array] | None = None
decoder_attentions: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
encoder_last_hidden_state: chex.Array | None = None
encoder_hidden_states: tuple[chex.Array] | None = None
encoder_attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class CausalLMOutputWithCrossAttentions(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
logits (`chex.Array` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads.
past_key_values (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` tuples of length `config.n_layers`, with each tuple containing the cached key, value
states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting.
Only relevant if `config.is_decoder = True`.
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
"""
logits: chex.Array = None
past_key_values: TransformerCache | None = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class MaskedLMOutput(ModelOutput):
"""
Base class for masked language models outputs.
Args:
logits (`chex.Array` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: chex.Array | None = None
hidden_states: tuple[chex.Array] | None = None
last_hidden_state: chex.Array | None = None
attentions: tuple[chex.Array] | None = None
past_key_values: TransformerCache | None = None
loss: chex.Array | None = None
CausalLMOutput = MaskedLMOutput
# type:ignore
[docs]@auto_pytree
class Seq2SeqLMOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
Args:
logits (`chex.Array` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(chex.Array | None)`):
tp.Tuple of `tuple(chex.Array)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
logits: chex.Array = None
past_key_values: TransformerCache | None = None
decoder_hidden_states: tuple[chex.Array] | None = None
decoder_attentions: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
encoder_last_hidden_state: chex.Array | None = None
encoder_hidden_states: tuple[chex.Array] | None = None
encoder_attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class NextSentencePredictorOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
logits (`chex.Array` of shape `(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class SequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
logits (`chex.Array` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
past_key_values: TransformerCache | None = None
loss: chex.Array | None = None
aux_loss: chex.Array | None = None
[docs]@auto_pytree
class Seq2SeqSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence sentence classification models.
Args:
logits (`chex.Array` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
past_key_values (`tuple(chex.Array | None)`):
tp.Tuple of `tuple(chex.Array)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
logits: chex.Array = None
past_key_values: TransformerCache | None = None
decoder_hidden_states: tuple[chex.Array] | None = None
decoder_attentions: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
encoder_last_hidden_state: chex.Array | None = None
encoder_hidden_states: tuple[chex.Array] | None = None
encoder_attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class MultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
Args:
logits (`chex.Array` of shape `(batch_size, num_choices)`):
*num_choices* is the second dimension of the input tensors. (see *input_ids* above).
Classification scores (before SoftMax).
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class TokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
logits (`chex.Array` of shape `(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class QuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.
Args:
start_logits (`chex.Array` of shape `(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (`chex.Array` of shape `(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
start_logits: chex.Array = None
end_logits: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence question answering models.
Args:
start_logits (`chex.Array` of shape `(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (`chex.Array` of shape `(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
past_key_values (`tuple(chex.Array | None)`):
tp.Tuple of `tuple(chex.Array)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
start_logits: chex.Array = None
end_logits: chex.Array = None
past_key_values: TransformerCache | None = None
decoder_hidden_states: tuple[chex.Array] | None = None
decoder_attentions: tuple[chex.Array] | None = None
cross_attentions: tuple[chex.Array] | None = None
encoder_last_hidden_state: chex.Array | None = None
encoder_hidden_states: tuple[chex.Array] | None = None
encoder_attentions: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class MoeModelOutput(ModelOutput):
"""
Base class for MoE model outputs.
Args:
last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer)
of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
router_logits (`tuple(chex.Array)`, *optional*):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
The logits output of the router network, which are used to compute the mixture of experts.
"""
last_hidden_state: chex.Array = None
hidden_states: tuple[chex.Array] | None = None
past_key_values: TransformerCache | None = None
attentions: tuple[chex.Array] | None = None
router_logits: tuple[chex.Array] | None = None
all_router_losses: tuple[chex.Array] | None = None
logits: chex.Array = None
loss: chex.Array | None = None
[docs]@auto_pytree
class MoeCausalLMOutput(MaskedLMOutput):
"""
Base class for causal language modeling (CLM) outputs of MoE models.
Args:
aux_loss (`chex.Array`, *optional*):
Auxiliary loss used for training MoE models.
router_logits (`tuple(chex.Array)`, *optional*):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
The logits output of the router network, which are used to compute the mixture of experts.
"""
aux_loss: chex.Array | None = None
router_logits: tuple[chex.Array] | None = None
all_router_losses: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class VLMCausalLMOutput(ModelOutput):
"""Unified output class for Vision-Language Models (VLMs).
Provides a standardized output structure for all VLM models including
LLaVA, Qwen2-VL, Qwen3-VL, Gemma3, AyaVision, Mistral3, and Llama4.
Args:
logits (`chex.Array` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (before SoftMax).
past_key_values (`TransformerCache`, *optional*):
Pre-computed hidden-states (key and values in attention blocks) for
efficient autoregressive generation.
hidden_states (`tuple(chex.Array)`, *optional*):
Tuple of hidden-states at output of each layer plus embeddings.
Shape: `(batch_size, sequence_length, hidden_size)`.
last_hidden_state (`chex.Array`, *optional*):
Hidden-state at output of the last layer.
Shape: `(batch_size, sequence_length, hidden_size)`.
attentions (`tuple(chex.Array)`, *optional*):
Attention weights after softmax. Shape: `(batch_size, num_heads,
sequence_length, sequence_length)`.
image_hidden_states (`chex.Array`, *optional*):
Projected image features from the vision encoder after the multimodal
projector. Shape varies by model.
video_hidden_states (`chex.Array`, *optional*):
Projected video features for models supporting video input (Qwen2-VL,
Qwen3-VL, Llama4). Shape varies by model.
rope_deltas (`chex.Array`, *optional*):
Position embedding deltas for multi-dimensional RoPE (mRoPE) used in
Qwen2-VL and Qwen3-VL models.
router_logits (`tuple(chex.Array)`, *optional*):
Router logits for MoE VLMs (Qwen3-VL-MoE). Shape:
`(batch_size, sequence_length, num_experts)`.
aux_loss (`chex.Array`, *optional*):
Auxiliary loss for MoE load balancing.
loss (`chex.Array`, *optional*):
Language modeling loss when labels are provided.
"""
logits: chex.Array = None
past_key_values: TransformerCache | None = None
hidden_states: tuple[chex.Array] | None = None
last_hidden_state: chex.Array | None = None
attentions: tuple[chex.Array] | None = None
image_hidden_states: chex.Array | None = None
video_hidden_states: chex.Array | None = None
rope_deltas: chex.Array | None = None
router_logits: tuple[chex.Array] | None = None
aux_loss: chex.Array | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class MambaOutput(BaseModelOutput):
"""Output from Mamba state-space models.
Contains the outputs from Mamba models which use selective state-space
layers instead of attention for sequence modeling.
Args:
last_hidden_state: Final hidden states from the model with shape
(batch_size, sequence_length, hidden_size).
cache_params: Optional list of cached state-space parameters for
efficient autoregressive generation. Each element contains the
SSM state for a layer.
hidden_states: Optional tuple of hidden states from all layers.
Only returned when output_hidden_states=True.
loss: Optional loss value when labels are provided.
"""
last_hidden_state: chex.Array = None
cache_params: list[chex.Array] | None = None
hidden_states: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class MambaCausalLMOutput(BaseModelOutput):
"""Output from Mamba causal language models.
Contains the outputs from Mamba models configured for causal language
modeling, including logits over the vocabulary.
Args:
logits: Prediction scores over the vocabulary with shape
(batch_size, sequence_length, vocab_size).
cache_params: Optional list of cached state-space parameters for
efficient autoregressive generation.
hidden_states: Optional tuple of hidden states from all layers.
Only returned when output_hidden_states=True.
loss: Optional language modeling loss when labels are provided.
"""
logits: chex.Array = None
cache_params: list[chex.Array] | None = None
hidden_states: tuple[chex.Array] | None = None
loss: chex.Array | None = None
[docs]@auto_pytree
class CLIPTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`chex.Array` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of
[`FlaxCLIPTextModel`].
last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
text_embeds: chex.Array = None
last_hidden_state: chex.Array = None
hidden_states: tuple[chex.Array, ...] | None = None
attentions: tuple[chex.Array, ...] | None = None
[docs]@auto_pytree
class ImageClassifierOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`chex.Array` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of
[`FlaxCLIPTextModel`].
last_hidden_state (`chex.Array` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(chex.Array | None)`):
tp.Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
text_embeds: chex.Array = None
last_hidden_state: chex.Array = None
hidden_states: tuple[chex.Array, ...] | None = None
attentions: tuple[chex.Array, ...] | None = None
[docs]@auto_pytree
class CLIPOutput(ModelOutput):
"""
Args:
loss:(`chex.Array`) training loss
logits_per_image:(`chex.Array` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text:(`chex.Array` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds(`chex.Array` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of
[`FlaxCLIPTextModel`].
image_embeds(`chex.Array` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of
[`FlaxCLIPVisionModel`].
text_model_output(`BaseModelOutputWithPooling`):
The output of the [`FlaxCLIPTextModel`].
vision_model_output(`BaseModelOutputWithPooling`):
The output of the [`FlaxCLIPVisionModel`].
"""
loss: chex.Array = None
logits_per_image: chex.Array = None
logits_per_text: chex.Array = None
text_embeds: chex.Array = None
image_embeds: chex.Array = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
[docs] def to_tuple(self) -> tuple[tp.Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
[docs]@auto_pytree
class GreedySearchOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using greedy search.
Args:
sequences (`chex.Array` of shape `(batch_size, max_length)`):
The generated sequences.
"""
sequences: chex.Array = None
[docs]@auto_pytree
class SampleOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using sampling.
Args:
sequences (`chex.Array` of shape `(batch_size, max_length)`):
The generated sequences.
"""
sequences: chex.Array = None
[docs]@auto_pytree
class BeamSearchOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using greedy search.
Args:
sequences (`chex.Array` of shape `(batch_size, max_length)`):
The generated sequences.
scores (`chex.Array` of shape `(batch_size,)`):
The scores (log probabilities) of the generated sequences.
"""
sequences: chex.Array = None
scores: chex.Array = None