easydel.infra.loss_utils

Contents

easydel.infra.loss_utils#

easydel.infra.loss_utils.ForCausalLMLoss(logits: Array, labels: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, num_items_in_batch: Optional[int] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#

Jax implementation of loss function for causal language models.

Parameters
  • logits โ€“ Predicted logits, shape (batch_size, seq_len, vocab_size).

  • labels โ€“ True labels, shape (batch_size, seq_len). Must be integers.

  • num_items_in_batch โ€“ tp.Optional, used when reduction should be sum.

  • batch โ€“ tp.Optional batch for dynamic loss normalization

  • **kwargs โ€“ Additional keyword arguments for the cross-entropy loss.

Returns

The computed causal language modeling loss.

easydel.infra.loss_utils.ForQuestionAnsweringLoss(start_logits: Array, end_logits: Array, start_positions: Array, end_positions: Array, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#

Jax implementation of loss function for question answering.

Parameters
  • start_logits โ€“ Predicted start logits, shape (batch_size, seq_len).

  • end_logits โ€“ Predicted end logits, shape (batch_size, seq_len).

  • start_positions โ€“ True start positions, shape (batch_size,).

  • end_positions โ€“ True end positions, shape (batch_size,).

  • batch โ€“ tp.Optional batch for dynamic loss normalization

  • **kwargs โ€“ Additional keyword arguments for the cross-entropy loss.

Returns

The computed question answering loss.

easydel.infra.loss_utils.ForSequenceClassificationLoss(logits: Array, labels: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#

Jax implementation of loss function for sequence classification.

Parameters
  • labels โ€“ True labels, shape (batch_size,) or (batch_size, num_labels) for multi label classification.

  • logits โ€“ Predicted logits, shape (batch_size, num_labels) or (batch_size, 1) or (batch_size,) for regression.

  • config โ€“ Configuration with problem_type and num_labels attributes.

  • batch โ€“ tp.Optional batch for dynamic loss normalization

  • **kwargs โ€“ Additional keyword arguments for the cross-entropy loss.

Returns

The computed sequence classification loss.

easydel.infra.loss_utils.ForTokenClassification(logits: Array, labels: Array, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#

Jax implementation of loss function for token classification.

Parameters
  • logits โ€“ Predicted logits, shape (batch_size, seq_len, num_labels).

  • labels โ€“ True labels, shape (batch_size, seq_len). Must be integers.

  • config โ€“ Configuration with num_labels attribute.

  • label_smoothing โ€“ Label smoothing factor.

  • z_loss โ€“ Coefficient for the auxiliary z-loss term.

  • loss_normalizing_factor โ€“ A factor to normalize the loss, can also be enum.

  • batch โ€“ tp.Optional batch for dynamic loss normalization

  • **kwargs โ€“ Additional keyword arguments for the cross-entropy loss.

Returns

The computed token classification loss.

class easydel.infra.loss_utils.LossConfig(ignore_index: int = -100, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS', num_labels: Optional[str] = None, problem_type: Optional[str] = None, divide_weight_sum: bool = False, shift_tokens: bool = True, break_on_nan: bool = True, reduction: Optional[Literal['none', 'mean', 'sum']] = None, num_classification_labels: Optional[int] = None, classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None)[source]#

Bases: object

Configuration class for customizing loss computation behavior.

ignore_index#

Specifies a target value that is ignored and does not contribute to the loss. Defaults to -100.

Type

int

label_smoothing#

Amount of label smoothing to apply. 0.0 means no smoothing. Defaults to 0.0.

Type

float

z_loss#

Coefficient for the z-loss regularization term, which encourages logits for non-target classes to be small. Defaults to 0.0.

Type

float

loss_normalizing_factor#

How to normalize the loss. Can be a constant float/int, a string representation of a SpecialLossNormalizingFactor enum, or the enum itself. Defaults to โ€œNUM_REAL_TARGET_TOKENSโ€.

Type

FACTOR_TYPE

num_labels#

The number of labels for classification tasks. Used in ForSequenceClassificationLoss. Defaults to None.

Type

tp.Optional[int]

problem_type#

Specifies the problem type for sequence classification (e.g., โ€œsingle_label_classificationโ€, โ€œmulti_label_classificationโ€). Defaults to None.

Type

tp.Optional[str]

divide_weight_sum#

If True, divides the loss by the sum of weights, in addition to the loss_normalizing_factor. Defaults to False.

Type

bool

shift_tokens#

If True (typically for Causal LM), shifts the logits and labels so that the model predicts the next token. Defaults to True.

Type

bool

break_on_nan#

If True, raises an EasyDeLBreakRequest if a NaN is encountered during loss computation. Defaults to True.

Type

bool

reduction#

Specifies the reduction to apply to the loss. If None, the default reduction of the specific loss function is used. Defaults to None.

Type

tp.Optional[tp.Literal[โ€œnoneโ€, โ€œmeanโ€, โ€œsumโ€]]

num_classification_labels#

Number of labels specifically for sequence classification. Alias for num_labels. Defaults to None.

Type

tp.Optional[int]

classification_problem_type#

Problem type specifically for sequence classification. Alias for problem_type. Defaults to None.

Type

tp.Optional[tp.Literal[โ€œregressionโ€, โ€œsingle_label_classificationโ€, โ€œmulti_label_classificationโ€]]

break_on_nan: bool = True#
classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None#
divide_weight_sum: bool = False#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

ignore_index: int = -100#
label_smoothing: float = 0.0#
loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS'#
num_classification_labels: Optional[int] = None#
num_labels: Optional[str] = None#
problem_type: Optional[str] = None#
reduction: Optional[Literal['none', 'mean', 'sum']] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

shift_tokens: bool = True#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

z_loss: float = 0.0#
class easydel.infra.loss_utils.LossMetrics(loss: Optional[Union[float, Array, ndarray, bool, number]] = None, z_loss: Optional[Union[float, Array, ndarray, bool, number]] = None, weight_sum: Optional[Union[float, Array, ndarray, bool, number]] = None, accuracy: Optional[Union[float, Array, ndarray, bool, number]] = None, learning_rate: Optional[Union[float, Array, ndarray, bool, number]] = None, max_grad_norm: Optional[PyTreeNode] = None, mean_grad_norm: Optional[PyTreeNode] = None, grad_norms: Optional[PyTreeNode] = None, chosen_rewards: Optional[Array] = None, rejected_rewards: Optional[Array] = None, other_metrics: Optional[Mapping[str, Array]] = None, execution_time: Optional[float] = None)[source]#

Bases: object

Container for various metrics related to loss computation and model training.

loss#

The primary computed loss value.

Type

tp.Optional[tp.Union[float, chex.Array]]

z_loss#

The computed z-loss regularization term.

Type

tp.Optional[tp.Union[float, chex.Array]]

weight_sum#

The sum of weights used in the loss calculation.

Type

tp.Optional[tp.Union[float, chex.Array]]

accuracy#

Computed accuracy, if applicable.

Type

tp.Optional[tp.Union[float, chex.Array]]

learning_rate#

The learning rate used for the current step.

Type

tp.Optional[tp.Union[float, chex.Array]]

max_grad_norm#

The maximum gradient norm observed.

Type

tp.Optional[flax.struct.PyTreeNode]

mean_grad_norm#

The mean gradient norm observed.

Type

tp.Optional[flax.struct.PyTreeNode]

grad_norms#

A pytree containing the norms of gradients for each parameter.

Type

tp.Optional[flax.struct.PyTreeNode]

chosen_rewards#

Rewards for the chosen sequence in preference-based tasks.

Type

tp.Optional[jax.Array]

rejected_rewards#

Rewards for the rejected sequence in preference-based tasks.

Type

tp.Optional[jax.Array]

other_metrics#

A dictionary for any additional custom metrics.

Type

tp.Optional[tp.Mapping[str, jax.Array]]

execution_time#

Time taken for the computation step.

Type

tp.Optional[float]

accuracy: Optional[Union[float, Array, ndarray, bool, number]] = None#
chosen_rewards: Optional[Array] = None#
execution_time: Optional[float] = None#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

grad_norms: Optional[PyTreeNode] = None#
learning_rate: Optional[Union[float, Array, ndarray, bool, number]] = None#
loss: Optional[Union[float, Array, ndarray, bool, number]] = None#
max_grad_norm: Optional[PyTreeNode] = None#
mean_grad_norm: Optional[PyTreeNode] = None#
other_metrics: Optional[Mapping[str, Array]] = None#
rejected_rewards: Optional[Array] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

weight_sum: Optional[Union[float, Array, ndarray, bool, number]] = None#
z_loss: Optional[Union[float, Array, ndarray, bool, number]] = None#
class easydel.infra.loss_utils.SpecialLossNormalizingFactor(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: Enum

Specifies special, dynamically calculated loss normalizing factors.

These enums are used in loss functions to indicate how the loss should be normalized based on properties of the input batch, rather than using a fixed constant.

NO_WEIGHT_NUM_REAL_TARGET_TOKENS#

Divides the loss by the number of non-padding target tokens, ignoring any provided loss weights.

NUM_REAL_TARGET_TOKENS#

Divides the loss by the number of non-padding target tokens, considering provided loss weights.

NUM_TOTAL_TARGET_TOKENS#

Divides the loss by the total number of target tokens, including padding.

AVERAGE_PER_SEQUENCE#

Computes the average loss per sequence in the batch.

AVERAGE_PER_SEQUENCE = 3#
NO_WEIGHT_NUM_REAL_TARGET_TOKENS = 0#
NUM_REAL_TARGET_TOKENS = 1#
NUM_TOTAL_TARGET_TOKENS = 2#
easydel.infra.loss_utils.auxiliary_load_balancing_loss_func(gate_logits: Union[Array, ndarray, bool, number, Tuple[Union[Array, ndarray, bool, number], ...]], num_experts: int, top_k: int, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None) Union[Array, int][source]#

Computes auxiliary load balancing loss as in Switch Transformer - implemented in JAX.

See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced.

Parameters
  • gate_logits โ€“ Logits from the gate. Should be a tuple/list of JAX arrays, with each array corresponding to a layer and having shape [batch_size * sequence_length, num_experts]. Alternatively, can be a single stacked array of shape [num_layers * batch_size * sequence_length, num_experts].

  • num_experts โ€“ Number of experts. Must be provided if gate_logits is not None.

  • top_k โ€“ The number of experts to route per-token, can be also interpreted as the top-k routing parameter.

  • attention_mask (jax.numpy.ndarray, optional) โ€“ The attention_mask used in forward function shape [batch_size, sequence_length] if not None.

Returns

The auxiliary loss as a JAX scalar array, or 0 if gate_logits is None.

easydel.infra.loss_utils.compute_weighted_cross_entropy(logits: Union[Array, ndarray, bool, number], targets: Union[Array, ndarray, bool, number], weights: Optional[Union[Array, ndarray, bool, number]] = None, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[float] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#

Computes weighted cross-entropy loss, z-loss, and weight sum.

Parameters
  • logits โ€“ The predicted logits.

  • targets โ€“ The target class labels (integers).

  • weights โ€“ tp.Optional weights for each example.

  • label_smoothing โ€“ Label smoothing factor.

  • z_loss โ€“ Coefficient for the auxiliary z-loss term.

  • loss_normalizing_factor โ€“ A factor to normalize the loss.

Returns

A tuple containing the total loss, z-loss, and sum of weights.

easydel.infra.loss_utils.compute_weighted_cross_entropy_and_accuracy(logits: Union[Array, ndarray, bool, number], targets: Union[Array, ndarray, bool, number], weights: Optional[Union[Array, ndarray, bool, number]] = None, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[float] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#

Computes weighted cross-entropy loss, z-loss, weight sum, and accuracy.

Parameters
  • logits โ€“ The predicted logits.

  • targets โ€“ The target class labels (integers).

  • weights โ€“ tp.Optional weights for each example.

  • label_smoothing โ€“ Label smoothing factor.

  • z_loss โ€“ Coefficient for the auxiliary z-loss term.

  • loss_normalizing_factor โ€“ A factor to normalize the loss.

Returns

A tuple containing the total loss, z-loss, sum of weights, and accuracy.

easydel.infra.loss_utils.convert_special_loss_normalizing_factor_to_enum(x: str) SpecialLossNormalizingFactor[source]#

Converts a stringified version of SpecialLossNormalizingFactor to an enum.

Parameters

x โ€“ Stringified version of the enum value.

Returns

The corresponding SpecialLossNormalizingFactor enum value.

easydel.infra.loss_utils.cross_entropy_loss_and_accuracy(source, target, valid=None)[source]#
easydel.infra.loss_utils.dynamic_cross_entropy_loss(logits: Array, targets: Array, weight: Optional[Array] = None, ignore_index: int = -100, reduction: str = 'mean', label_smoothing: float = 0.0) Tuple[Array, Array][source]#

Computes the cross-entropy loss with optional label smoothing and ignore index, dynamically handling different reduction types.

Parameters
  • logits (jnp.ndarray) โ€“ The predicted logits from the model (batch_size, โ€ฆ, num_classes).

  • targets (jnp.ndarray) โ€“ The target labels (batch_size, โ€ฆ).

  • weight (tp.Optional[jnp.ndarray]) โ€“ Optional weights for each element (batch_size, โ€ฆ). Defaults to None.

  • ignore_index (int) โ€“ Index in the target labels to ignore. Defaults to -100.

  • reduction (str) โ€“ Specifies the reduction method: โ€˜meanโ€™, โ€˜sumโ€™, or โ€˜noneโ€™. Defaults to โ€œmeanโ€.

  • label_smoothing (float) โ€“ The amount of label smoothing to apply (0.0 means no smoothing). Defaults to 0.0.

Returns

  • The computed loss (scalar if reduction is โ€˜meanโ€™ or โ€˜sumโ€™, array otherwise).

  • The normalization factor (sum of weights or count of non-ignored elements).

Return type

tp.Tuple[jnp.ndarray, jnp.ndarray]

Raises

ValueError โ€“ If an invalid reduction method is specified.

easydel.infra.loss_utils.fixed_cross_entropy(source: Array, target: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, num_items_in_batch: Optional[int] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#

Jax implementation of fixed cross-entropy loss with z-loss, label smoothing, masking.

Parameters
  • source โ€“ Predicted logits, shape (batch_size, num_classes) or (batch_size * seq_len, num_classes).

  • target โ€“ True labels, shape (batch_size,) or (batch_size * seq_len,). Must be integers.

  • num_items_in_batch โ€“ tp.Optional, used when reduction should be sum.

  • attention_mask โ€“ tp.Optional, boolean mask applied to the loss.

  • batch โ€“ tp.Optional batch for dynamic loss normalization

  • **kwargs โ€“ Additional keyword arguments.

Returns

The computed cross-entropy loss in LossMetrics.

easydel.infra.loss_utils.get_loss_normalizing_factor_and_weights(loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]], batch: Mapping[str, Union[Array, ndarray, bool, number]]) Tuple[Optional[float], Optional[Union[Array, ndarray, bool, number]]][source]#

Gets the loss normalizing factor and weights from a batch of data.

Parameters
  • loss_normalizing_factor โ€“ The loss normalizing factor to use.

  • batch โ€“ A dictionary containing the input batch of data.

Returns

A tuple containing the loss normalizing factor and loss weights.

easydel.infra.loss_utils.onehot(labels, num_classes, on_value=1.0, off_value=0.0)[source]#

Creates one-hot encoded versions of integer labels.

Parameters
  • labels (jnp.ndarray) โ€“ An array of integer labels.

  • num_classes (int) โ€“ The total number of classes.

  • on_value (float) โ€“ The value to use for the โ€œonโ€ state (corresponding to the label). Defaults to 1.0.

  • off_value (float) โ€“ The value to use for the โ€œoffโ€ states. Defaults to 0.0.

Returns

The one-hot encoded array with shape labels.shape + (num_classes,).

Return type

jnp.ndarray

easydel.infra.loss_utils.sigmoid_cross_entropy_with_logits(logits: Array, labels: Array, weights: Optional[Array] = None, label_smoothing: float = 0.0, axis: Optional[Union[int, tuple]] = None) Array[source]#

Computes sigmoid cross-entropy loss between logits and labels.

Measures the probability error in discrete classification tasks in which each class is independent and not mutually exclusive. For instance, one could perform multilabel classification where a picture can contain both an elephant and a dog at the same time.

Parameters
  • logits โ€“ The predicted logits from the model.

  • labels โ€“ The target labels.

  • weights (tp.Optional[jnp.ndarray]) โ€“ Optional weights for the loss computation. Defaults to None.

  • label_smoothing (float) โ€“ Amount of label smoothing to apply (0.0 means no smoothing). Defaults to 0.0.

  • axis (tp.Optional[tp.Union[int, tuple]]) โ€“ The axis or axes along which to compute the mean. If None, the mean is computed over all elements. Defaults to None.

Returns

The computed sigmoid cross-entropy loss.

Return type

jnp.ndarray