easydel.infra.loss_utils#
- easydel.infra.loss_utils.ForCausalLMLoss(logits: Array, labels: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, num_items_in_batch: Optional[int] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#
Jax implementation of loss function for causal language models.
- Parameters
logits โ Predicted logits, shape (batch_size, seq_len, vocab_size).
labels โ True labels, shape (batch_size, seq_len). Must be integers.
num_items_in_batch โ tp.Optional, used when reduction should be sum.
batch โ tp.Optional batch for dynamic loss normalization
**kwargs โ Additional keyword arguments for the cross-entropy loss.
- Returns
The computed causal language modeling loss.
- easydel.infra.loss_utils.ForQuestionAnsweringLoss(start_logits: Array, end_logits: Array, start_positions: Array, end_positions: Array, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#
Jax implementation of loss function for question answering.
- Parameters
start_logits โ Predicted start logits, shape (batch_size, seq_len).
end_logits โ Predicted end logits, shape (batch_size, seq_len).
start_positions โ True start positions, shape (batch_size,).
end_positions โ True end positions, shape (batch_size,).
batch โ tp.Optional batch for dynamic loss normalization
**kwargs โ Additional keyword arguments for the cross-entropy loss.
- Returns
The computed question answering loss.
- easydel.infra.loss_utils.ForSequenceClassificationLoss(logits: Array, labels: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#
Jax implementation of loss function for sequence classification.
- Parameters
labels โ True labels, shape (batch_size,) or (batch_size, num_labels) for multi label classification.
logits โ Predicted logits, shape (batch_size, num_labels) or (batch_size, 1) or (batch_size,) for regression.
config โ Configuration with problem_type and num_labels attributes.
batch โ tp.Optional batch for dynamic loss normalization
**kwargs โ Additional keyword arguments for the cross-entropy loss.
- Returns
The computed sequence classification loss.
- easydel.infra.loss_utils.ForTokenClassification(logits: Array, labels: Array, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#
Jax implementation of loss function for token classification.
- Parameters
logits โ Predicted logits, shape (batch_size, seq_len, num_labels).
labels โ True labels, shape (batch_size, seq_len). Must be integers.
config โ Configuration with num_labels attribute.
label_smoothing โ Label smoothing factor.
z_loss โ Coefficient for the auxiliary z-loss term.
loss_normalizing_factor โ A factor to normalize the loss, can also be enum.
batch โ tp.Optional batch for dynamic loss normalization
**kwargs โ Additional keyword arguments for the cross-entropy loss.
- Returns
The computed token classification loss.
- class easydel.infra.loss_utils.LossConfig(ignore_index: int = -100, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS', num_labels: Optional[str] = None, problem_type: Optional[str] = None, divide_weight_sum: bool = False, shift_tokens: bool = True, break_on_nan: bool = True, reduction: Optional[Literal['none', 'mean', 'sum']] = None, num_classification_labels: Optional[int] = None, classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None)[source]#
Bases:
objectConfiguration class for customizing loss computation behavior.
- ignore_index#
Specifies a target value that is ignored and does not contribute to the loss. Defaults to -100.
- Type
int
- label_smoothing#
Amount of label smoothing to apply. 0.0 means no smoothing. Defaults to 0.0.
- Type
float
- z_loss#
Coefficient for the z-loss regularization term, which encourages logits for non-target classes to be small. Defaults to 0.0.
- Type
float
- loss_normalizing_factor#
How to normalize the loss. Can be a constant float/int, a string representation of a SpecialLossNormalizingFactor enum, or the enum itself. Defaults to โNUM_REAL_TARGET_TOKENSโ.
- Type
FACTOR_TYPE
- num_labels#
The number of labels for classification tasks. Used in ForSequenceClassificationLoss. Defaults to None.
- Type
tp.Optional[int]
- problem_type#
Specifies the problem type for sequence classification (e.g., โsingle_label_classificationโ, โmulti_label_classificationโ). Defaults to None.
- Type
tp.Optional[str]
- divide_weight_sum#
If True, divides the loss by the sum of weights, in addition to the loss_normalizing_factor. Defaults to False.
- Type
bool
- shift_tokens#
If True (typically for Causal LM), shifts the logits and labels so that the model predicts the next token. Defaults to True.
- Type
bool
- break_on_nan#
If True, raises an EasyDeLBreakRequest if a NaN is encountered during loss computation. Defaults to True.
- Type
bool
- reduction#
Specifies the reduction to apply to the loss. If None, the default reduction of the specific loss function is used. Defaults to None.
- Type
tp.Optional[tp.Literal[โnoneโ, โmeanโ, โsumโ]]
- num_classification_labels#
Number of labels specifically for sequence classification. Alias for num_labels. Defaults to None.
- Type
tp.Optional[int]
- classification_problem_type#
Problem type specifically for sequence classification. Alias for problem_type. Defaults to None.
- Type
tp.Optional[tp.Literal[โregressionโ, โsingle_label_classificationโ, โmulti_label_classificationโ]]
- break_on_nan: bool = True#
- classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None#
- divide_weight_sum: bool = False#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- ignore_index: int = -100#
- label_smoothing: float = 0.0#
- loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS'#
- num_classification_labels: Optional[int] = None#
- num_labels: Optional[str] = None#
- problem_type: Optional[str] = None#
- reduction: Optional[Literal['none', 'mean', 'sum']] = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- shift_tokens: bool = True#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- z_loss: float = 0.0#
- class easydel.infra.loss_utils.LossMetrics(loss: Optional[Union[float, Array, ndarray, bool, number]] = None, z_loss: Optional[Union[float, Array, ndarray, bool, number]] = None, weight_sum: Optional[Union[float, Array, ndarray, bool, number]] = None, accuracy: Optional[Union[float, Array, ndarray, bool, number]] = None, learning_rate: Optional[Union[float, Array, ndarray, bool, number]] = None, max_grad_norm: Optional[PyTreeNode] = None, mean_grad_norm: Optional[PyTreeNode] = None, grad_norms: Optional[PyTreeNode] = None, chosen_rewards: Optional[Array] = None, rejected_rewards: Optional[Array] = None, other_metrics: Optional[Mapping[str, Array]] = None, execution_time: Optional[float] = None)[source]#
Bases:
objectContainer for various metrics related to loss computation and model training.
- loss#
The primary computed loss value.
- Type
tp.Optional[tp.Union[float, chex.Array]]
- z_loss#
The computed z-loss regularization term.
- Type
tp.Optional[tp.Union[float, chex.Array]]
- weight_sum#
The sum of weights used in the loss calculation.
- Type
tp.Optional[tp.Union[float, chex.Array]]
- accuracy#
Computed accuracy, if applicable.
- Type
tp.Optional[tp.Union[float, chex.Array]]
- learning_rate#
The learning rate used for the current step.
- Type
tp.Optional[tp.Union[float, chex.Array]]
- max_grad_norm#
The maximum gradient norm observed.
- Type
tp.Optional[flax.struct.PyTreeNode]
- mean_grad_norm#
The mean gradient norm observed.
- Type
tp.Optional[flax.struct.PyTreeNode]
- grad_norms#
A pytree containing the norms of gradients for each parameter.
- Type
tp.Optional[flax.struct.PyTreeNode]
- chosen_rewards#
Rewards for the chosen sequence in preference-based tasks.
- Type
tp.Optional[jax.Array]
- rejected_rewards#
Rewards for the rejected sequence in preference-based tasks.
- Type
tp.Optional[jax.Array]
- other_metrics#
A dictionary for any additional custom metrics.
- Type
tp.Optional[tp.Mapping[str, jax.Array]]
- execution_time#
Time taken for the computation step.
- Type
tp.Optional[float]
- execution_time: Optional[float] = None#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- grad_norms: Optional[PyTreeNode] = None#
- max_grad_norm: Optional[PyTreeNode] = None#
- mean_grad_norm: Optional[PyTreeNode] = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.infra.loss_utils.SpecialLossNormalizingFactor(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
EnumSpecifies special, dynamically calculated loss normalizing factors.
These enums are used in loss functions to indicate how the loss should be normalized based on properties of the input batch, rather than using a fixed constant.
- NO_WEIGHT_NUM_REAL_TARGET_TOKENS#
Divides the loss by the number of non-padding target tokens, ignoring any provided loss weights.
- NUM_REAL_TARGET_TOKENS#
Divides the loss by the number of non-padding target tokens, considering provided loss weights.
- NUM_TOTAL_TARGET_TOKENS#
Divides the loss by the total number of target tokens, including padding.
- AVERAGE_PER_SEQUENCE#
Computes the average loss per sequence in the batch.
- AVERAGE_PER_SEQUENCE = 3#
- NO_WEIGHT_NUM_REAL_TARGET_TOKENS = 0#
- NUM_REAL_TARGET_TOKENS = 1#
- NUM_TOTAL_TARGET_TOKENS = 2#
- easydel.infra.loss_utils.auxiliary_load_balancing_loss_func(gate_logits: Union[Array, ndarray, bool, number, Tuple[Union[Array, ndarray, bool, number], ...]], num_experts: int, top_k: int, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None) Union[Array, int][source]#
Computes auxiliary load balancing loss as in Switch Transformer - implemented in JAX.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced.
- Parameters
gate_logits โ Logits from the gate. Should be a tuple/list of JAX arrays, with each array corresponding to a layer and having shape [batch_size * sequence_length, num_experts]. Alternatively, can be a single stacked array of shape [num_layers * batch_size * sequence_length, num_experts].
num_experts โ Number of experts. Must be provided if gate_logits is not None.
top_k โ The number of experts to route per-token, can be also interpreted as the top-k routing parameter.
attention_mask (jax.numpy.ndarray, optional) โ The attention_mask used in forward function shape [batch_size, sequence_length] if not None.
- Returns
The auxiliary loss as a JAX scalar array, or 0 if gate_logits is None.
- easydel.infra.loss_utils.compute_weighted_cross_entropy(logits: Union[Array, ndarray, bool, number], targets: Union[Array, ndarray, bool, number], weights: Optional[Union[Array, ndarray, bool, number]] = None, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[float] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#
Computes weighted cross-entropy loss, z-loss, and weight sum.
- Parameters
logits โ The predicted logits.
targets โ The target class labels (integers).
weights โ tp.Optional weights for each example.
label_smoothing โ Label smoothing factor.
z_loss โ Coefficient for the auxiliary z-loss term.
loss_normalizing_factor โ A factor to normalize the loss.
- Returns
A tuple containing the total loss, z-loss, and sum of weights.
- easydel.infra.loss_utils.compute_weighted_cross_entropy_and_accuracy(logits: Union[Array, ndarray, bool, number], targets: Union[Array, ndarray, bool, number], weights: Optional[Union[Array, ndarray, bool, number]] = None, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[float] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#
Computes weighted cross-entropy loss, z-loss, weight sum, and accuracy.
- Parameters
logits โ The predicted logits.
targets โ The target class labels (integers).
weights โ tp.Optional weights for each example.
label_smoothing โ Label smoothing factor.
z_loss โ Coefficient for the auxiliary z-loss term.
loss_normalizing_factor โ A factor to normalize the loss.
- Returns
A tuple containing the total loss, z-loss, sum of weights, and accuracy.
- easydel.infra.loss_utils.convert_special_loss_normalizing_factor_to_enum(x: str) SpecialLossNormalizingFactor[source]#
Converts a stringified version of SpecialLossNormalizingFactor to an enum.
- Parameters
x โ Stringified version of the enum value.
- Returns
The corresponding SpecialLossNormalizingFactor enum value.
- easydel.infra.loss_utils.dynamic_cross_entropy_loss(logits: Array, targets: Array, weight: Optional[Array] = None, ignore_index: int = -100, reduction: str = 'mean', label_smoothing: float = 0.0) Tuple[Array, Array][source]#
Computes the cross-entropy loss with optional label smoothing and ignore index, dynamically handling different reduction types.
- Parameters
logits (jnp.ndarray) โ The predicted logits from the model (batch_size, โฆ, num_classes).
targets (jnp.ndarray) โ The target labels (batch_size, โฆ).
weight (tp.Optional[jnp.ndarray]) โ Optional weights for each element (batch_size, โฆ). Defaults to None.
ignore_index (int) โ Index in the target labels to ignore. Defaults to -100.
reduction (str) โ Specifies the reduction method: โmeanโ, โsumโ, or โnoneโ. Defaults to โmeanโ.
label_smoothing (float) โ The amount of label smoothing to apply (0.0 means no smoothing). Defaults to 0.0.
- Returns
The computed loss (scalar if reduction is โmeanโ or โsumโ, array otherwise).
The normalization factor (sum of weights or count of non-ignored elements).
- Return type
tp.Tuple[jnp.ndarray, jnp.ndarray]
- Raises
ValueError โ If an invalid reduction method is specified.
- easydel.infra.loss_utils.fixed_cross_entropy(source: Array, target: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, num_items_in_batch: Optional[int] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#
Jax implementation of fixed cross-entropy loss with z-loss, label smoothing, masking.
- Parameters
source โ Predicted logits, shape (batch_size, num_classes) or (batch_size * seq_len, num_classes).
target โ True labels, shape (batch_size,) or (batch_size * seq_len,). Must be integers.
num_items_in_batch โ tp.Optional, used when reduction should be sum.
attention_mask โ tp.Optional, boolean mask applied to the loss.
batch โ tp.Optional batch for dynamic loss normalization
**kwargs โ Additional keyword arguments.
- Returns
The computed cross-entropy loss in LossMetrics.
- easydel.infra.loss_utils.get_loss_normalizing_factor_and_weights(loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]], batch: Mapping[str, Union[Array, ndarray, bool, number]]) Tuple[Optional[float], Optional[Union[Array, ndarray, bool, number]]][source]#
Gets the loss normalizing factor and weights from a batch of data.
- Parameters
loss_normalizing_factor โ The loss normalizing factor to use.
batch โ A dictionary containing the input batch of data.
- Returns
A tuple containing the loss normalizing factor and loss weights.
- easydel.infra.loss_utils.onehot(labels, num_classes, on_value=1.0, off_value=0.0)[source]#
Creates one-hot encoded versions of integer labels.
- Parameters
labels (jnp.ndarray) โ An array of integer labels.
num_classes (int) โ The total number of classes.
on_value (float) โ The value to use for the โonโ state (corresponding to the label). Defaults to 1.0.
off_value (float) โ The value to use for the โoffโ states. Defaults to 0.0.
- Returns
The one-hot encoded array with shape labels.shape + (num_classes,).
- Return type
jnp.ndarray
- easydel.infra.loss_utils.sigmoid_cross_entropy_with_logits(logits: Array, labels: Array, weights: Optional[Array] = None, label_smoothing: float = 0.0, axis: Optional[Union[int, tuple]] = None) Array[source]#
Computes sigmoid cross-entropy loss between logits and labels.
Measures the probability error in discrete classification tasks in which each class is independent and not mutually exclusive. For instance, one could perform multilabel classification where a picture can contain both an elephant and a dog at the same time.
- Parameters
logits โ The predicted logits from the model.
labels โ The target labels.
weights (tp.Optional[jnp.ndarray]) โ Optional weights for the loss computation. Defaults to None.
label_smoothing (float) โ Amount of label smoothing to apply (0.0 means no smoothing). Defaults to 0.0.
axis (tp.Optional[tp.Union[int, tuple]]) โ The axis or axes along which to compute the mean. If None, the mean is computed over all elements. Defaults to None.
- Returns
The computed sigmoid cross-entropy loss.
- Return type
jnp.ndarray