# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss computation utilities for EasyDeL models.
This module provides a comprehensive suite of loss functions optimized for large-scale
language model training and inference. It includes memory-efficient implementations
using chunking strategies, custom VJP for gradient computation, and support for
various normalization and regularization techniques.
Classes:
SpecialLossNormalizingFactor: Enum for dynamic loss normalization strategies
LossConfig: Configuration class for customizing loss computation behavior
LossMetrics: Container for loss metrics and auxiliary training information
Loss Functions:
ForCausalLMLoss: Causal language modeling with token shifting
ForSequenceClassificationLoss: Single/multi-label classification and regression
ForQuestionAnsweringLoss: Span-based question answering (SQuAD-style)
ForTokenClassification: Token-level classification (NER, POS tagging)
Utility Functions:
cross_entropy_blockwise_logits: Memory-efficient CE for large vocabularies
sparse_cross_entropy_chunked_vocab: Chunked vocabulary processing
sparse_cross_entropy_chunked_tokens: Chunked token processing
compute_weighted_cross_entropy: Standard weighted CE with z-loss
auxiliary_load_balancing_loss_func: MoE load balancing loss
Key Features:
- Memory-efficient chunking for large vocabulary/sequence lengths
- Flexible loss normalization (per-token, per-sequence, weighted)
- Label smoothing and z-loss regularization
- Custom VJP for efficient gradient computation
- Support for packed sequences and attention masking
- Mixed precision computation support
- MoE auxiliary loss for expert load balancing
Example:
>>> from easydel.infra import LossConfig, ForCausalLMLoss
>>>
>>> # Configure loss with label smoothing and z-loss
>>> config = LossConfig(
... label_smoothing=0.1,
... z_loss=1e-4,
... loss_normalizing_factor="NUM_REAL_TARGET_TOKENS",
... chunk_vocab_size=8192 # Enable vocabulary chunking
... )
>>>
>>> # Compute loss for language modeling
>>> metrics = ForCausalLMLoss(
... logits=model_output, # [batch, seq_len, vocab_size]
... labels=targets, # [batch, seq_len]
... attention_mask=mask, # [batch, seq_len]
... config=config
... )
>>> print(f"Loss: {metrics.loss}, Accuracy: {metrics.accuracy}")
"""
import dataclasses
import enum
import typing as tp
from dataclasses import fields
from functools import reduce
from operator import mul
import chex
import flax
import flax.struct
import jax
import jax.numpy as jnp
from eformer.escale import PartitionAxis
from eformer.escale.partition.constraints import with_sharding_constraint
from eformer.pytree import auto_pytree
from jax import lax
from jax.sharding import PartitionSpec
from easydel.utils.compiling_utils import hash_fn
[docs]@enum.unique
class SpecialLossNormalizingFactor(enum.Enum):
"""
Specifies special, dynamically calculated loss normalizing factors.
These enums are used in loss functions to indicate how the loss should be
normalized based on properties of the input batch, rather than using a fixed
constant.
Attributes:
NO_WEIGHT_NUM_REAL_TARGET_TOKENS: Divides the loss by the number of non-padding target tokens,
ignoring any provided loss weights.
NUM_REAL_TARGET_TOKENS: Divides the loss by the number of non-padding target tokens,
considering provided loss weights.
NUM_TOTAL_TARGET_TOKENS: Divides the loss by the total number of target tokens, including padding.
AVERAGE_PER_SEQUENCE: Computes the average loss per sequence in the batch.
"""
NO_WEIGHT_NUM_REAL_TARGET_TOKENS = 0
NUM_REAL_TARGET_TOKENS = 1
NUM_TOTAL_TARGET_TOKENS = 2
AVERAGE_PER_SEQUENCE = 3
SLNF = SpecialLossNormalizingFactor
FACTOR_TYPE = tp.Optional[float | int | str | SLNF] # noqa
[docs]@auto_pytree
class LossConfig:
"""
Configuration class for customizing loss computation behavior.
Attributes:
ignore_index (int): Specifies a target value that is ignored and does not contribute to the loss.
Defaults to -100.
label_smoothing (float): Amount of label smoothing to apply. 0.0 means no smoothing.
Defaults to 0.0.
z_loss (float): Coefficient for the z-loss regularization term, which encourages logits
for non-target classes to be small. Defaults to 0.0.
loss_normalizing_factor (FACTOR_TYPE): How to normalize the loss. Can be a constant float/int,
a string representation of a `SpecialLossNormalizingFactor` enum, or the enum itself.
Defaults to "NUM_REAL_TARGET_TOKENS".
num_labels (tp.Optional[int]): The number of labels for classification tasks. Used in
`ForSequenceClassificationLoss`. Defaults to None.
problem_type (tp.Optional[str]): Specifies the problem type for sequence classification
(e.g., "single_label_classification", "multi_label_classification").
Defaults to None.
divide_weight_sum (bool): If True, divides the loss by the sum of weights, in addition to
the `loss_normalizing_factor`. Defaults to False.
shift_tokens (bool): If True (typically for Causal LM), shifts the logits and labels
so that the model predicts the next token. Defaults to True.
break_on_nan (bool): If True, raises an `EasyDeLBreakRequest` if a NaN is encountered
during loss computation. Defaults to True.
reduction (tp.Optional[tp.Literal["none", "mean", "sum"]]): Specifies the reduction to apply
to the loss. If None, the default reduction of the specific loss function is used.
Defaults to None.
num_classification_labels (tp.Optional[int]): Number of labels specifically for sequence
classification. Alias for `num_labels`. Defaults to None.
classification_problem_type (tp.Optional[tp.Literal["regression", "single_label_classification",
"multi_label_classification"]]):
Problem type specifically for sequence classification. Alias for `problem_type`.
Defaults to None.
"""
ignore_index: int = -100
label_smoothing: float = 0.0
z_loss: float = 0.0
loss_normalizing_factor: FACTOR_TYPE = "NUM_REAL_TARGET_TOKENS"
num_labels: str | None = None
problem_type: str | None = None
divide_weight_sum: bool = False
shift_tokens: bool = True
break_on_nan: bool = True
reduction: tp.Literal["none", "mean", "sum"] | None = None
num_classification_labels: int | None = None
classification_problem_type: (
tp.Literal["regression", "single_label_classification", "multi_label_classification"] | None
) = None
chunk_vocab_size: int | None = None
chunk_token_size: int | None = None
chunk_block_size: int | None = None
compute_dtype: tp.Literal["fp32", "bf16"] | None = None
def __repr__(self):
cls_name = self.__class__.__name__
field_lines = [f" {f.name}: {getattr(self, f.name)!r}".replace("\n", "\n ") for f in fields(self)]
return f"{cls_name}(\n" + "\n".join(field_lines) + "\n)"
__str__ = __repr__
__hash__ = hash_fn
[docs]@auto_pytree
class LossMetrics:
"""
Container for various metrics related to loss computation and model training.
Attributes:
loss (tp.Optional[tp.Union[float, chex.Array]]): The primary computed loss value.
z_loss (tp.Optional[tp.Union[float, chex.Array]]): The computed z-loss regularization term.
weight_sum (tp.Optional[tp.Union[float, chex.Array]]): The sum of weights used in the loss calculation.
accuracy (tp.Optional[tp.Union[float, chex.Array]]): Computed accuracy, if applicable.
learning_rate (tp.Optional[tp.Union[float, chex.Array]]): The learning rate used for the current step.
max_grad_norm (tp.Optional[flax.struct.PyTreeNode]): The maximum gradient norm observed.
mean_grad_norm (tp.Optional[flax.struct.PyTreeNode]): The mean gradient norm observed.
grad_norms (tp.Optional[flax.struct.PyTreeNode]): A pytree containing the norms of gradients for each parameter.
chosen_rewards (tp.Optional[jax.Array]): Rewards for the chosen sequence in preference-based tasks.
rejected_rewards (tp.Optional[jax.Array]): Rewards for the rejected sequence in preference-based tasks.
other_metrics (tp.Optional[tp.Mapping[str, jax.Array]]): A dictionary for any additional custom metrics.
execution_time (tp.Optional[float]): Time taken for the computation step.
"""
loss: float | chex.Array | None = None
z_loss: float | chex.Array | None = None
weight_sum: float | chex.Array | None = None
accuracy: float | chex.Array | None = None
learning_rate: float | chex.Array | None = None
max_grad_norm: flax.struct.PyTreeNode | None = None
mean_grad_norm: flax.struct.PyTreeNode | None = None
grad_norms: flax.struct.PyTreeNode | None = None
chosen_rewards: jax.Array | None = None
rejected_rewards: jax.Array | None = None
other_metrics: tp.Mapping[str, jax.Array] | None = None
execution_time: float | None = None
def _logsumexp_chunked(x: jnp.ndarray, chunk_size: int) -> jnp.ndarray:
"""Compute logsumexp over the last dimension in chunks for memory efficiency.
This function computes log(sum(exp(x))) over the last dimension using a chunked
approach to reduce memory usage for large vocabulary sizes.
Args:
x: Input array with shape [..., V] where V is vocabulary size.
chunk_size: Size of chunks to process at a time.
Returns:
Array with shape [...] containing logsumexp values.
"""
# x: [..., V]
V: int = x.shape[-1] # static python int
n_full = V // chunk_size
tail = V - n_full * chunk_size # static python int
# Pass 1: max
def max_body(i, m):
start = i * chunk_size
chunk = lax.dynamic_slice_in_dim(x, start, chunk_size, axis=-1)
return jnp.maximum(m, jnp.max(chunk, axis=-1))
m = jnp.full(x.shape[:-1], -jnp.inf, dtype=x.dtype)
m = lax.fori_loop(0, n_full, max_body, m)
if tail:
start = n_full * chunk_size
chunk = lax.dynamic_slice_in_dim(x, start, tail, axis=-1)
m = jnp.maximum(m, jnp.max(chunk, axis=-1))
# Pass 2: sum of exp(x - m)
def sum_body(i, s):
start = i * chunk_size
chunk = lax.dynamic_slice_in_dim(x, start, chunk_size, axis=-1)
return s + jnp.sum(jnp.exp(chunk - m[..., None]), axis=-1)
s = jnp.zeros_like(m)
s = lax.fori_loop(0, n_full, sum_body, s)
if tail:
start = n_full * chunk_size
chunk = lax.dynamic_slice_in_dim(x, start, tail, axis=-1)
s = s + jnp.sum(jnp.exp(chunk - m[..., None]), axis=-1)
return jnp.log(s) + m
[docs]def cross_entropy_blockwise_logits(
logits: jax.Array, # [B, T, V] or [N, V]
targets: jax.Array, # [B, T] or [N]
weights: jax.Array | None = None, # [B, T] or [N]
*,
ignore_index: int = -100,
label_smoothing: float = 0.0,
z_loss: float = 0.0,
block_size: int = 8192,
dtype: jnp.dtype | None = jnp.float32,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
"""
Blockwise sparse cross-entropy from logits without materializing softmax or one-hot.
Returns (total_loss, total_z_loss, weight_sum, accuracy).
"""
# Flatten tokens
if logits.ndim == 3:
B, T, V = logits.shape
L = B * T
logits2d = logits.reshape(L, V)
y = targets.reshape(L)
w = None if weights is None else weights.reshape(L).astype(jnp.float32)
elif logits.ndim == 2:
L, V = logits.shape
logits2d = logits
y = targets
w = None if weights is None else weights.astype(jnp.float32)
else:
raise ValueError(f"logits must be [B, T, V] or [N, V], got {logits.shape}")
if block_size <= 0:
raise ValueError(f"block_size must be > 0, got {block_size}")
# Upcast for numerical stability
logits2d = logits2d.astype(dtype or logits2d.dtype)
# Valid/weights
valid = y != ignore_index
y_safe = jnp.where(valid, y, 0)
w = valid.astype(jnp.float32) if w is None else valid.astype(jnp.float32) * w
# Accumulators
neg_inf = jnp.array(-jnp.inf, dtype=jnp.float32)
m = jnp.full((L,), neg_inf)
log_z = jnp.full((L,), neg_inf)
o = jnp.zeros((L,), dtype=jnp.float32) # sum of target logits
sum_logits = jnp.zeros((L,), dtype=jnp.float32) # for smoothing
best_logit = jnp.full((L,), neg_inf)
best_id = jnp.zeros((L,), dtype=jnp.int32)
n_full = V // block_size
tail = V - n_full * block_size
def process_block(start: int, size: int, m, log_z, o, sum_logits, best_logit, best_id):
# Static slice sizes: size is either block_size (in a fori_loop) or tail (once)
chunk = lax.dynamic_slice_in_dim(logits2d, start, size, axis=1) # [L, size]
# Running logsumexp via updated max
chunk_max = jnp.max(chunk, axis=1)
new_m = jnp.maximum(m, chunk_max)
log_z = new_m + jnp.log(jnp.exp(log_z - new_m) + jnp.sum(jnp.exp(chunk - new_m[:, None]), axis=1))
m = new_m
# Accumulate target logit (sparse)
in_block = (y_safe >= start) & (y_safe < start + size)
idx = jnp.where(in_block, (y_safe - start).astype(jnp.int32), 0)
logit_y_b = jnp.take_along_axis(chunk, idx[:, None], axis=1)[:, 0]
o = o + jnp.where(in_block, logit_y_b, 0.0)
# Sum logits for smoothing
sum_logits = sum_logits + jnp.sum(chunk, axis=1)
# Streamed argmax for accuracy
block_best = jnp.argmax(chunk, axis=1)
block_best_id = start + block_best.astype(jnp.int32)
update = chunk_max > best_logit
best_logit = jnp.where(update, chunk_max, best_logit)
best_id = jnp.where(update, block_best_id, best_id)
return m, log_z, o, sum_logits, best_logit, best_id
def full_body(i, carry):
start = i * block_size
return process_block(start, block_size, *carry)
carry = (m, log_z, o, sum_logits, best_logit, best_id)
if n_full > 0:
carry = lax.fori_loop(0, n_full, full_body, carry)
if tail:
start = n_full * block_size
carry = process_block(start, tail, *carry)
m, log_z, o, sum_logits, best_logit, best_id = carry
# Base CE
nll = log_z - o # [L]
# Label smoothing: (1-eps)*NLL + eps*(log_z - mean(logits))
if label_smoothing and label_smoothing != 0.0:
eps = jnp.asarray(label_smoothing, dtype=jnp.float32)
mean_logits = sum_logits / float(V)
nll = (1.0 - eps) * nll + eps * (log_z - mean_logits)
# z-loss term
zterm = (z_loss * (log_z**2)) if (z_loss and z_loss != 0.0) else 0.0
per_tok = nll + (zterm if (z_loss and z_loss != 0.0) else 0.0)
total_loss = jnp.sum(per_tok * w)
total_z_loss = jnp.sum((zterm if (z_loss and z_loss != 0.0) else 0.0) * w)
weight_sum = jnp.sum(w)
# Weighted accuracy
acc = jnp.sum((best_id == y_safe).astype(jnp.float32) * w) / jnp.maximum(weight_sum, 1e-8)
return total_loss, total_z_loss, weight_sum, acc
[docs]def sparse_cross_entropy_chunked_vocab(
logits: jnp.ndarray, # [..., V]
targets: jnp.ndarray, # [...]
weights: jnp.ndarray | None = None, # [...]
ignore_index: int = -100,
label_smoothing: float = 0.0,
z_loss: float = 0.0,
reduction: str = "mean",
chunk_size: int = 8192,
compute_dtype: jnp.dtype = jnp.float32,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
logits = logits.astype(compute_dtype)
valid = targets != ignore_index
safe_targets = jnp.where(valid, targets, 0)
lse = _logsumexp_chunked(logits, chunk_size) # [...,]
logit_y = jnp.take_along_axis(logits, safe_targets[..., None], axis=-1)[..., 0]
nll = lse - logit_y
if label_smoothing > 0.0:
eps = label_smoothing
nll = (1.0 - eps) * nll + eps * (lse - jnp.mean(logits, axis=-1))
z_term = (z_loss * jnp.square(lse)) if z_loss > 0.0 else 0.0
nll = nll + (z_term if z_loss > 0.0 else 0.0)
w = valid.astype(compute_dtype) if weights is None else valid.astype(compute_dtype) * weights.astype(compute_dtype)
total_loss = jnp.sum(nll * w)
total_z_loss = jnp.sum((z_term if z_loss > 0.0 else 0.0) * w)
weight_sum = jnp.sum(w)
if reduction == "mean":
total_loss = total_loss / jnp.maximum(weight_sum, 1e-8)
# Weighted accuracy
correct = (jnp.argmax(logits, axis=-1) == targets).astype(compute_dtype) * w
accuracy = jnp.sum(correct) / jnp.maximum(weight_sum, 1e-8)
return total_loss, total_z_loss, weight_sum, accuracy
[docs]def sparse_cross_entropy_chunked_tokens(
logits: jnp.ndarray, # [B, T, V] or [N, V]
targets: jnp.ndarray, # [B, T] or [N]
weights: jnp.ndarray | None = None,
ignore_index: int = -100,
label_smoothing: float = 0.0,
z_loss: float = 0.0,
reduction: str = "sum", # sum here; normalize outside for consistency
token_chunk_size: int = 8192,
compute_dtype: jnp.dtype = jnp.float32,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
logits = logits.astype(compute_dtype)
V = logits.shape[-1]
logits2d = logits.reshape(-1, V)
targets1d = targets.reshape(-1)
weights1d = None if weights is None else weights.reshape(-1).astype(compute_dtype)
N: int = logits2d.shape[0]
n_full = N // token_chunk_size
tail = N - n_full * token_chunk_size
def body(i, carry):
tot, wsum, acc_sum, zsum = carry
start = i * token_chunk_size
chunk_logits = lax.dynamic_slice_in_dim(logits2d, start, token_chunk_size, axis=0)
chunk_targets = lax.dynamic_slice_in_dim(targets1d, start, token_chunk_size, axis=0)
chunk_weights = (
None if weights1d is None else lax.dynamic_slice_in_dim(weights1d, start, token_chunk_size, axis=0)
)
lse = jax.scipy.special.logsumexp(chunk_logits, axis=-1)
logit_y = jnp.take_along_axis(chunk_logits, chunk_targets[:, None], axis=-1)[:, 0]
valid = chunk_targets != ignore_index
nll = lse - logit_y
if label_smoothing > 0.0:
eps = label_smoothing
nll = (1.0 - eps) * nll + eps * (lse - jnp.mean(chunk_logits, axis=-1))
zterm = (z_loss * jnp.square(lse)) if z_loss > 0.0 else 0.0
nll = nll + (zterm if z_loss > 0.0 else 0.0)
w = valid.astype(compute_dtype) if chunk_weights is None else valid.astype(compute_dtype) * chunk_weights
loss_sum = jnp.sum(nll * w)
w_sum = jnp.sum(w)
z_sum = jnp.sum((zterm if z_loss > 0.0 else 0.0) * w)
acc = jnp.sum((jnp.argmax(chunk_logits, axis=-1) == chunk_targets).astype(compute_dtype) * w)
return (tot + loss_sum, wsum + w_sum, acc_sum + acc, zsum + z_sum)
# Full chunks
init = (
jnp.array(0.0, compute_dtype),
jnp.array(0.0, compute_dtype),
jnp.array(0.0, compute_dtype),
jnp.array(0.0, compute_dtype),
)
carry = init
carry = lax.fori_loop(0, n_full, body, carry)
# Tail
if tail:
start = n_full * token_chunk_size
chunk_logits = lax.dynamic_slice_in_dim(logits2d, start, tail, axis=0)
chunk_targets = lax.dynamic_slice_in_dim(targets1d, start, tail, axis=0)
chunk_weights = None if weights1d is None else lax.dynamic_slice_in_dim(weights1d, start, tail, axis=0)
lse = jax.scipy.special.logsumexp(chunk_logits, axis=-1)
logit_y = jnp.take_along_axis(chunk_logits, chunk_targets[:, None], axis=-1)[:, 0]
valid = chunk_targets != ignore_index
nll = lse - logit_y
if label_smoothing > 0.0:
eps = label_smoothing
nll = (1.0 - eps) * nll + eps * (lse - jnp.mean(chunk_logits, axis=-1))
zterm = (z_loss * jnp.square(lse)) if z_loss > 0.0 else 0.0
nll = nll + (zterm if z_loss > 0.0 else 0.0)
w = valid.astype(compute_dtype) if chunk_weights is None else valid.astype(compute_dtype) * chunk_weights
loss_sum = jnp.sum(nll * w)
w_sum = jnp.sum(w)
z_sum = jnp.sum((zterm if z_loss > 0.0 else 0.0) * w)
acc = jnp.sum((jnp.argmax(chunk_logits, axis=-1) == chunk_targets).astype(compute_dtype) * w)
carry = (carry[0] + loss_sum, carry[1] + w_sum, carry[2] + acc, carry[3] + z_sum)
total_loss, total_wsum, acc_sum, total_z_loss = carry
if reduction == "mean":
total_loss = total_loss / jnp.maximum(total_wsum, 1e-8)
accuracy = acc_sum / jnp.maximum(total_wsum, 1e-8)
return total_loss, total_z_loss, total_wsum, accuracy
[docs]def dynamic_cross_entropy_loss(
logits: jnp.ndarray,
targets: jnp.ndarray,
weight: jnp.ndarray | None = None,
ignore_index: int = -100,
reduction: str = "mean",
label_smoothing: float = 0.0,
compute_dtype: jnp.dtype = jnp.float32,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""
Computes the cross-entropy loss with optional label smoothing and ignore index,
dynamically handling different reduction types.
Args:
logits (jnp.ndarray): The predicted logits from the model (batch_size, ..., num_classes).
targets (jnp.ndarray): The target labels (batch_size, ...).
weight (tp.Optional[jnp.ndarray]): Optional weights for each element (batch_size, ...).
Defaults to None.
ignore_index (int): Index in the target labels to ignore. Defaults to -100.
reduction (str): Specifies the reduction method: 'mean', 'sum', or 'none'.
Defaults to "mean".
label_smoothing (float): The amount of label smoothing to apply (0.0 means no smoothing).
Defaults to 0.0.
Returns:
tp.Tuple[jnp.ndarray, jnp.ndarray]:
- The computed loss (scalar if reduction is 'mean' or 'sum', array otherwise).
- The normalization factor (sum of weights or count of non-ignored elements).
Raises:
ValueError: If an invalid reduction method is specified.
"""
logits = logits.astype(compute_dtype)
valid = targets != ignore_index
safe_targets = jnp.where(valid, targets, 0)
lse = jax.scipy.special.logsumexp(logits, axis=-1)
logit_y = jnp.take_along_axis(logits, safe_targets[..., None], axis=-1)[..., 0]
nll = lse - logit_y
if label_smoothing > 0.0:
eps = label_smoothing
nll = (1.0 - eps) * nll + eps * (lse - jnp.mean(logits, axis=-1))
w = valid.astype(compute_dtype) if weight is None else valid.astype(compute_dtype) * weight.astype(compute_dtype)
loss = nll * w
norm = jnp.maximum(jnp.sum(w), 1e-8)
if reduction == "mean":
return jnp.sum(loss) / norm, norm
elif reduction == "sum":
return jnp.sum(loss), norm
elif reduction == "none":
return loss, w
else:
raise ValueError(f"Invalid reduction: {reduction}. Must be 'mean', 'sum', or 'none'.")
[docs]def sigmoid_cross_entropy_with_logits(
logits: jnp.ndarray,
labels: jnp.ndarray,
weights: jnp.ndarray | None = None,
label_smoothing: float = 0.0,
axis: int | tuple | None = None,
) -> jnp.ndarray:
"""
Computes sigmoid cross-entropy loss between logits and labels.
Measures the probability error in discrete classification tasks in which each
class is independent and not mutually exclusive. For instance, one could
perform multilabel classification where a picture can contain both an elephant
and a dog at the same time.
Args:
logits: The predicted logits from the model.
labels: The target labels.
weights (tp.Optional[jnp.ndarray]): Optional weights for the loss computation.
Defaults to None.
label_smoothing (float): Amount of label smoothing to apply (0.0 means no smoothing).
Defaults to 0.0.
axis (tp.Optional[tp.Union[int, tuple]]): The axis or axes along which to compute the mean.
If None, the mean is computed over all elements. Defaults to None.
Returns:
jnp.ndarray: The computed sigmoid cross-entropy loss.
"""
if label_smoothing > 0.0:
labels = labels * (1.0 - label_smoothing) + 0.5 * label_smoothing
log_p = jax.nn.log_sigmoid(logits)
log_not_p = jax.nn.log_sigmoid(-logits)
loss = -labels * log_p - (1.0 - labels) * log_not_p
if weights is not None:
loss *= weights
if axis is None:
return jnp.mean(loss)
else:
return jnp.mean(loss, axis=axis)
[docs]def onehot(labels, num_classes, on_value=1.0, off_value=0.0):
"""
Create one-hot encoded versions of integer labels.
Args:
labels (jnp.ndarray): An array of integer labels.
num_classes (int): The total number of classes.
on_value (float): The value to use for the "on" state (corresponding to the label).
Defaults to 1.0.
off_value (float): The value to use for the "off" states.
Defaults to 0.0.
Returns:
jnp.ndarray: The one-hot encoded array with shape `labels.shape + (num_classes,)`.
"""
x = lax.eq(labels[..., None], jnp.arange(num_classes)[(None,) * labels.ndim])
y = lax.select(x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value))
return y
@jax.custom_vjp
def cross_entropy_with_logits(
logits: chex.Array,
targets: chex.Array,
z_loss: float,
) -> tuple[chex.Array, chex.Array]:
"""
Computes cross-entropy loss with potential z-loss regularization.
This function calculates the standard cross-entropy loss between logits and targets.
It also includes an optional z-loss term, which penalizes large logits for non-target
classes, encouraging the model to be less confident in incorrect predictions.
A custom VJP (vector-Jacobian product) is defined for efficient gradient computation.
Args:
logits (chex.Array): The predicted logits from the model (batch_size, ..., num_classes).
targets (chex.Array): The target labels, typically one-hot encoded (batch_size, ..., num_classes).
z_loss (float): The coefficient for the z-loss regularization. If 0, z-loss is not computed.
Returns:
tp.Tuple[chex.Array, chex.Array]:
- loss: The cross-entropy loss for each example (batch_size, ...).
- z_loss: The z-loss value for each example (batch_size, ...). Returns zero if `z_loss` coeff is 0.
"""
logsumexp = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=False)
log_softmax = logits - logsumexp[..., None]
ce = -jnp.sum(targets * log_softmax, axis=-1)
z = z_loss * jnp.square(logsumexp)
return ce + z, z
def _cross_entropy_with_logits_fwd(
logits: chex.Array,
targets: chex.Array,
z_loss: float = 0.0,
) -> tuple[
tuple[
chex.Array,
chex.Array,
],
tuple[
chex.Array,
chex.Array,
float,
chex.Array,
chex.Array,
chex.Array,
chex.Array,
],
]:
"""
Forward pass for cross_entropy_with_logits custom VJP.
Computes cross-entropy loss with z-loss regularization and saves intermediates
for efficient gradient computation in the backward pass.
Args:
logits: Model predictions with shape (batch_size, ..., num_classes).
targets: One-hot encoded targets with same shape as logits.
z_loss: Coefficient for z-loss regularization.
Returns:
Tuple containing:
- (loss, z_loss_value): The computed losses.
- Residuals: Intermediate values needed for backward pass including
targets, z_loss coefficient, exp_shifted, sum_exp, log_softmax, and log_z.
"""
max_logit = logits.max(axis=-1, keepdims=True)
shifted = logits - max_logit
exp_shifted = jnp.exp(shifted)
sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True)
log_softmax = shifted - jnp.log(sum_exp)
ce = -jnp.sum(targets * log_softmax, axis=-1)
log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1)
z = z_loss * jax.lax.square(log_z)
ce_plus_z = ce + z
y = (ce_plus_z, z)
res = (targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z)
return y, res
def _cross_entropy_with_logits_bwd(
res: tuple[
chex.Array,
chex.Array,
float,
chex.Array,
chex.Array,
chex.Array,
chex.Array,
],
g: tuple[chex.Array, chex.Array],
) -> tuple[chex.Array, chex.Array, chex.Array]:
"""Backward pass for cross_entropy_with_logits custom VJP.
Computes gradients with respect to logits and targets using saved intermediates
from the forward pass.
Args:
res: Residuals from forward pass containing targets, z_loss, exp_shifted,
sum_exp, log_softmax, and log_z.
g: Gradient of the loss with respect to the output.
Returns:
Tuple of gradients with respect to (logits, targets, z_loss).
"""
g0 = g[0]
targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res
softmax = exp_shifted / sum_exp
deriv = softmax - targets + jnp.expand_dims(2 * z_loss * log_z, -1) * softmax
g_logits = jnp.expand_dims(g0, -1) * deriv
g_targets = -jnp.expand_dims(g0, -1) * log_softmax
return (
jnp.asarray(g_logits, dtype=log_softmax.dtype),
jnp.asarray(g_targets, dtype=targets.dtype),
jnp.array(0.0, dtype=log_softmax.dtype),
)
cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd)
[docs]def compute_weighted_cross_entropy(
logits: chex.Array,
targets: chex.Array,
weights: chex.Array | None = None,
label_smoothing: float = 0.0,
z_loss: float = 0.0,
loss_normalizing_factor: float | None = None,
compute_dtype: jnp.dtype = jnp.float32,
) -> tuple[chex.Array, chex.Array, chex.Array]:
"""
Computes weighted cross-entropy loss, z-loss, and weight sum.
Args:
logits: The predicted logits.
targets: The target class labels (integers).
weights: tp.Optional weights for each example.
label_smoothing: Label smoothing factor.
z_loss: Coefficient for the auxiliary z-loss term.
loss_normalizing_factor: A factor to normalize the loss.
Returns:
A tuple containing the total loss, z-loss, and sum of weights.
"""
if not isinstance(logits, jax.Array):
raise TypeError(f"logits must be a JAX array, got {type(logits)}")
if not isinstance(targets, jax.Array):
raise TypeError(f"targets must be a JAX array, got {type(targets)}")
if weights is not None and not isinstance(weights, jax.Array):
raise TypeError(f"weights must be a JAX array or None, got {type(weights)}")
if not 0.0 <= label_smoothing < 1.0:
raise ValueError(f"label_smoothing must be in range 0~1, got {label_smoothing}")
if z_loss < 0.0:
raise ValueError(f"z_loss must be non-negative, got {z_loss}")
if logits.ndim != targets.ndim + 1:
raise ValueError(f"Incorrect shapes. Got shape {logits.shape} logits and {targets.shape} targets")
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_targets = onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence).astype(compute_dtype)
total_loss, total_z_loss = cross_entropy_with_logits(logits.astype(compute_dtype), soft_targets, z_loss=z_loss)
total_loss = total_loss - normalizing_constant
weight_sum = jnp.array(reduce(mul, targets.shape, 1), dtype=compute_dtype)
if weights is not None:
total_loss = total_loss * weights
total_z_loss = total_z_loss * weights
weight_sum = jnp.sum(weights.astype(compute_dtype))
if loss_normalizing_factor is not None:
total_loss /= loss_normalizing_factor
total_z_loss /= loss_normalizing_factor
return jnp.sum(total_loss), jnp.sum(total_z_loss), weight_sum
[docs]def compute_weighted_cross_entropy_and_accuracy(
logits: chex.Array,
targets: chex.Array,
weights: chex.Array | None = None,
label_smoothing: float = 0.0,
z_loss: float = 0.0,
loss_normalizing_factor: float | None = None,
compute_dtype: jnp.dtype = jnp.float32,
) -> tuple[chex.Array, chex.Array, chex.Array, chex.Array]:
"""
Computes weighted cross-entropy loss, z-loss, weight sum, and accuracy.
Args:
logits: The predicted logits.
targets: The target class labels (integers).
weights: tp.Optional weights for each example.
label_smoothing: Label smoothing factor.
z_loss: Coefficient for the auxiliary z-loss term.
loss_normalizing_factor: A factor to normalize the loss.
Returns:
A tuple containing the total loss, z-loss, sum of weights, and accuracy.
"""
total_loss, total_z_loss, weight_sum = compute_weighted_cross_entropy(
logits=logits,
targets=targets,
weights=weights,
label_smoothing=label_smoothing,
z_loss=z_loss,
loss_normalizing_factor=loss_normalizing_factor,
compute_dtype=compute_dtype,
)
predictions = jnp.argmax(logits, axis=-1)
correct = (predictions == targets).astype(compute_dtype)
if weights is None:
accuracy = jnp.mean(correct)
else:
w = weights.astype(compute_dtype)
denom = jnp.maximum(jnp.sum(w), 1e-8)
accuracy = jnp.sum(correct * w) / denom
return total_loss, total_z_loss, weight_sum, accuracy
[docs]def cross_entropy_loss_and_accuracy(
source,
target,
valid=None,
compute_dtype: jnp.dtype = jnp.float32,
):
"""Compute cross-entropy loss and accuracy with optional masking.
Simple and efficient implementation for computing both loss and accuracy
in a single pass through the data.
Args:
source: Logits with shape [..., num_classes].
target: Integer labels with shape [...].
valid: Optional boolean mask indicating valid positions.
compute_dtype: Data type for computation.
Returns:
Tuple of (loss, accuracy) as scalar values.
"""
source = source.astype(compute_dtype)
if valid is None:
valid = jnp.ones_like(target, dtype=compute_dtype)
else:
valid = valid.astype(compute_dtype)
lse = jax.scipy.special.logsumexp(source, axis=-1)
logit_y = jnp.take_along_axis(source, target[..., None], axis=-1)[..., 0]
nll = (lse - logit_y) * valid
weight_sum = jnp.maximum(jnp.sum(valid), 1e-8)
loss = jnp.sum(nll) / weight_sum
preds = jnp.argmax(source, axis=-1)
correct = (preds == target).astype(compute_dtype) * valid
accuracy = jnp.sum(correct) / weight_sum
return loss, accuracy
[docs]def convert_special_loss_normalizing_factor_to_enum(x: str) -> SLNF:
"""
Converts a stringified version of SpecialLossNormalizingFactor to an enum.
Args:
x: Stringified version of the enum value.
Returns:
The corresponding SpecialLossNormalizingFactor enum value.
"""
x = x.upper()
if x == "NUM_REAL_TARGET_TOKENS":
return SLNF.NUM_REAL_TARGET_TOKENS
if x == "NUM_TOTAL_TARGET_TOKENS":
return SLNF.NUM_TOTAL_TARGET_TOKENS
if x == "AVERAGE_PER_SEQUENCE":
return SLNF.AVERAGE_PER_SEQUENCE
if x == "NO_WEIGHT_NUM_REAL_TARGET_TOKENS":
return SLNF.NO_WEIGHT_NUM_REAL_TARGET_TOKENS
raise ValueError(f'Could not convert string "{x}" to SpecialLossNormalizingFactor')
@jax.vmap
def _sum_weights_per_segment(
positions: chex.Array,
segment_ids: chex.Array,
weights: chex.Array,
) -> chex.Array:
"""Sum weights per packed segment to produce a normalizing vector.
This function is used for handling packed sequences where multiple sequences
are concatenated together. It computes the sum of weights for each segment
to enable per-sequence normalization.
Args:
positions: Position indices within each segment.
segment_ids: Segment identifiers for packed sequences.
weights: Weights to sum per segment.
Returns:
Array containing normalization factors for each position based on
the total weight in its segment.
"""
def _repeat_last_nonnegative(xs, reverse=False):
"""Propagate the last non-zero value through zeros in the array.
Args:
xs: Input array.
reverse: If True, propagate in reverse direction.
Returns:
Array with zeros replaced by the last non-zero value.
"""
def fn(prev, x):
y = jnp.where(x == 0, prev, x)
return y, y
return jax.lax.scan(fn, jnp.zeros_like(xs[0]), xs, reverse=reverse)[1]
start_positions = positions == 0
final_positions = jnp.concatenate([start_positions[1:], jnp.ones(1)])
final_positions *= segment_ids != 0
final_cumulative_weights = final_positions * jnp.cumsum(weights)
final_total_weights = jnp.concatenate(
[
final_cumulative_weights[0:1],
jnp.diff(_repeat_last_nonnegative(final_cumulative_weights)),
]
)
normalizer = _repeat_last_nonnegative(final_total_weights, reverse=True)
return normalizer
[docs]def get_factor_and_weight(
loss_normalizing_factor: FACTOR_TYPE,
batch: tp.Mapping[str, chex.Array],
compute_dtype: jnp.dtype = jnp.float32,
) -> tuple[float | None, chex.Array | None]:
"""
Gets the loss normalizing factor and weights from a batch of data.
Args:
loss_normalizing_factor: The loss normalizing factor to use.
batch: A dictionary containing the input batch of data.
Returns:
A tuple containing the loss normalizing factor and loss weights.
"""
loss_weights = batch.get("decoder_loss_weights", None)
if loss_normalizing_factor is None or not isinstance(loss_normalizing_factor, str | SLNF):
return loss_normalizing_factor, loss_weights
if isinstance(loss_normalizing_factor, str):
loss_normalizing_factor = convert_special_loss_normalizing_factor_to_enum(loss_normalizing_factor)
if loss_weights is None:
loss_weights = jnp.asarray(batch["decoder_target_tokens"] > 0, compute_dtype)
output_normalizing_factor = None
if loss_normalizing_factor == SLNF.NUM_REAL_TARGET_TOKENS:
output_normalizing_factor = jnp.sum(loss_weights)
elif loss_normalizing_factor == SLNF.NUM_TOTAL_TARGET_TOKENS:
output_normalizing_factor = jnp.prod(batch["decoder_target_tokens"].shape)
elif loss_normalizing_factor == SLNF.AVERAGE_PER_SEQUENCE:
if "decoder_segment_ids" in batch:
norm_vec = _sum_weights_per_segment(
batch["decoder_positions"],
batch["decoder_segment_ids"],
loss_weights,
)
else:
norm_vec = jnp.sum(loss_weights, axis=-1, keepdims=True)
loss_weights = jnp.nan_to_num(loss_weights / norm_vec, nan=0, posinf=0, neginf=0)
output_normalizing_factor = jnp.sum(loss_weights)
else:
raise ValueError(f"Unsupported value of loss_normalizing_factor: {loss_normalizing_factor}")
return output_normalizing_factor, loss_weights
[docs]def auxiliary_load_balancing_loss_func(
gate_logits: chex.Array | tuple[chex.Array, ...],
num_experts: int,
top_k: int,
attention_mask: chex.Array | None = None,
compute_dtype: jnp.dtype = jnp.float32,
) -> jax.Array | int:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in JAX.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
Args:
gate_logits:
Logits from the `gate`. Should be a tuple/list of JAX arrays,
with each array corresponding to a layer and having shape
[batch_size * sequence_length, num_experts].
Alternatively, can be a single stacked array of shape
[num_layers * batch_size * sequence_length, num_experts].
num_experts:
Number of experts. Must be provided if `gate_logits` is not None.
top_k:
The number of experts to route per-token, can be also interpreted as the `top-k` routing
parameter.
attention_mask (`jax.numpy.ndarray`, *optional*):
The attention_mask used in forward function
shape [batch_size, sequence_length] if not None.
Returns:
The auxiliary loss as a JAX scalar array, or 0 if gate_logits is None.
"""
if gate_logits is None:
return 0
if num_experts is None:
raise ValueError("num_experts must be specified if gate_logits is provided.")
# If gate_logits is a tuple or list, concatenate them.
# Assumes individual layer logits are already on the correct device.
if isinstance(gate_logits, tuple | list):
# Ensure all logits are JAX arrays before concatenation
gate_logits_list = [jnp.asarray(layer_gate.reshape(-1, num_experts)) for layer_gate in gate_logits]
concatenated_gate_logits = jnp.concatenate(gate_logits_list, axis=0)
elif isinstance(gate_logits, jnp.ndarray):
concatenated_gate_logits = gate_logits
else:
raise TypeError(f"gate_logits must be a JAX array, tuple/list of JAX arrays, or None. Got {type(gate_logits)}")
routing_weights = jax.nn.softmax(concatenated_gate_logits, axis=-1)
_, selected_experts = jax.lax.top_k(routing_weights, k=top_k)
expert_mask = jax.nn.one_hot(selected_experts, num_classes=num_experts, dtype=compute_dtype)
if attention_mask is None:
tokens_per_expert = jnp.mean(expert_mask, axis=0)
router_prob_per_expert = jnp.mean(routing_weights, axis=0)
else:
attention_mask = jnp.asarray(attention_mask)
if attention_mask.ndim != 2:
raise ValueError(f"attention_mask must have shape [batch_size, sequence_length], got {attention_mask.shape}")
batch_size, sequence_length = attention_mask.shape
total_tokens_per_layer = batch_size * sequence_length
num_effective_tokens = concatenated_gate_logits.shape[0]
if num_effective_tokens % total_tokens_per_layer != 0:
raise ValueError(
f"Total tokens in gate_logits ({num_effective_tokens}) is not divisible by "
f"batch_size*sequence_length ({total_tokens_per_layer}). Ensure gate_logits are correctly concatenated."
)
num_hidden_layers = num_effective_tokens // total_tokens_per_layer
mask_expanded = jnp.expand_dims(attention_mask, axis=(0, 3, 4))
target_mask_shape = (
num_hidden_layers,
batch_size,
sequence_length,
top_k,
num_experts,
)
expert_attention_mask_broadcast = jnp.broadcast_to(mask_expanded, target_mask_shape)
expert_attention_mask = jnp.reshape(expert_attention_mask_broadcast, (-1, top_k, num_experts))
masked_expert_contributions = expert_mask * expert_attention_mask
tokens_per_expert_numerator = jnp.sum(masked_expert_contributions, axis=0)
tokens_per_expert_denominator = jnp.sum(expert_attention_mask, axis=0)
tokens_per_expert_denominator = jnp.where(tokens_per_expert_denominator == 0, 1.0, tokens_per_expert_denominator)
tokens_per_expert = tokens_per_expert_numerator / tokens_per_expert_denominator
mask_expanded_router = jnp.expand_dims(attention_mask, axis=(0, 3))
target_router_mask_shape = (
num_hidden_layers,
batch_size,
sequence_length,
num_experts,
)
router_attention_mask_broadcast = jnp.broadcast_to(mask_expanded_router, target_router_mask_shape)
router_per_expert_attention_mask = jnp.reshape(router_attention_mask_broadcast, (-1, num_experts))
masked_routing_weights = routing_weights * router_per_expert_attention_mask
router_prob_numerator = jnp.sum(masked_routing_weights, axis=0)
router_prob_denominator = jnp.sum(router_per_expert_attention_mask, axis=0)
router_prob_denominator = jnp.where(router_prob_denominator == 0, 1.0, router_prob_denominator)
router_prob_per_expert = router_prob_numerator / router_prob_denominator
router_prob_per_expert_expanded = jnp.expand_dims(router_prob_per_expert, axis=0)
per_expert_loss_terms = tokens_per_expert * router_prob_per_expert_expanded
overall_loss = jnp.sum(per_expert_loss_terms)
final_loss = overall_loss * num_experts
return jnp.asarray(final_loss, dtype=jnp.float32)
[docs]def fixed_cross_entropy(
source: jax.Array,
target: jax.Array,
attention_mask: jax.Array | None = None,
config: LossConfig | None = None,
num_items_in_batch: int | None = None,
batch: tp.Mapping[str, chex.Array] | None = None,
**kwargs: tp.Any,
) -> LossMetrics:
"""
Jax implementation of fixed cross-entropy loss with z-loss, label smoothing, masking.
Args:
source: Predicted logits, shape (batch_size, num_classes) or (batch_size * seq_len, num_classes).
target: True labels, shape (batch_size,) or (batch_size * seq_len,). Must be integers.
num_items_in_batch: tp.Optional, used when reduction should be sum.
attention_mask: tp.Optional, boolean mask applied to the loss.
batch: tp.Optional batch for dynamic loss normalization
**kwargs: Additional keyword arguments.
Returns:
The computed cross-entropy loss in LossMetrics.
"""
if config is None:
config = LossConfig()
if source is None or target is None:
raise ValueError("Logits and labels cannot be None")
compute_dtype = (
jnp.float32
if config.compute_dtype == "fp32"
else (jnp.bfloat16 if config.compute_dtype == "bf16" else source.dtype)
)
mask = attention_mask if attention_mask is not None else (target != config.ignore_index)
loss_factor = config.loss_normalizing_factor
if config.reduction is not None:
loss, norm = dynamic_cross_entropy_loss(
logits=source,
targets=target,
weight=mask.astype(compute_dtype),
ignore_index=config.ignore_index,
reduction=config.reduction,
label_smoothing=config.label_smoothing,
compute_dtype=compute_dtype,
)
total_z_loss = jnp.array(0.0, compute_dtype)
weight_sum = norm
preds = jnp.argmax(source, axis=-1)
correct = (preds == target).astype(compute_dtype) * mask.astype(compute_dtype)
accuracy = jnp.sum(correct) / jnp.maximum(norm, 1e-8)
elif (
loss_factor is SLNF.NO_WEIGHT_NUM_REAL_TARGET_TOKENS
or loss_factor is SLNF.NO_WEIGHT_NUM_REAL_TARGET_TOKENS.value
):
loss, accuracy = cross_entropy_loss_and_accuracy(
source.astype(compute_dtype), target, mask.astype(compute_dtype)
)
total_z_loss = jnp.array(0.0, compute_dtype)
weight_sum = jnp.sum(mask.astype(compute_dtype))
else:
if batch is None:
lf = config.loss_normalizing_factor
if isinstance(lf, str):
lf = convert_special_loss_normalizing_factor_to_enum(lf)
batch = (
{"decoder_target_tokens": target, "decoder_loss_weights": mask.astype(compute_dtype)}
if lf == SLNF.NUM_REAL_TARGET_TOKENS
else {}
)
loss_normalizing_factor, loss_weights = get_factor_and_weight(config.loss_normalizing_factor, batch)
use_chunk_vocab = config.chunk_vocab_size is not None
use_chunk_tokens = config.chunk_token_size is not None
use_block_size = config.chunk_block_size is not None
if use_chunk_vocab:
total_loss, total_z_loss, weight_sum, accuracy = sparse_cross_entropy_chunked_vocab(
source,
target,
weights=loss_weights, # <- use loss_weights, not raw mask
ignore_index=config.ignore_index,
label_smoothing=config.label_smoothing,
z_loss=config.z_loss,
reduction="sum",
chunk_size=config.chunk_vocab_size,
ccompute_dtype=compute_dtype,
)
elif use_chunk_tokens:
total_loss, total_z_loss, weight_sum, accuracy = sparse_cross_entropy_chunked_tokens(
source,
target,
weights=loss_weights,
ignore_index=config.ignore_index,
label_smoothing=config.label_smoothing,
z_loss=config.z_loss,
reduction="sum",
token_chunk_size=config.chunk_token_size,
compute_dtype=compute_dtype,
)
elif use_block_size:
total_loss, total_z_loss, weight_sum, accuracy = cross_entropy_blockwise_logits(
logits=source,
targets=target,
weights=loss_weights,
ignore_index=config.ignore_index,
label_smoothing=config.label_smoothing,
z_loss=config.z_loss,
block_size=int(config.chunk_block_size),
dtype=compute_dtype,
)
if loss_normalizing_factor is not None:
total_loss = total_loss / loss_normalizing_factor
total_z_loss = total_z_loss / loss_normalizing_factor
else:
total_loss, total_z_loss, weight_sum, accuracy = compute_weighted_cross_entropy_and_accuracy(
logits=source,
targets=target,
weights=loss_weights,
label_smoothing=config.label_smoothing,
z_loss=config.z_loss,
loss_normalizing_factor=loss_normalizing_factor,
compute_dtype=compute_dtype,
)
# Apply loss_normalizing_factor in chunked paths (dense path already applied it)
if (use_chunk_vocab or use_chunk_tokens) and (loss_normalizing_factor is not None):
total_loss = total_loss / loss_normalizing_factor
total_z_loss = total_z_loss / loss_normalizing_factor
# Optional post-normalization
if num_items_in_batch is not None:
loss = total_loss / num_items_in_batch
elif config.divide_weight_sum:
loss = total_loss / jnp.maximum(weight_sum, 1e-8)
else:
loss = total_loss
return LossMetrics(loss=loss, z_loss=total_z_loss, weight_sum=weight_sum, accuracy=accuracy)
[docs]def ForCausalLMLoss(
logits: jax.Array,
labels: jax.Array,
attention_mask: jax.Array | None = None,
config: LossConfig | None = None,
paxis: PartitionAxis | None = None,
num_items_in_batch: int | None = None,
batch: tp.Mapping[str, chex.Array] | None = None,
**kwargs: tp.Any,
) -> LossMetrics:
"""
Jax implementation of loss function for causal language models.
Args:
logits: Predicted logits, shape (batch_size, seq_len, vocab_size).
labels: True labels, shape (batch_size, seq_len). Must be integers.
num_items_in_batch: tp.Optional, used when reduction should be sum.
batch: tp.Optional batch for dynamic loss normalization
**kwargs: Additional keyword arguments for the cross-entropy loss.
Returns:
The computed causal language modeling loss.
"""
if logits is None or labels is None:
raise ValueError("Logits and labels cannot be None")
if paxis is not None:
logits = with_sharding_constraint(
logits,
PartitionSpec(
paxis.batch_axis,
paxis.sequence_axis,
paxis.hidden_state_axis,
),
)
labels = with_sharding_constraint(
labels,
PartitionSpec(
paxis.batch_axis,
paxis.sequence_axis,
),
)
shift_attn_m = attention_mask
if config is None:
config = LossConfig()
if config.shift_tokens:
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
if attention_mask is not None:
shift_attn_m = attention_mask[:, 1:]
else:
shift_logits = logits
shift_labels = labels
if attention_mask is not None:
shift_attn_m = attention_mask
loss = fixed_cross_entropy(
source=shift_logits,
target=shift_labels,
attention_mask=shift_attn_m,
config=config,
num_items_in_batch=num_items_in_batch,
batch=batch,
**kwargs,
)
return loss
[docs]def ForSequenceClassificationLoss(
logits: jax.Array,
labels: jax.Array,
attention_mask: jax.Array | None = None,
config: LossConfig | None = None,
paxis: PartitionAxis | None = None,
batch: tp.Mapping[str, chex.Array] | None = None,
**kwargs: tp.Any,
) -> LossMetrics:
"""
Jax implementation of loss function for sequence classification.
Args:
labels: True labels, shape (batch_size,) or (batch_size, num_labels) for multi label classification.
logits: Predicted logits, shape (batch_size, num_labels) or (batch_size, 1) or (batch_size,) for regression.
config: Configuration with problem_type and num_labels attributes.
batch: tp.Optional batch for dynamic loss normalization
**kwargs: Additional keyword arguments for the cross-entropy loss.
Returns:
The computed sequence classification loss.
"""
if logits is None or labels is None:
raise ValueError("Logits and labels cannot be None")
num_labels = config.num_labels
if config.problem_type is None:
if num_labels == 1:
config.problem_type = "regression"
elif num_labels > 1 and (labels.dtype == jnp.int32 or labels.dtype == jnp.int64):
config.problem_type = "single_label_classification"
else:
config.problem_type = "multi_label_classification"
if config.problem_type == "regression":
loss = jnp.mean((logits.squeeze() - labels.squeeze()) ** 2)
elif config.problem_type == "single_label_classification":
return fixed_cross_entropy(
source=logits.reshape(-1, num_labels),
target=labels.reshape(-1),
attention_mask=attention_mask,
config=config,
batch=batch,
**kwargs,
)
elif config.problem_type == "multi_label_classification":
loss = jnp.mean(
sigmoid_cross_entropy_with_logits(
logits=logits,
labels=labels,
label_smoothing=config.label_smoothing,
)
)
else:
raise ValueError(f"Invalid problem_type: {config.problem_type}")
return LossMetrics(total_loss=loss, loss=loss)
[docs]def ForQuestionAnsweringLoss(
start_logits: jax.Array,
end_logits: jax.Array,
start_positions: jax.Array,
end_positions: jax.Array,
config: LossConfig | None = None,
paxis: PartitionAxis | None = None,
batch: tp.Mapping[str, chex.Array] | None = None,
**kwargs: tp.Any,
) -> LossMetrics:
"""
Jax implementation of loss function for question answering.
Args:
start_logits: Predicted start logits, shape (batch_size, seq_len).
end_logits: Predicted end logits, shape (batch_size, seq_len).
start_positions: True start positions, shape (batch_size,).
end_positions: True end positions, shape (batch_size,).
batch: tp.Optional batch for dynamic loss normalization
**kwargs: Additional keyword arguments for the cross-entropy loss.
Returns:
The computed question answering loss.
"""
if start_logits is None or end_logits is None or start_positions is None or end_positions is None:
raise ValueError("Logits and labels cannot be None")
ignored_index = start_logits.shape[1]
start_positions = jnp.clip(start_positions, 0, ignored_index)
end_positions = jnp.clip(end_positions, 0, ignored_index)
cfg = dataclasses.replace(config or LossConfig(), ignore_index=ignored_index)
start_loss = fixed_cross_entropy(source=start_logits, target=start_positions, config=cfg, batch=batch, **kwargs)
end_loss = fixed_cross_entropy(source=end_logits, target=end_positions, config=cfg, batch=batch, **kwargs)
loss = (start_loss.loss + end_loss.loss) / 2
accuracy = (start_loss.accuracy + end_loss.accuracy) / 2
z_loss = (start_loss.z_loss + end_loss.z_loss) / 2
weight_sum = (start_loss.weight_sum + end_loss.weight_sum) / 2
return LossMetrics(
loss=loss,
accuracy=accuracy,
z_loss=z_loss,
weight_sum=weight_sum,
)
[docs]def ForTokenClassification(
logits: jax.Array,
labels: jax.Array,
config: LossConfig | None = None,
paxis: PartitionAxis | None = None,
batch: tp.Mapping[str, chex.Array] | None = None,
**kwargs: tp.Any,
) -> LossMetrics:
"""
Jax implementation of loss function for token classification.
Args:
logits: Predicted logits, shape (batch_size, seq_len, num_labels).
labels: True labels, shape (batch_size, seq_len). Must be integers.
config: Configuration with num_labels attribute.
label_smoothing: Label smoothing factor.
z_loss: Coefficient for the auxiliary z-loss term.
loss_normalizing_factor: A factor to normalize the loss, can also be enum.
batch: tp.Optional batch for dynamic loss normalization
**kwargs: Additional keyword arguments for the cross-entropy loss.
Returns:
The computed token classification loss.
"""
if logits is None or labels is None:
raise ValueError("Logits and labels cannot be None")
loss = fixed_cross_entropy(
source=logits,
target=labels,
config=config,
batch=batch,
**kwargs,
)
return loss
LOSS_MAPPING = {
"ForCausalLM": ForCausalLMLoss,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
}