easydel.infra.modeling_outputs

Contents

easydel.infra.modeling_outputs#

Model output classes for EasyDeL.

Defines standardized output structures for various model types and tasks. These dataclasses provide consistent interfaces for model outputs while maintaining compatibility with JAX pytrees.

Classes:

ModelOutput: Base class for all model outputs CausalLMOutput: Output for causal language models MoeCausalLMOutput: Output for MoE causal language models SequenceClassifierOutput: Output for sequence classification ImageClassifierOutput: Output for image classification CLIPOutput: Output for CLIP models CLIPTextModelOutput: Output for CLIP text encoders GreedySearchOutput: Output for greedy generation SampleOutput: Output for sampling generation BeamSearchOutput: Output for beam search generation

Key Features:
  • Consistent interface across model types

  • JAX pytree compatibility

  • Optional fields with None defaults

  • Dictionary-like access patterns

  • Automatic validation

Example

>>> from easydel.infra.modeling_outputs import CausalLMOutput
>>> output = CausalLMOutput(
...     logits=logits,
...     hidden_states=hidden_states,
...     attentions=attentions
... )
>>> # Access as attribute or dictionary
>>> logits = output.logits
>>> logits = output["logits"]
class easydel.infra.modeling_outputs.AttentionLayerOutput(attention_output: Union[Array, ndarray, bool, number], attention_weight: Optional[Union[Array, ndarray, bool, number]] = None, cache_view: Any | None = None)[source]#

Bases: ModelOutput

Output from a single attention layer.

Contains the attention computation results from a transformer attention layer, including optional attention weights and cache views for efficient generation.

Parameters
  • attention_output – Output tensor from the attention layer with shape (batch_size, sequence_length, hidden_size).

  • attention_weight – Optional attention weights after softmax with shape (batch_size, num_heads, sequence_length, sequence_length). Only returned when output_attentions=True.

  • cache_view – Optional cache view for efficient autoregressive generation. Contains cached key-value pairs from previous steps.

attention_output: Union[Array, ndarray, bool, number]#
attention_weight: Optional[Union[Array, ndarray, bool, number]] = None#
cache_view: Any | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.BaseModelOutput(last_hidden_state: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, past_key_values: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for model’s outputs, with potential hidden states and attentions.

Parameters
  • last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model.

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.BaseModelOutputWithNoAttention(last_hidden_state: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for model’s outputs, with potential hidden states.

Parameters
  • last_hidden_state (chex.Array of shape (batch_size, num_channels, height, width)) – Sequence of hidden-states at the output of the last layer of the model.

  • hidden_states (tuple(chex.Array | None)) – tp.Tuple of chex.Array (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, num_channels, height, width). Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.BaseModelOutputWithPast(last_hidden_state: Union[Array, ndarray, bool, number] = None, past_key_values: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for model’s outputs, with potential hidden states and attentions.

Parameters
  • last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model.

  • past_key_values (tp.Dict[str, chex.Array]) – Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape [batch_size, max_length].

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions(last_hidden_state: Union[Array, ndarray, bool, number] = None, past_key_values: Any | None = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for model’s outputs that may also contain a past key/values (to speed up sequential decoding).

Parameters
  • last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size)) –

    Sequence of hidden-states at the output of the last layer of the model.

    If past_key_values is used only the last hidden-state of the sequences of shape (batch_size, 1, hidden_size) is output.

  • past_key_values (tuple(chex.Array | None)) –

    tp.Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)) and optionally if config.is_encoder_decoder=True 2 additional tensors of shape (batch_size, num_heads, encoder_sequence_length, embed_size_per_head).

    Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if config.is_encoder_decoder=True in the cross-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • cross_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.BaseModelOutputWithPooling(last_hidden_state: Union[Array, ndarray, bool, number] = None, pooler_output: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for model’s outputs that also contains a pooling of the last hidden states.

Parameters
  • last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model.

  • pooler_output (chex.Array of shape (batch_size, hidden_size)) – Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
pooler_output: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state: Union[Array, ndarray, bool, number] = None, pooler_output: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, past_key_values: Any | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for model’s outputs that also contains a pooling of the last hidden states.

Parameters
  • last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model.

  • pooler_output (chex.Array of shape (batch_size, hidden_size)) – Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns the classification token after processing through a linear layer and a tanh activation function. The linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • cross_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.

  • past_key_values (tuple(chex.Array | None)) –

    tp.Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)) and optionally if config.is_encoder_decoder=True 2 additional tensors of shape (batch_size, num_heads, encoder_sequence_length, embed_size_per_head).

    Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if config.is_encoder_decoder=True in the cross-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
