easydel.infra.loss_utils#
- easydel.infra.loss_utils.ForCausalLMLoss(logits: Array, labels: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, num_items_in_batch: Optional[int] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#
Jax implementation of loss function for causal language models.
- Parameters
logits โ Predicted logits, shape (batch_size, seq_len, vocab_size).
labels โ True labels, shape (batch_size, seq_len). Must be integers.
num_items_in_batch โ tp.Optional, used when reduction should be sum.
batch โ tp.Optional batch for dynamic loss normalization
**kwargs โ Additional keyword arguments for the cross-entropy loss.
- Returns
The computed causal language modeling loss.
- easydel.infra.loss_utils.ForQuestionAnsweringLoss(start_logits: Array, end_logits: Array, start_positions: Array, end_positions: Array, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#
Jax implementation of loss function for question answering.
- Parameters
start_logits โ Predicted start logits, shape (batch_size, seq_len).
end_logits โ Predicted end logits, shape (batch_size, seq_len).
start_positions โ True start positions, shape (batch_size,).
end_positions โ True end positions, shape (batch_size,).
batch โ tp.Optional batch for dynamic loss normalization
**kwargs โ Additional keyword arguments for the cross-entropy loss.
- Returns
The computed question answering loss.
- easydel.infra.loss_utils.ForSequenceClassificationLoss(logits: Array, labels: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#
Jax implementation of loss function for sequence classification.
- Parameters
labels โ True labels, shape (batch_size,) or (batch_size, num_labels) for multi label classification.
logits โ Predicted logits, shape (batch_size, num_labels) or (batch_size, 1) or (batch_size,) for regression.
config โ Configuration with problem_type and num_labels attributes.
batch โ tp.Optional batch for dynamic loss normalization
**kwargs โ Additional keyword arguments for the cross-entropy loss.
- Returns
The computed sequence classification loss.
- easydel.infra.loss_utils.ForTokenClassification(logits: Array, labels: Array, config: Optional[LossConfig] = None, paxis: Optional[PartitionAxis] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#
Jax implementation of loss function for token classification.
- Parameters
logits โ Predicted logits, shape (batch_size, seq_len, num_labels).
labels โ True labels, shape (batch_size, seq_len). Must be integers.
config โ Configuration with num_labels attribute.
label_smoothing โ Label smoothing factor.
z_loss โ Coefficient for the auxiliary z-loss term.
loss_normalizing_factor โ A factor to normalize the loss, can also be enum.
batch โ tp.Optional batch for dynamic loss normalization
**kwargs โ Additional keyword arguments for the cross-entropy loss.
- Returns
The computed token classification loss.
- class easydel.infra.loss_utils.LossConfig(ignore_index: int = -100, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Union[float, int, str, easydel.infra.loss_utils.SpecialLossNormalizingFactor, NoneType] = 'NUM_REAL_TARGET_TOKENS', num_labels: Optional[str] = None, problem_type: Optional[str] = None, divide_weight_sum: bool = False, shift_tokens: bool = True, break_on_nan: bool = True, reduction: Optional[Literal['none', 'mean', 'sum']] = None, num_classification_labels: Optional[int] = None, classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None)[source]#
Bases:
Mapping- break_on_nan: bool = True#
- classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None#
- divide_weight_sum: bool = False#
- from_tuple()#
- ignore_index: int = -100#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- label_smoothing: float = 0.0#
- loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS'#
- num_classification_labels: Optional[int] = None#
- num_labels: Optional[str] = None#
- problem_type: Optional[str] = None#
- reduction: Optional[Literal['none', 'mean', 'sum']] = None#
- replace(**kwargs)#
- shift_tokens: bool = True#
- to_tuple()#
- values() an object providing a view on D's values#
- z_loss: float = 0.0#
- class easydel.infra.loss_utils.LossMetrics(loss: Union[float, jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, z_loss: Union[float, jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, weight_sum: Union[float, jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, accuracy: Union[float, jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, learning_rate: Union[float, jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, max_grad_norm: Optional[flax.struct.PyTreeNode] = None, mean_grad_norm: Optional[flax.struct.PyTreeNode] = None, grad_norms: Optional[flax.struct.PyTreeNode] = None, chosen_rewards: Optional[jax.Array] = None, rejected_rewards: Optional[jax.Array] = None, other_metrics: Optional[Mapping[str, jax.Array]] = None, execution_time: Optional[float] = None)[source]#
Bases:
Mapping- execution_time: Optional[float] = None#
- from_tuple()#
- grad_norms: Optional[PyTreeNode] = None#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- max_grad_norm: Optional[PyTreeNode] = None#
- mean_grad_norm: Optional[PyTreeNode] = None#
- replace(**kwargs)#
- to_tuple()#
- values() an object providing a view on D's values#
- class easydel.infra.loss_utils.SpecialLossNormalizingFactor(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
EnumSpecially calculated loss normalizing factors that are not constant.
- NO_WEIGHT_NUM_REAL_TARGET_TOKENS#
Divide the loss by the number of real (non-padding) tokens (wont calculate Weights).
NUM_REAL_TARGET_TOKENS: Divide the loss by the number of real (non-padding) tokens. NUM_TOTAL_TARGET_TOKENS: Divide the loss by the total number of target tokens. AVERAGE_PER_SEQUENCE: Compute the average loss per sequence.
- AVERAGE_PER_SEQUENCE = 3#
- NO_WEIGHT_NUM_REAL_TARGET_TOKENS = 0#
- NUM_REAL_TARGET_TOKENS = 1#
- NUM_TOTAL_TARGET_TOKENS = 2#
- easydel.infra.loss_utils.auxiliary_load_balancing_loss_func(gate_logits: Union[Array, ndarray, bool, number, Tuple[Union[Array, ndarray, bool, number], ...]], num_experts: int, top_k: int, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None) Union[Array, ndarray, bool, number][source]#
Computes auxiliary load balancing loss as in Switch Transformer.
See Switch Transformer (https://arxiv.org/abs/2101.03961)
- Parameters
gate_logits โ The logits for the gating network, either as a single array or tuple of arrays.
num_experts โ The number of experts.
top_k โ The number of experts to select.
attention_mask โ An optional attention mask.
- Returns
The auxiliary load balancing loss as a scalar array.
- Raises
ValueError โ If num_experts or top_k are invalid.
- easydel.infra.loss_utils.compute_weighted_cross_entropy(logits: Union[Array, ndarray, bool, number], targets: Union[Array, ndarray, bool, number], weights: Optional[Union[Array, ndarray, bool, number]] = None, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[float] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#
Computes weighted cross-entropy loss, z-loss, and weight sum.
- Parameters
logits โ The predicted logits.
targets โ The target class labels (integers).
weights โ tp.Optional weights for each example.
label_smoothing โ Label smoothing factor.
z_loss โ Coefficient for the auxiliary z-loss term.
loss_normalizing_factor โ A factor to normalize the loss.
- Returns
A tuple containing the total loss, z-loss, and sum of weights.
- easydel.infra.loss_utils.compute_weighted_cross_entropy_and_accuracy(logits: Union[Array, ndarray, bool, number], targets: Union[Array, ndarray, bool, number], weights: Optional[Union[Array, ndarray, bool, number]] = None, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[float] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#
Computes weighted cross-entropy loss, z-loss, weight sum, and accuracy.
- Parameters
logits โ The predicted logits.
targets โ The target class labels (integers).
weights โ tp.Optional weights for each example.
label_smoothing โ Label smoothing factor.
z_loss โ Coefficient for the auxiliary z-loss term.
loss_normalizing_factor โ A factor to normalize the loss.
- Returns
A tuple containing the total loss, z-loss, sum of weights, and accuracy.
- easydel.infra.loss_utils.convert_special_loss_normalizing_factor_to_enum(x: str) SpecialLossNormalizingFactor[source]#
Converts a stringified version of SpecialLossNormalizingFactor to an enum.
- Parameters
x โ Stringified version of the enum value.
- Returns
The corresponding SpecialLossNormalizingFactor enum value.
- easydel.infra.loss_utils.dynamic_cross_entropy_loss(logits: Array, targets: Array, weight: Optional[Array] = None, ignore_index: int = -100, reduction: str = 'mean', label_smoothing: float = 0.0) Tuple[Array, Array][source]#
Cross entropy loss with support for masking, weights, and label smoothing.
- Parameters
logits โ Predicted logits (B, C) or (B, T, C)
targets โ Ground truth integer labels (B, โฆ) or probabilities (B, โฆ, C)
weight โ Optional per-class weights (C,)
ignore_index โ Value of tokens to ignore (only for integer targets)
reduction โ โnoneโ, โmeanโ, or โsumโ
label_smoothing โ Smoothing factor between 0 and 1
- Returns
Loss and accuracy (accuracy only valid for integer targets)
- easydel.infra.loss_utils.fixed_cross_entropy(source: Array, target: Array, attention_mask: Optional[Array] = None, config: Optional[LossConfig] = None, num_items_in_batch: Optional[int] = None, batch: Optional[Mapping[str, Union[Array, ndarray, bool, number]]] = None, **kwargs: Any) LossMetrics[source]#
Jax implementation of fixed cross-entropy loss with z-loss, label smoothing, masking.
- Parameters
source โ Predicted logits, shape (batch_size, num_classes) or (batch_size * seq_len, num_classes).
target โ True labels, shape (batch_size,) or (batch_size * seq_len,). Must be integers.
num_items_in_batch โ tp.Optional, used when reduction should be sum.
attention_mask โ tp.Optional, boolean mask applied to the loss.
batch โ tp.Optional batch for dynamic loss normalization
**kwargs โ Additional keyword arguments.
- Returns
The computed cross-entropy loss in LossMetrics.
- easydel.infra.loss_utils.get_loss_normalizing_factor_and_weights(loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]], batch: Mapping[str, Union[Array, ndarray, bool, number]]) Tuple[Optional[float], Optional[Union[Array, ndarray, bool, number]]][source]#
Gets the loss normalizing factor and weights from a batch of data.
- Parameters
loss_normalizing_factor โ The loss normalizing factor to use.
batch โ A dictionary containing the input batch of data.
- Returns
A tuple containing the loss normalizing factor and loss weights.
- easydel.infra.loss_utils.onehot(labels, num_classes, on_value=1.0, off_value=0.0)[source]#
Create a dense one-hot version of an indexed array.
NB: consider using the more standard
jax.nn.one_hotinstead.- Parameters
labels โ an n-dim JAX array whose last dimension contains integer indices.
num_classes โ the maximum possible index.
on_value โ the โonโ value for the one-hot array, defaults to 1.0.
off_value โ the โoffโ value for the one-hot array, defaults to 0.0.
- Returns
A (n+1)-dim array whose last dimension contains one-hot vectors of length num_classes.
- easydel.infra.loss_utils.sigmoid_cross_entropy_with_logits(logits: Array, labels: Array, weights: Optional[Array] = None, label_smoothing: float = 0.0, axis: Optional[Union[int, tuple]] = None) Array[source]#
Computes sigmoid cross entropy given logits and labels.
- Parameters
logits โ Input tensor
labels โ Target tensor with the same shape as logits
weights โ tp.Optional weights to apply to the loss
label_smoothing โ Float in [0, 1]. Amount of smoothing to apply to labels
axis โ The dimensions to reduce. If None, reduces all dimensions.
- Returns
Sigmoid cross entropy loss, reduced according to axis if specified