easydel.trainers.utils#
Utility functions and classes for EasyDeL trainers.
This module provides essential utilities for training, including: - JAX distributed configuration management - Dataset creation and manipulation functions - Data collation utilities for various training tasks - Conversation formatting and prompt processing - Memory and performance profiling tools - Training state management utilities
- class easydel.trainers.utils.BCODataCollatorGrain(max_prompt_length: int, max_completion_length: int, pad_token_id: int, label_pad_token_id: int, is_encoder_decoder: bool)[source]#
Bases:
_BCODataCollatorMixinGrain-compatible BCO data collator.
- class easydel.trainers.utils.BCODataCollatorTFDS(max_prompt_length: int, max_completion_length: int, pad_token_id: int, label_pad_token_id: int, is_encoder_decoder: bool)[source]#
Bases:
_BCODataCollatorMixinData collator for BCO training with TFDS backends.
- class easydel.trainers.utils.CollateMapTransform(collate_fn: callable)[source]#
Bases:
MapTransformGrain transform for applying custom collation functions.
Wraps a user-defined collation function as a Grain MapTransform, allowing custom batch processing logic in the Grain pipeline.
- collate_fn#
Callable that processes/collates data elements.
- Type
callable
- collate_fn: callable#
- class easydel.trainers.utils.DPODataCollatorWithPaddingGrain(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: bool | None = False, output_arrays_only: bool = True, prepadded: bool = True)[source]#
Bases:
objectDPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- label_pad_token_id: int = -100#
- max_completion_length: int#
- max_prompt_length: int#
- output_arrays_only: bool = True#
- pad_token_id: int = 0#
- prepadded: bool = True#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.utils.DPODataCollatorWithPaddingTFDS(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: bool | None = False, output_arrays_only: bool = True, prepadded: bool = True)[source]#
Bases:
objectAdvanced data collator for DPO training with TFDS.
Extended version of DataCollatorForPreferenceTFDS with additional features for handling complex DPO scenarios including encoder-decoder models and pre-padded data.
- max_prompt_length#
Maximum length for prompt sequences.
- Type
int
- max_completion_length#
Maximum length for completion sequences.
- Type
int
- pad_token_id#
Token ID to use for padding (default 0).
- Type
int
- label_pad_token_id#
Token ID for label padding (default -100).
- Type
int
- is_encoder_decoder#
Whether using encoder-decoder architecture.
- Type
bool | None
- output_arrays_only#
If True, only return array-type outputs.
- Type
bool
- prepadded#
If True, assumes inputs are already padded.
- Type
bool
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- label_pad_token_id: int = -100#
- max_completion_length: int#
- max_prompt_length: int#
- output_arrays_only: bool = True#
- pad_token_id: int = 0#
- prepadded: bool = True#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.utils.DataCollatorForCompletionOnlyLM(processing_class: Union[str, PreTrainedTokenizerBase], response_template: str | list[int], instruction_template: str | list[int] | None = None, *args, mlm: bool = False, ignore_index: int = -100, **kwargs)[source]#
Bases:
objectData collator for training on assistant completions only.
This collator masks out non-assistant tokens in the labels, ensuring that the loss is only calculated on the model’s completions (assistant responses) and not on the user prompts or system messages.
This is particularly useful for: - Instruction tuning where you only want to train on responses - Chat models where user inputs should not contribute to loss - Maintaining the model’s ability to understand prompts without
being trained to generate them
- processing_class#
Tokenizer or processor for encoding text.
- response_template#
Template or token IDs marking response start.
- instruction_template#
Optional template marking instruction start.
- mlm#
Whether using masked language modeling (default False).
- ignore_index#
Label value to ignore in loss calculation.
- class easydel.trainers.utils.DataCollatorForPreferenceGrain(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: bool | None = False)[source]#
Bases:
objectData collator for Direct Preference Optimization (DPO) with Grain.
Grain-compatible version of DataCollatorForPreferenceTFDS. Processes single dictionaries instead of lists for Grain’s data pipeline.
- max_prompt_length#
Maximum length for prompt sequences.
- Type
int
- max_completion_length#
Maximum length for completion sequences.
- Type
int
- pad_token_id#
Token ID to use for padding (default 0).
- Type
int
- label_pad_token_id#
Token ID for label padding (default -100).
- Type
int
- is_encoder_decoder#
Whether using encoder-decoder architecture.
- Type
bool | None
Note
Returns NumPy arrays for Grain compatibility. Handles single feature dictionary rather than list.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- label_pad_token_id: int = -100#
- max_completion_length: int#
- max_prompt_length: int#
- pad_token_id: int = 0#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.utils.DataCollatorForPreferenceTFDS(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: bool | None = False)[source]#
Bases:
objectData collator for Direct Preference Optimization (DPO) with TFDS.
Handles batching and padding of prompt-completion pairs for preference learning. Each example has a prompt, chosen completion, and rejected completion that need to be padded separately.
- max_prompt_length#
Maximum length for prompt sequences.
- Type
int
- max_completion_length#
Maximum length for completion sequences.
- Type
int
- pad_token_id#
Token ID to use for padding (default 0).
- Type
int
- label_pad_token_id#
Token ID for label padding (default -100).
- Type
int
- is_encoder_decoder#
Whether using encoder-decoder architecture.
- Type
bool | None
Note
Supports multimodal inputs with pixel_values and pixel_attention_mask. Can include reference model log probabilities if provided.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- label_pad_token_id: int = -100#
- max_completion_length: int#
- max_prompt_length: int#
- pad_token_id: int = 0#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.utils.GRPODataCollatorGrain(max_prompt_length: int, pad_token_id: int = 0)[source]#
Bases:
objectGrain-compatible GRPO data collator.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- max_prompt_length: int#
- pad_token_id: int = 0#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.utils.GRPODataCollatorTFDS(max_prompt_length: int, pad_token_id: int = 0)[source]#
Bases:
objectData collator for GRPO training with TFDS backends.
GRPO only needs prompts since completions are generated online.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- max_prompt_length: int#
- pad_token_id: int = 0#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.utils.HFDataSource(dataset: IterableDataset, shard_options: ShardOptions, num_threads: int = 1)[source]#
Bases:
RandomAccessDataSourceGrain-compatible data source for HuggingFace IterableDatasets.
Bridges HuggingFace’s IterableDataset with Google’s Grain data loading library, enabling efficient distributed data loading with proper sharding.
This class handles: - Multi-threaded data loading - Dataset sharding across distributed workers - Thread-safe iteration over dataset shards
- dataset#
The HuggingFace IterableDataset to wrap.
- shard_options#
Grain sharding configuration.
- num_threads#
Number of worker threads for data loading.
Note
Automatically handles dataset sharding based on world size and rank. Issues warnings if dataset shards don’t match expected shard count.
- class easydel.trainers.utils.JaxDistributedConfig[source]#
Bases:
objectConfiguration manager for JAX distributed training.
This class handles the initialization of JAX distributed computing environments, enabling multi-host and multi-device training setups. Originally from EasyLM project.
The class manages: - Multi-process coordination - Device assignment - Communication setup between processes
Note
This is typically used internally by TrainingArguments and should not need to be configured directly by users in most cases.
- static get_default_config(updates=None)[source]#
Get default configuration for JAX distributed.
- Parameters
updates – Optional dictionary of configuration updates to apply to the default configuration.
- Returns
- Configuration dictionary with the following fields:
initialize_jax_distributed: Whether to initialize distributed
coordinator_address: Address of the coordinator process
num_processes: Total number of processes
process_id: ID of the current process
local_device_ids: Comma-separated list of local device IDs
- Return type
ConfigDict
Note
Uses ml_collections placeholders for required fields that must be provided at runtime.
- classmethod initialize(config=None)[source]#
Initialize JAX distributed with the given configuration.
- Parameters
config – Configuration dictionary or None to use defaults. If provided, should contain distributed setup parameters.
Note
Only initializes if config.initialize_jax_distributed is True. Parses local_device_ids from comma-separated string if provided.
- Raises
RuntimeError – If JAX distributed initialization fails.
- class easydel.trainers.utils.RewardDataCollatorWithPaddingGrain(tokenizer: Any, padding: bool | str = 'max_length', max_length: int | None = None, truncation_mode: str = 'keep_end')[source]#
Bases:
objectData collator for reward modeling with Grain data loading.
Similar to RewardDataCollatorWithPaddingTFDS but designed for use with Google’s Grain data loading library. Handles single dictionaries instead of lists of dictionaries.
- tokenizer#
The tokenizer/processor for encoding text.
- Type
Any
- padding#
Padding strategy - ‘max_length’, True, or False.
- Type
bool | str
- max_length#
Maximum sequence length for padding.
- Type
int | None
- truncation_mode#
How to truncate sequences (‘keep_end’ or ‘keep_start’).
- Type
str
Note
Returns NumPy arrays instead of JAX arrays for Grain compatibility. Expects a single feature dictionary rather than a list.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- padding: bool | str = 'max_length'#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- tokenizer: Any#
- truncation_mode: str = 'keep_end'#
- class easydel.trainers.utils.RewardDataCollatorWithPaddingTFDS(tokenizer: Any, padding: bool | str = 'max_length', max_length: int | None = None, truncation_mode: str = 'keep_end')[source]#
Bases:
objectReward DataCollator class that pads the inputs to the maximum length of the batch.
- Parameters
tokenizer (ProcessingClassType) – The tokenizer used for encoding the data.
padding (Union[bool, str, `PaddingStrategy]`, optional, defaults to True) – padding_strategy to pass to the tokenizer.
max_length (int or None, optional, defaults to None) – If set will pad the sequence to a maximum provided value.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- padding: bool | str = 'max_length'#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- tokenizer: Any#
- truncation_mode: str = 'keep_end'#
- class easydel.trainers.utils.ToNumpy[source]#
Bases:
MapTransformGrain transform to convert data elements to NumPy arrays.
Ensures all values in a dictionary are converted to NumPy arrays, which is often required for JAX-based training pipelines.
- easydel.trainers.utils.add_bos_token_if_needed(bos_token_id: int | None, prompt_len_input_ids: int, prompt_tokens: dict[str, list[int]], chosen_prompt_len_input_ids: int, chosen_tokens: dict[str, list[int]], rejected_prompt_len_input_ids: int, rejected_tokens: dict[str, list[int]])[source]#
Add beginning-of-sequence token to prompts if needed.
Ensures all prompt sequences start with BOS token for consistency in preference learning scenarios.
- Parameters
bos_token_id – BOS token ID, or None if not used.
prompt_len_input_ids – Length of main prompt.
prompt_tokens – Main prompt token dictionary.
chosen_prompt_len_input_ids – Length of chosen prompt.
chosen_tokens – Chosen response token dictionary.
rejected_prompt_len_input_ids – Length of rejected prompt.
rejected_tokens – Rejected response token dictionary.
- Returns
(prompt_tokens, chosen_tokens, rejected_tokens) with BOS added.
- Return type
tuple
Note
Only adds BOS if it’s not already present at the beginning. Updates both input_ids and attention_mask.
- easydel.trainers.utils.add_eos_token_if_needed(eos_token_id: int, chosen_tokens: dict[str, list[int]], rejected_tokens: dict[str, list[int]])[source]#
Add end-of-sequence token to responses if needed.
Ensures both chosen and rejected responses end with EOS token for proper sequence termination.
- Parameters
eos_token_id – EOS token ID to add.
chosen_tokens – Chosen response token dictionary.
rejected_tokens – Rejected response token dictionary.
- Returns
(chosen_tokens, rejected_tokens) with EOS added.
- Return type
tuple
Note
Only adds EOS if it’s not already present at the end. Updates both input_ids and attention_mask.
- easydel.trainers.utils.conversations_formatting_function(processing_class: AutoTokenizer, messages_field: Literal['messages', 'conversations'], tools: list | None = None)[source]#
Create a formatter for conversation/chat datasets.
Returns a function that applies chat templates to conversation data, converting structured conversations into formatted text suitable for training chat models.
- Parameters
processing_class – Tokenizer with chat template support.
messages_field – Field name containing conversations - either ‘messages’ or ‘conversations’.
tools – Optional list of tools for function calling support.
- Returns
- Function that formats dataset examples using the
tokenizer’s chat template.
- Return type
Callable
Note
Handles both single conversations and batches of conversations. The returned function expects datasets with the specified messages_field containing role-based conversation data.
- easydel.trainers.utils.create_constant_length_dataset(processing_class, dataset, dataset_text_field: str | None = None, formatting_func: Optional[Callable] = None, infinite: bool = False, seq_length: int = 1024, num_of_sequences: int = 1024, chars_per_token: float = 3.6, eos_token_id: int = 0, shuffle: bool = True, append_concat_token: bool = True, add_special_tokens: bool = True) Callable[[], Iterator[dict[str, jax.Array]]][source]#
Creates a generator function that yields constant length chunks of tokens from a stream of text files.
- Parameters
processing_class – The processor used for processing the data.
dataset – Dataset with text files.
dataset_text_field – Name of the field in the dataset that contains the text.
formatting_func – Function that formats the text before tokenization.
infinite – If True the iterator is reset after dataset reaches end else stops.
seq_length – Length of token sequences to return.
num_of_sequences – Number of token sequences to keep in buffer.
chars_per_token – Number of characters per token used to estimate number of tokens in text buffer.
eos_token_id – Id of the end of sequence token if the passed processing_class does not have an EOS token.
shuffle – Shuffle the examples before they are returned.
append_concat_token – If true, appends eos_token_id at the end of each sample being packed.
add_special_tokens – If true, processing_class adds special tokens to each sample being packed.
- Returns
A generator function that yields dictionaries containing input_ids and attention_mask as jnp.arrays
- easydel.trainers.utils.create_prompt_creator(processing_class)[source]#
Create a prompt formatting function for conversation data.
- Parameters
processing_class – Tokenizer or processor class used for formatting.
- Returns
- A function that formats conversation samples into prompts
suitable for training.
- Return type
Callable
Note
The returned function expects samples with a ‘conversation’ field containing input/output pairs and formats them using the conversations_formatting_function.
- easydel.trainers.utils.first_true_indices(bools, dtype=<class 'jax.numpy.int32'>)[source]#
Find the index of the first True value along the last axis.
Takes an N-dimensional bool array and returns an (N-1)-dimensional array of integers giving the position of the first True in each “row”.
- Parameters
bools – N-dimensional boolean array to search.
dtype – Data type for the output indices (default jnp.int32).
- Returns
- (N-1)-dimensional array of indices where each element
is the position of the first True in the corresponding row. Returns row_len if no True values found.
- Return type
jnp.ndarray
Note
Uses a clever trick with minimum to find first True efficiently. Returns the row length if no True value is found in a row.
- easydel.trainers.utils.get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], processing_class: AutoTokenizer, tools: list | None = None) Optional[Callable][source]#
Automatically detect and return appropriate formatting function.
Examines dataset structure to determine the appropriate formatting function (chat format or instruction format) based on field names and schemas.
- Parameters
dataset – HuggingFace Dataset to analyze.
processing_class – Tokenizer to use for formatting.
tools – Optional tools for function calling support.
- Returns
- Appropriate formatting function, or None if
no suitable format is detected.
- Return type
Callable | None
Note
Supports: - ChatML format (messages/conversations fields) - Instruction format (prompt/completion fields) Returns None if dataset doesn’t match known formats.
- easydel.trainers.utils.instructions_formatting_function(processing_class: AutoTokenizer)[source]#
Create a formatter for instruction-following datasets.
Returns a function that converts prompt-completion pairs into chat format using the tokenizer’s chat template. Originally from TRL.
- Parameters
processing_class – Tokenizer with chat template support.
- Returns
- Function that formats instruction datasets by converting
prompt/completion pairs to user/assistant conversations.
- Return type
Callable
Note
Expects datasets with ‘prompt’ and ‘completion’ fields. Automatically converts to chat format with user/assistant roles.
- easydel.trainers.utils.leave_alone_context_manager()[source]#
No-op context manager that does nothing.
Useful as a placeholder when a context manager is required but no actual context management is needed.
- Yields
None – Simply yields control back to the caller.
- easydel.trainers.utils.np_pad(tensors: list[numpy.ndarray], max_lenght: int | None, padding_value: int = 0, padding_side: str = 'right') ndarray[source]#
Pad a list of NumPy tensors to uniform shape.
Similar to pad() but for NumPy arrays instead of JAX arrays.
- Parameters
tensors – List of NumPy arrays to pad.
max_lenght – Target length for padding. If None, uses current max.
padding_value – Value to use for padding (default 0).
padding_side – Where to add padding - ‘left’ or ‘right’ (default ‘right’).
- Returns
Batched and padded array.
- Return type
np.ndarray
- Raises
ValueError – If padding_side is not ‘left’ or ‘right’.
- easydel.trainers.utils.pad(tensors: list[jax.Array], max_lenght: int | None, padding_value: int = 0, padding_side: str = 'right') Array[source]#
Pad a list of JAX tensors to uniform shape.
- Parameters
tensors – List of JAX arrays to pad.
max_lenght – Target length for padding. If None, uses maximum length found in tensors.
padding_value – Value to use for padding (default 0).
padding_side – Where to add padding - ‘left’ or ‘right’ (default ‘right’).
- Returns
- Batched and padded tensor with shape
[batch_size, *tensor_shape, max_length].
- Return type
jnp.ndarray
- Raises
ValueError – If padding_side is not ‘left’ or ‘right’.
Note
Efficiently handles variable-length sequences by padding to a common length. Preserves dtype of input tensors.
- easydel.trainers.utils.pad_sequence(sequences, batch_first=False, padding_value=0, max_len: int | None = None)[source]#
Pad a list of sequences to the same length.
- Parameters
sequences – List of sequences (arrays) to pad.
batch_first – If True, output has batch dimension first. If False, adds padding to the left (default False).
padding_value – Value to use for padding (default 0).
max_len – Maximum length to pad to. If None, uses longest sequence.
- Returns
Padded sequences as a single batched array.
- Return type
jnp.ndarray
Note
Similar to PyTorch’s pad_sequence but for JAX arrays. When batch_first=False, padding is added to the left.
- easydel.trainers.utils.pad_single(tensor: ndarray, max_length: int | None = None, padding_value: int = 0, padding_side: str = 'right') ndarray[source]#
Pad a single NumPy tensor along the last dimension.
- Parameters
tensor – NumPy array to pad.
max_length – Target length for the last dimension. If None, returns tensor unchanged.
padding_value – Value to use for padding (default 0).
padding_side – Where to add padding - ‘left’ or ‘right’ (default ‘right’).
- Returns
Padded tensor with last dimension of size max_length.
- Return type
np.ndarray
- Raises
ValueError – If padding_side is not ‘left’ or ‘right’.
Note
If tensor is already longer than max_length, it will be truncated from the appropriate side based on padding_side.
- easydel.trainers.utils.pad_to_length(tensor: Union[Array, ndarray, bool, number], length: int, pad_value: int | float, axis: int = -1) Union[Array, ndarray, bool, number][source]#
Pad or truncate a tensor to a specific length along an axis.
- Parameters
tensor – Input array to pad or truncate.
length – Target length for the specified axis.
pad_value – Value to use for padding.
axis – Axis along which to pad/truncate (default -1).
- Returns
Tensor padded or truncated to the specified length.
- Return type
chex.Array
Note
If tensor is already longer than length, it will be truncated. Special handling for 2D tensors.
- easydel.trainers.utils.shift_and_pad(mask, *tensors)[source]#
Shift tensors to align with the first non-zero mask position.
Rolls each tensor so that the first ‘1’ in the mask appears at position 0. Useful for aligning sequences that have different starting positions.
- Parameters
mask – Binary mask array indicating valid positions.
*tensors – Additional tensors to shift along with the mask.
- Returns
Shifted mask and tensors (if provided), or just mask if no tensors.
- Return type
tuple
Note
Modifies inputs in-place. Each row is shifted independently based on its first non-zero mask position.
- easydel.trainers.utils.tolist(x)[source]#
Convert various array types to Python list.
Utility function from HuggingFace for consistent list conversion.
- Parameters
x – Input to convert. Can be: - Python list (returned as-is) - NumPy array - JAX array - Tensor with .numpy() method
- Returns
Python list representation of the input.
- Return type
list
Note
Handles tensors by first converting to NumPy if they have a .numpy() method.
- easydel.trainers.utils.truncate_right(input_ids, stop_token_id, pad_token_id)[source]#
Truncate sequences after the first occurrence of stop token.
Replaces all tokens after the first stop token with padding tokens and creates a corresponding attention mask.
- Parameters
input_ids – 2D array of token IDs [batch_size, sequence_length].
stop_token_id – Token ID that marks where to stop.
pad_token_id – Token ID to use for padding truncated positions.
- Returns
- (output_ids, mask) where:
output_ids: Input with post-stop tokens replaced by padding
mask: Binary attention mask (1 for valid, 0 for padded)
- Return type
tuple
Note
Useful for truncating generated sequences at EOS tokens. Preserves the stop token itself in the output.