pooler_output: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.BaseModelOutputWithPoolingAndNoAttention(last_hidden_state: Union[Array, ndarray, bool, number] = None, pooler_output: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for model’s outputs that also contains a pooling of the last hidden states.

Parameters
  • last_hidden_state (chex.Array of shape (batch_size, num_channels, height, width)) – Sequence of hidden-states at the output of the last layer of the model.

  • pooler_output (chex.Array of shape (batch_size, hidden_size)) – Last layer hidden-state after a pooling operation on the spatial dimensions.

  • hidden_states (tuple(chex.Array | None)) – tp.Tuple of chex.Array (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, num_channels, height, width). Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
pooler_output: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.BeamSearchOutput(sequences: Union[Array, ndarray, bool, number] = None, scores: Union[Array, ndarray, bool, number] = None)[source]#

Bases: ModelOutput

Flax Base class for outputs of decoder-only generation models using greedy search.

Parameters
  • sequences (chex.Array of shape (batch_size, max_length)) – The generated sequences.

  • scores (chex.Array of shape (batch_size,)) – The scores (log probabilities) of the generated sequences.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

scores: Union[Array, ndarray, bool, number] = None#
sequences: Union[Array, ndarray, bool, number] = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.CLIPOutput(loss: Union[Array, ndarray, bool, number] = None, logits_per_image: Union[Array, ndarray, bool, number] = None, logits_per_text: Union[Array, ndarray, bool, number] = None, text_embeds: Union[Array, ndarray, bool, number] = None, image_embeds: Union[Array, ndarray, bool, number] = None, text_model_output: BaseModelOutputWithPooling = None, vision_model_output: BaseModelOutputWithPooling = None)[source]#

Bases: ModelOutput

Parameters
  • loss – (chex.Array) training loss

  • logits_per_image – (chex.Array of shape (image_batch_size, text_batch_size)): The scaled dot product scores between image_embeds and text_embeds. This represents the image-text similarity scores.

  • logits_per_text – (chex.Array of shape (text_batch_size, image_batch_size)): The scaled dot product scores between text_embeds and image_embeds. This represents the text-image similarity scores.

  • text_embeds (chex.Array of shape (batch_size, output_dim) – The text embeddings obtained by applying the projection layer to the pooled output of [FlaxCLIPTextModel].

  • image_embeds (chex.Array of shape (batch_size, output_dim) – The image embeddings obtained by applying the projection layer to the pooled output of [FlaxCLIPVisionModel].

  • text_model_output (BaseModelOutputWithPooling) – The output of the [FlaxCLIPTextModel].

  • vision_model_output (BaseModelOutputWithPooling) – The output of the [FlaxCLIPVisionModel].

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

image_embeds: Union[Array, ndarray, bool, number] = None#
logits_per_image: Union[Array, ndarray, bool, number] = None#
logits_per_text: Union[Array, ndarray, bool, number] = None#
loss: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

text_embeds: Union[Array, ndarray, bool, number] = None#
text_model_output: BaseModelOutputWithPooling = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

to_tuple() tuple[Any][source]#

Convert self to a tuple containing all the attributes/keys that are not None.

vision_model_output: BaseModelOutputWithPooling = None#
class easydel.infra.modeling_outputs.CLIPTextModelOutput(text_embeds: Union[Array, ndarray, bool, number] = None, last_hidden_state: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None)[source]#

Bases: ModelOutput

Base class for text model’s outputs that also contains a pooling of the last hidden states.

Parameters
  • text_embeds (chex.Array of shape (batch_size, output_dim) – The text embeddings obtained by applying the projection layer to the pooled output of [FlaxCLIPTextModel].

  • last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model.

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

text_embeds: Union[Array, ndarray, bool, number] = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

easydel.infra.modeling_outputs.CausalLMOutput#

alias of MaskedLMOutput

class easydel.infra.modeling_outputs.CausalLMOutputWithCrossAttentions(logits: Union[Array, ndarray, bool, number] = None, past_key_values: Any | None = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for causal language model (or autoregressive) outputs.

Parameters
  • logits (chex.Array of shape (batch_size, sequence_length, config.vocab_size)) – Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • cross_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Cross attentions weights after the attention softmax, used to compute the weighted average in the cross-attention heads.

  • past_key_values (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array tuples of length config.n_layers, with each tuple containing the cached key, value states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting. Only relevant if config.is_decoder = True.

    Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.DecoderLayerOutput(hidden_states: Union[Array, ndarray, bool, number], residual_states: Optional[Union[Array, ndarray, bool, number]] = None, cross_attention: Optional[Union[Array, ndarray, bool, number]] = None, attention_weight: Optional[Union[Array, ndarray, bool, number]] = None, router_logits: Optional[Union[Array, ndarray, bool, number]] = None, gate_loss: Optional[Union[Array, ndarray, bool, number]] = None, cache_view: Any | None = None)[source]#

Bases: ModelOutput

Output from a single decoder layer.

Contains the outputs from a transformer decoder layer, including hidden states, attention weights, and optional MoE routing information.

Parameters
  • hidden_states – Output hidden states from the decoder layer with shape (batch_size, sequence_length, hidden_size).

  • residual_states – Optional residual connection states before layer norm with shape (batch_size, sequence_length, hidden_size).

  • cross_attention – Optional cross-attention outputs when using encoder-decoder architecture with shape (batch_size, sequence_length, hidden_size).

  • attention_weight – Optional self-attention weights after softmax with shape (batch_size, num_heads, sequence_length, sequence_length).

  • router_logits – Optional MoE router logits for expert selection with shape (batch_size, sequence_length, num_experts).

  • gate_loss – Optional auxiliary loss for MoE load balancing.

  • cache_view – Optional cache view for efficient autoregressive generation.

attention_weight: Optional[Union[Array, ndarray, bool, number]] = None#
cache_view: Any | None = None#
cross_attention: Optional[Union[Array, ndarray, bool, number]] = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

gate_loss: Optional[Union[Array, ndarray, bool, number]] = None#
hidden_states: Union[Array, ndarray, bool, number]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

residual_states: Optional[Union[Array, ndarray, bool, number]] = None#
router_logits: Optional[Union[Array, ndarray, bool, number]] = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.EncoderLayerOutput(hidden_states: Union[Array, ndarray, bool, number], residual_states: Optional[Union[Array, ndarray, bool, number]] = None, attention_weight: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Output from a single encoder layer.

Contains the outputs from a transformer encoder layer, including the processed hidden states and optional attention weights.

Parameters
  • hidden_states – Output hidden states from the encoder layer with shape (batch_size, sequence_length, hidden_size).

  • residual_states – Optional residual connection states before layer norm with shape (batch_size, sequence_length, hidden_size).

  • attention_weight – Optional attention weights after softmax with shape (batch_size, num_heads, sequence_length, sequence_length). Only returned when output_attentions=True.

attention_weight: Optional[Union[Array, ndarray, bool, number]] = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: Union[Array, ndarray, bool, number]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

residual_states: Optional[Union[Array, ndarray, bool, number]] = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.GreedySearchOutput(sequences: Union[Array, ndarray, bool, number] = None)[source]#

Bases: ModelOutput

Flax Base class for outputs of decoder-only generation models using greedy search.

Parameters

sequences (chex.Array of shape (batch_size, max_length)) – The generated sequences.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

sequences: Union[Array, ndarray, bool, number] = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.ImageClassifierOutput(text_embeds: Union[Array, ndarray, bool, number] = None, last_hidden_state: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None)[source]#

Bases: ModelOutput

Base class for text model’s outputs that also contains a pooling of the last hidden states.

Parameters
  • text_embeds (chex.Array of shape (batch_size, output_dim) – The text embeddings obtained by applying the projection layer to the pooled output of [FlaxCLIPTextModel].

  • last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model.

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

text_embeds: Union[Array, ndarray, bool, number] = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.ImageClassifierOutputWithNoAttention(logits: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for outputs of image classification models.

Parameters
  • logits (chex.Array of shape (batch_size, config.num_labels)) – Classification (or regression if config.num_labels==1) scores (before SoftMax).

  • hidden_states (`tuple(chex.Array) –

  • config.output_hidden_states=True) – tp.Tuple of chex.Array (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each stage) of shape (batch_size, num_channels, height, width). Hidden-states (also called feature maps) of the model at the output of each stage.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.MambaCausalLMOutput(last_hidden_state: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, past_key_values: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None, logits: Union[Array, ndarray, bool, number] = None, cache_params: list[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None)[source]#

Bases: BaseModelOutput

Output from Mamba causal language models.

Contains the outputs from Mamba models configured for causal language modeling, including logits over the vocabulary.

Parameters
  • logits – Prediction scores over the vocabulary with shape (batch_size, sequence_length, vocab_size).

  • cache_params – Optional list of cached state-space parameters for efficient autoregressive generation.

  • hidden_states – Optional tuple of hidden states from all layers. Only returned when output_hidden_states=True.

  • loss – Optional language modeling loss when labels are provided.

cache_params: list[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.MambaOutput(last_hidden_state: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, past_key_values: dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None, cache_params: list[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None)[source]#

Bases: BaseModelOutput

Output from Mamba state-space models.

Contains the outputs from Mamba models which use selective state-space layers instead of attention for sequence modeling.

Parameters
  • last_hidden_state – Final hidden states from the model with shape (batch_size, sequence_length, hidden_size).

  • cache_params – Optional list of cached state-space parameters for efficient autoregressive generation. Each element contains the SSM state for a layer.

  • hidden_states – Optional tuple of hidden states from all layers. Only returned when output_hidden_states=True.

  • loss – Optional loss value when labels are provided.

cache_params: list[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.MaskedLMOutput(logits: Optional[Union[Array, ndarray, bool, number]] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, past_key_values: Any | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for masked language models outputs.

Parameters
  • logits (chex.Array of shape (batch_size, sequence_length, config.vocab_size)) – Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None#
logits: Optional[Union[Array, ndarray, bool, number]] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.ModelOutput(*args, **kwargs)[source]#

Bases: OrderedDict

Base class for all model outputs.

Provides a consistent interface for model outputs that behaves like both a tuple (for positional access) and a dictionary (for named access). Automatically filters out None values and provides validation.

Subclasses must use the @auto_pytree decorator to ensure JAX compatibility.

to_tuple()[source]#

Convert to tuple, excluding None values

Note

All fields except the first should have None as default.

pop(key[, default]) v, remove specified key and return the corresponding value.[source]#

If the key is not found, return the default if given; otherwise, raise a KeyError.

setdefault(*args, **kwargs)[source]#

Insert key with a value of default if key is not in the dictionary.

Return the value for key if key is in the dictionary, else default.

to_tuple() tuple[Any][source]#

Convert self to a tuple containing all the attributes/keys that are not None.

update([E, ]**F) None.  Update D from dict/iterable E and F.[source]#

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

class easydel.infra.modeling_outputs.MoeCausalLMOutput(logits: Optional[Union[Array, ndarray, bool, number]] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, past_key_values: Any | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None, aux_loss: Optional[Union[Array, ndarray, bool, number]] = None, router_logits: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, all_router_losses: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None)[source]#

Bases: MaskedLMOutput

Base class for causal language modeling (CLM) outputs of MoE models.

Parameters
  • aux_loss (chex.Array, optional) – Auxiliary loss used for training MoE models.

  • router_logits (tuple(chex.Array), optional) – tp.Tuple of chex.Array (one for each layer) of shape (batch_size, sequence_length, num_experts). The logits output of the router network, which are used to compute the mixture of experts.

all_router_losses: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
aux_loss: Optional[Union[Array, ndarray, bool, number]] = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

loss: Optional[Union[Array, ndarray, bool, number]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

router_logits: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.MoeModelOutput(last_hidden_state: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, past_key_values: Any | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, router_logits: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, all_router_losses: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, logits: Union[Array, ndarray, bool, number] = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for MoE model outputs.

Parameters
  • last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model.

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • router_logits (tuple(chex.Array), optional) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, sequence_length, num_experts).

    The logits output of the router network, which are used to compute the mixture of experts.

all_router_losses: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

router_logits: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.MultipleChoiceModelOutput(logits: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for outputs of multiple choice models.

Parameters
  • logits (chex.Array of shape (batch_size, num_choices)) –

    num_choices is the second dimension of the input tensors. (see input_ids above).

    Classification scores (before SoftMax).

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.NextSentencePredictorOutput(logits: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for outputs of models predicting if two sentences are consecutive or not.

Parameters
  • logits (chex.Array of shape (batch_size, 2)) – Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.QuestionAnsweringModelOutput(start_logits: Union[Array, ndarray, bool, number] = None, end_logits: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for outputs of question answering models.

Parameters
  • start_logits (chex.Array of shape (batch_size, sequence_length)) – Span-start scores (before SoftMax).

  • end_logits (chex.Array of shape (batch_size, sequence_length)) – Span-end scores (before SoftMax).

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
end_logits: Union[Array, ndarray, bool, number] = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

start_logits: Union[Array, ndarray, bool, number] = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.SampleOutput(sequences: Union[Array, ndarray, bool, number] = None)[source]#

Bases: ModelOutput

Flax Base class for outputs of decoder-only generation models using sampling.

Parameters

sequences (chex.Array of shape (batch_size, max_length)) – The generated sequences.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

sequences: Union[Array, ndarray, bool, number] = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.Seq2SeqLMOutput(logits: Union[Array, ndarray, bool, number] = None, past_key_values: Any | None = None, decoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, decoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, encoder_last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None, encoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, encoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for sequence-to-sequence language models outputs.

Parameters
  • logits (chex.Array of shape (batch_size, sequence_length, config.vocab_size)) – Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • past_key_values (tuple(chex.Array | None)) –

    tp.Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)) and 2 additional tensors of shape (batch_size, num_heads, encoder_sequence_length, embed_size_per_head).

    Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • decoder_hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.

  • decoder_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads.

  • cross_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.

  • encoder_last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size), optional) – Sequence of hidden-states at the output of the last layer of the encoder of the model.

  • encoder_hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.

  • encoder_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads.

cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
decoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
decoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.Seq2SeqModelOutput(last_hidden_state: Union[Array, ndarray, bool, number] = None, past_key_values: Any | None = None, decoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, decoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, encoder_last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None, encoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, encoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for model encoder’s outputs that also contains : pre-computed hidden states that can speed up sequential decoding.

Parameters
  • last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size)) –

    Sequence of hidden-states at the output of the last layer of the decoder of the model.

    If past_key_values is used only the last hidden-state of the sequences of shape (batch_size, 1, hidden_size) is output.

  • past_key_values (tuple(chex.Array | None)) –

    tp.Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)) and 2 additional tensors of shape (batch_size, num_heads, encoder_sequence_length, embed_size_per_head).

    Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • decoder_hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.

  • decoder_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads.

  • cross_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.

  • encoder_last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size), optional) – Sequence of hidden-states at the output of the last layer of the encoder of the model.

  • encoder_hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.

  • encoder_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads.

cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
decoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
decoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

last_hidden_state: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.Seq2SeqQuestionAnsweringModelOutput(start_logits: Union[Array, ndarray, bool, number] = None, end_logits: Union[Array, ndarray, bool, number] = None, past_key_values: Any | None = None, decoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, decoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, encoder_last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None, encoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, encoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for outputs of sequence-to-sequence question answering models.

Parameters
  • start_logits (chex.Array of shape (batch_size, sequence_length)) – Span-start scores (before SoftMax).

  • end_logits (chex.Array of shape (batch_size, sequence_length)) – Span-end scores (before SoftMax).

  • past_key_values (tuple(chex.Array | None)) –

    tp.Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)) and 2 additional tensors of shape (batch_size, num_heads, encoder_sequence_length, embed_size_per_head).

    Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • decoder_hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.

  • decoder_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads.

  • cross_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.

  • encoder_last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size), optional) – Sequence of hidden-states at the output of the last layer of the encoder of the model.

  • encoder_hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.

  • encoder_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads.

cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
decoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
decoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None#
end_logits: Union[Array, ndarray, bool, number] = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

start_logits: Union[Array, ndarray, bool, number] = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.Seq2SeqSequenceClassifierOutput(logits: Union[Array, ndarray, bool, number] = None, past_key_values: Any | None = None, decoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, decoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, encoder_last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None, encoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, encoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for outputs of sequence-to-sequence sentence classification models.

Parameters
  • logits (chex.Array of shape (batch_size, config.num_labels)) – Classification (or regression if config.num_labels==1) scores (before SoftMax).

  • past_key_values (tuple(chex.Array | None)) –

    tp.Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)) and 2 additional tensors of shape (batch_size, num_heads, encoder_sequence_length, embed_size_per_head).

    Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • decoder_hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.

  • decoder_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads.

  • cross_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.

  • encoder_last_hidden_state (chex.Array of shape (batch_size, sequence_length, hidden_size), optional) – Sequence of hidden-states at the output of the last layer of the encoder of the model.

  • encoder_hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.

  • encoder_attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads.

cross_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
decoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
decoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
encoder_last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.SequenceClassifierOutput(logits: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, past_key_values: Any | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None, aux_loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for outputs of sentence classification models.

Parameters
  • logits (chex.Array of shape (batch_size, config.num_labels)) – Classification (or regression if config.num_labels==1) scores (before SoftMax).

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
aux_loss: Optional[Union[Array, ndarray, bool, number]] = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.TokenClassifierOutput(logits: Union[Array, ndarray, bool, number] = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for outputs of token classification models.

Parameters
  • logits (chex.Array of shape (batch_size, sequence_length, config.num_labels)) – Classification scores (before SoftMax).

  • hidden_states (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the initial embedding outputs.

  • attentions (tuple(chex.Array | None)) –

    tp.Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.modeling_outputs.VLMCausalLMOutput(logits: Union[Array, ndarray, bool, number] = None, past_key_values: Any | None = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, image_hidden_states: Optional[Union[Array, ndarray, bool, number]] = None, video_hidden_states: Optional[Union[Array, ndarray, bool, number]] = None, rope_deltas: Optional[Union[Array, ndarray, bool, number]] = None, router_logits: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, aux_loss: Optional[Union[Array, ndarray, bool, number]] = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Unified output class for Vision-Language Models (VLMs).

Provides a standardized output structure for all VLM models including LLaVA, Qwen2-VL, Qwen3-VL, Gemma3, AyaVision, Mistral3, and Llama4.

Parameters
  • logits (chex.Array of shape (batch_size, sequence_length, config.vocab_size)) – Prediction scores of the language modeling head (before SoftMax).

  • past_key_values (TransformerCache, optional) – Pre-computed hidden-states (key and values in attention blocks) for efficient autoregressive generation.

  • hidden_states (tuple(chex.Array), optional) – Tuple of hidden-states at output of each layer plus embeddings. Shape: (batch_size, sequence_length, hidden_size).

  • last_hidden_state (chex.Array, optional) – Hidden-state at output of the last layer. Shape: (batch_size, sequence_length, hidden_size).

  • attentions (tuple(chex.Array), optional) – Attention weights after softmax. Shape: (batch_size, num_heads, sequence_length, sequence_length).

  • image_hidden_states (chex.Array, optional) – Projected image features from the vision encoder after the multimodal projector. Shape varies by model.

  • video_hidden_states (chex.Array, optional) – Projected video features for models supporting video input (Qwen2-VL, Qwen3-VL, Llama4). Shape varies by model.

  • rope_deltas (chex.Array, optional) – Position embedding deltas for multi-dimensional RoPE (mRoPE) used in Qwen2-VL and Qwen3-VL models.

  • router_logits (tuple(chex.Array), optional) – Router logits for MoE VLMs (Qwen3-VL-MoE). Shape: (batch_size, sequence_length, num_experts).

  • aux_loss (chex.Array, optional) – Auxiliary loss for MoE load balancing.

  • loss (chex.Array, optional) – Language modeling loss when labels are provided.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
aux_loss: Optional[Union[Array, ndarray, bool, number]] = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
image_hidden_states: Optional[Union[Array, ndarray, bool, number]] = None#
last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Any | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

rope_deltas: Optional[Union[Array, ndarray, bool, number]] = None#
router_logits: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

video_hidden_states: Optional[Union[Array, ndarray, bool, number]] = None#