easydel.infra.loss_utils

Contents

easydel.infra.loss_utils#

easydel.infra.loss_utils.ForCausalLMLoss(logits: Array, labels: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, num_items_in_batch: Optional[int] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#

Jax implementation of loss function for causal language models.

Parameters
  • logits โ€“ Predicted logits, shape (batch_size, seq_len, vocab_size).

  • labels โ€“ True labels, shape (batch_size, seq_len). Must be integers.

  • num_items_in_batch โ€“ tp.Optional, used when reduction should be sum.

  • batch โ€“ tp.Optional batch for dynamic loss normalization

  • **kwargs โ€“ Additional keyword arguments for the cross-entropy loss.

Returns

The computed causal language modeling loss.

easydel.infra.loss_utils.ForQuestionAnsweringLoss(start_logits: Array, end_logits: Array, start_positions: Array, end_positions: Array, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#

Jax implementation of loss function for question answering.

Parameters
  • start_logits โ€“ Predicted start logits, shape (batch_size, seq_len).

  • end_logits โ€“ Predicted end logits, shape (batch_size, seq_len).

  • start_positions โ€“ True start positions, shape (batch_size,).

  • end_positions โ€“ True end positions, shape (batch_size,).

  • batch โ€“ tp.Optional batch for dynamic loss normalization

  • **kwargs โ€“ Additional keyword arguments for the cross-entropy loss.

Returns

The computed question answering loss.

easydel.infra.loss_utils.ForSequenceClassificationLoss(logits: Array, labels: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#

Jax implementation of loss function for sequence classification.

Parameters
  • labels โ€“ True labels, shape (batch_size,) or (batch_size, num_labels) for multi label classification.

  • logits โ€“ Predicted logits, shape (batch_size, num_labels) or (batch_size, 1) or (batch_size,) for regression.

  • config โ€“ Configuration with problem_type and num_labels attributes.

  • batch โ€“ tp.Optional batch for dynamic loss normalization

  • **kwargs โ€“ Additional keyword arguments for the cross-entropy loss.

Returns

The computed sequence classification loss.

easydel.infra.loss_utils.ForTokenClassification(logits: Array, labels: Array, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#

Jax implementation of loss function for token classification.

Parameters
  • logits โ€“ Predicted logits, shape (batch_size, seq_len, num_labels).

  • labels โ€“ True labels, shape (batch_size, seq_len). Must be integers.

  • config โ€“ Configuration with num_labels attribute.

  • label_smoothing โ€“ Label smoothing factor.

  • z_loss โ€“ Coefficient for the auxiliary z-loss term.

  • loss_normalizing_factor โ€“ A factor to normalize the loss, can also be enum.

  • batch โ€“ tp.Optional batch for dynamic loss normalization

  • **kwargs โ€“ Additional keyword arguments for the cross-entropy loss.

Returns

The computed token classification loss.

class easydel.infra.loss_utils.LossConfig(ignore_index: int = -100, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Union[float, int, str, easydel.infra.loss_utils.SpecialLossNormalizingFactor, NoneType] = 'NUM_REAL_TARGET_TOKENS', num_labels: Optional[str] = None, problem_type: Optional[str] = None, divide_weight_sum: bool = False, shift_tokens: bool = True, break_on_nan: bool = True, reduction: Optional[Literal['none', 'mean', 'sum']] = None, num_classification_labels: Optional[int] = None, classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None)[source]#

Bases: Mapping

break_on_nan: bool = True#
classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None#
divide_weight_sum: bool = False#
from_tuple()#
ignore_index: int = -100#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
label_smoothing: float = 0.0#
loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS'#
num_classification_labels: Optional[int] = None#
num_labels: Optional[str] = None#
problem_type: Optional[str] = None#
reduction: Optional[Literal['none', 'mean', 'sum']] = None#
replace(**kwargs)#
shift_tokens: bool = True#
to_tuple()#
values() an object providing a view on D's values#
z_loss: float = 0.0#
class easydel.infra.loss_utils.LossMetrics(loss: Union[float, jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, z_loss: Union[float, jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, weight_sum: Union[float, jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, accuracy: Union[float, jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, learning_rate: Union[float, jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, max_grad_norm: Optional[flax.struct.PyTreeNode] = None, mean_grad_norm: Optional[flax.struct.PyTreeNode] = None, grad_norms: Optional[flax.struct.PyTreeNode] = None, chosen_rewards: Optional[jax.Array] = None, rejected_rewards: Optional[jax.Array] = None, other_metrics: Optional[Mapping[str, jax.Array]] = None, execution_time: Optional[float] = None)[source]#

Bases: Mapping

accuracy: Optional[Union[float, Array, ndarray, bool, number]] = None#
chosen_rewards: Optional[Array] = None#
execution_time: Optional[float] = None#
from_tuple()#
grad_norms: Optional[PyTreeNode] = None#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
learning_rate: Optional[Union[float, Array, ndarray, bool, number]] = None#
loss: Optional[Union[float, Array, ndarray, bool, number]] = None#
max_grad_norm: Optional[PyTreeNode] = None#
mean_grad_norm: Optional[PyTreeNode] = None#
other_metrics: Optional[Mapping[str, Array]] = None#
rejected_rewards: Optional[Array] = None#
replace(**kwargs)#
to_tuple()#
values() an object providing a view on D's values#
weight_sum: Optional[Union[float, Array, ndarray, bool, number]] = None#
z_loss: Optional[Union[float, Array, ndarray, bool, number]] = None#
class easydel.infra.loss_utils.SpecialLossNormalizingFactor(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: Enum

Specially calculated loss normalizing factors that are not constant.

NO_WEIGHT_NUM_REAL_TARGET_TOKENS#

Divide the loss by the number of real (non-padding) tokens (wont calculate Weights).

NUM_REAL_TARGET_TOKENS: Divide the loss by the number of real (non-padding) tokens. NUM_TOTAL_TARGET_TOKENS: Divide the loss by the total number of target tokens. AVERAGE_PER_SEQUENCE: Compute the average loss per sequence.

AVERAGE_PER_SEQUENCE = 3#
NO_WEIGHT_NUM_REAL_TARGET_TOKENS = 0#
NUM_REAL_TARGET_TOKENS = 1#
NUM_TOTAL_TARGET_TOKENS = 2#
easydel.infra.loss_utils.auxiliary_load_balancing_loss_func(gate_logits: Union[Array, ndarray, bool, number, Tuple[Union[Array, ndarray, bool, number], ...]], num_experts: int, top_k: int, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None) Union[Array, ndarray, bool, number][source]#

Computes auxiliary load balancing loss as in Switch Transformer.

See Switch Transformer (https://arxiv.org/abs/2101.03961)

Parameters
  • gate_logits โ€“ The logits for the gating network, either as a single array or tuple of arrays.

  • num_experts โ€“ The number of experts.

  • top_k โ€“ The number of experts to select.

  • attention_mask โ€“ An optional attention mask.

Returns

The auxiliary load balancing loss as a scalar array.

Raises

ValueError โ€“ If num_experts or top_k are invalid.

easydel.infra.loss_utils.compute_weighted_cross_entropy(logits: Union[Array, ndarray, bool, number], targets: Union[Array, ndarray, bool, number], weights: Optional[Union[Array, ndarray, bool, number]] = None, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[float] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#

Computes weighted cross-entropy loss, z-loss, and weight sum.

Parameters
  • logits โ€“ The predicted logits.

  • targets โ€“ The target class labels (integers).

  • weights โ€“ tp.Optional weights for each example.

  • label_smoothing โ€“ Label smoothing factor.

  • z_loss โ€“ Coefficient for the auxiliary z-loss term.

  • loss_normalizing_factor โ€“ A factor to normalize the loss.

Returns

A tuple containing the total loss, z-loss, and sum of weights.

easydel.infra.loss_utils.compute_weighted_cross_entropy_and_accuracy(logits: Union[Array, ndarray, bool, number], targets: Union[Array, ndarray, bool, number], weights: Optional[Union[Array, ndarray, bool, number]] = None, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[float] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#

Computes weighted cross-entropy loss, z-loss, weight sum, and accuracy.

Parameters
  • logits โ€“ The predicted logits.

  • targets โ€“ The target class labels (integers).

  • weights โ€“ tp.Optional weights for each example.

  • label_smoothing โ€“ Label smoothing factor.

  • z_loss โ€“ Coefficient for the auxiliary z-loss term.

  • loss_normalizing_factor โ€“ A factor to normalize the loss.

Returns

A tuple containing the total loss, z-loss, sum of weights, and accuracy.

easydel.infra.loss_utils.convert_special_loss_normalizing_factor_to_enum(x: str) SpecialLossNormalizingFactor[source]#

Converts a stringified version of SpecialLossNormalizingFactor to an enum.

Parameters

x โ€“ Stringified version of the enum value.

Returns

The corresponding SpecialLossNormalizingFactor enum value.

easydel.infra.loss_utils.cross_entropy_loss_and_accuracy(source, target, valid=None)[source]#
easydel.infra.loss_utils.dynamic_cross_entropy_loss(logits: Array, targets: Array, weight: Optional[Array] = None, ignore_index: int = -100, reduction: str = 'mean', label_smoothing: float = 0.0) Tuple[Array, Array][source]#

Cross entropy loss with support for masking, weights, and label smoothing.

Parameters
  • logits โ€“ Predicted logits (B, C) or (B, T, C)

  • targets โ€“ Ground truth integer labels (B, โ€ฆ) or probabilities (B, โ€ฆ, C)

  • weight โ€“ Optional per-class weights (C,)

  • ignore_index โ€“ Value of tokens to ignore (only for integer targets)

  • reduction โ€“ โ€˜noneโ€™, โ€˜meanโ€™, or โ€˜sumโ€™

  • label_smoothing โ€“ Smoothing factor between 0 and 1

Returns

Loss and accuracy (accuracy only valid for integer targets)

easydel.infra.loss_utils.fixed_cross_entropy(source: Array, target: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, num_items_in_batch: Optional[int] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#

Jax implementation of fixed cross-entropy loss with z-loss, label smoothing, masking.

Parameters
  • source โ€“ Predicted logits, shape (batch_size, num_classes) or (batch_size * seq_len, num_classes).

  • target โ€“ True labels, shape (batch_size,) or (batch_size * seq_len,). Must be integers.

  • num_items_in_batch โ€“ tp.Optional, used when reduction should be sum.

  • attention_mask โ€“ tp.Optional, boolean mask applied to the loss.

  • batch โ€“ tp.Optional batch for dynamic loss normalization

  • **kwargs โ€“ Additional keyword arguments.

Returns

The computed cross-entropy loss in LossMetrics.

easydel.infra.loss_utils.get_loss_normalizing_factor_and_weights(loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]], batch: Mapping[str, Union[Array, ndarray, bool, number]]) Tuple[Optional[float], Optional[Union[Array, ndarray, bool, number]]][source]#

Gets the loss normalizing factor and weights from a batch of data.

Parameters
  • loss_normalizing_factor โ€“ The loss normalizing factor to use.

  • batch โ€“ A dictionary containing the input batch of data.

Returns

A tuple containing the loss normalizing factor and loss weights.

easydel.infra.loss_utils.onehot(labels, num_classes, on_value=1.0, off_value=0.0)[source]#

Create a dense one-hot version of an indexed array.

NB: consider using the more standard jax.nn.one_hot instead.

Parameters
  • labels โ€“ an n-dim JAX array whose last dimension contains integer indices.

  • num_classes โ€“ the maximum possible index.

  • on_value โ€“ the โ€œonโ€ value for the one-hot array, defaults to 1.0.

  • off_value โ€“ the โ€œoffโ€ value for the one-hot array, defaults to 0.0.

Returns

A (n+1)-dim array whose last dimension contains one-hot vectors of length num_classes.

easydel.infra.loss_utils.sigmoid_cross_entropy_with_logits(logits: Array, labels: Array, weights: Optional[Array] = None, label_smoothing: float = 0.0, axis: Optional[Union[int, tuple]] = None) Array[source]#

Computes sigmoid cross entropy given logits and labels.

Parameters
  • logits โ€“ Input tensor

  • labels โ€“ Target tensor with the same shape as logits

  • weights โ€“ tp.Optional weights to apply to the loss

  • label_smoothing โ€“ Float in [0, 1]. Amount of smoothing to apply to labels

  • axis โ€“ The dimensions to reduce. If None, reduces all dimensions.

Returns

Sigmoid cross entropy loss, reduced according to axis if specified